diff --git a/examples/monai-2D-mednist/client/data.py b/examples/monai-2D-mednist/client/data.py index 1bacbc1d8..1d1b655b6 100644 --- a/examples/monai-2D-mednist/client/data.py +++ b/examples/monai-2D-mednist/client/data.py @@ -6,10 +6,48 @@ import torch import torchvision from monai.apps import download_and_extract +from torch.utils.data import Dataset as Dataset_ dir_path = os.path.dirname(os.path.realpath(__file__)) abs_path = os.path.abspath(dir_path) +import yaml + +import numpy as np + +DATA_CLASSES = {'AbdomenCT': 0, 'BreastMRI': 1, 'CXR': 2, 'ChestCT': 3, 'Hand': 4, 'HeadCT': 5} + + +def split_data(data_path='data/MedNIST', splits=100, validation_split=0.9): + # create clients + clients = {'client ' + str(i): {"train": [], "validation": []} for i in range(splits)} + # print("clients: ", clients) + + for class_ in os.listdir(data_path): + if os.path.isdir(os.path.join(data_path, class_)): + # print("class_: ", class_) + patients_in_class = [os.path.join(class_, patient) for patient in + os.listdir(os.path.join(data_path, class_))] + # print(patients_in_class) + np.random.shuffle(patients_in_class) + chops = np.int32(np.linspace(0, len(patients_in_class), splits + 1)) + for split in range(splits): + # print("split ", split) + p = patients_in_class[chops[split]:chops[split + 1]] + valsplit = np.int32(len(p) * validation_split) + # print("'client ' + str(split): " 'client ' + str(split)) + # print("p[:valsplit]: ", p[:valsplit]) + # clients['client' + str(split)]["train"] = 3 + + clients['client ' + str(split)]["train"] += p[:valsplit] + clients['client ' + str(split)]["validation"] += p[valsplit:] + + with open(os.path.join(os.path.dirname(data_path), "data_splits.yaml"), 'w') as file: + yaml.dump(clients, file, default_flow_style=False) + + + + def get_data(out_dir="data"): """Get data from the external repository. @@ -37,6 +75,10 @@ def get_data(out_dir="data"): else: print('files already exist.') + split_data() + + + def get_classes(data_path): """Get a list of classes from the dataset @@ -50,6 +92,12 @@ def get_classes(data_path): class_names = sorted(x for x in os.listdir(data_path) if os.path.isdir(os.path.join(data_path, x))) return(class_names) + + + + + + def load_data(data_path, sample_size=None, is_train=True): """Load data from disk. @@ -114,11 +162,26 @@ def load_data(data_path, sample_size=None, is_train=True): if is_train: return train_x, train_y else: - return val_x, val_y + return val_x, val_y, class_names + + + +class MedNISTDataset(torch.utils.data.Dataset): + def __init__(self, data_path, image_files, transforms): + self.data_path = data_path + self.image_files = image_files + self.transforms = transforms + + def __len__(self): + return len(self.image_files) + + def __getitem__(self, index): + return (self.transforms(os.path.join(self.data_path,self.image_files[index])), + DATA_CLASSES[os.path.dirname(self.image_files[index])]) if __name__ == "__main__": # Prepare data if not already done - if not os.path.exists(abs_path + "/data"): - get_data() + #if not os.path.exists(abs_path + "/data"): + get_data() #load_data('./data') diff --git a/examples/monai-2D-mednist/client/train.py b/examples/monai-2D-mednist/client/train.py index 6003fd4b1..a3ffac004 100644 --- a/examples/monai-2D-mednist/client/train.py +++ b/examples/monai-2D-mednist/client/train.py @@ -6,7 +6,7 @@ import torch from model import load_parameters, save_parameters -from data import load_data, get_classes +from data import load_data, get_classes, MedNISTDataset from fedn.utils.helpers.helpers import save_metadata dir_path = os.path.dirname(os.path.realpath(__file__)) @@ -29,37 +29,16 @@ import numpy as np from monai.data import decollate_batch, DataLoader -def pre_training_settings(num_class, batch_size, train_x, train_y, num_workers=2): - - train_transforms = Compose( - [ - LoadImage(image_only=True), - EnsureChannelFirst(), - ScaleIntensity(), - RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True), - RandFlip(spatial_axis=0, prob=0.5), - RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5), - ] - ) - - - class MedNISTDataset(torch.utils.data.Dataset): - def __init__(self, image_files, labels, transforms): - self.image_files = image_files - self.labels = labels - self.transforms = transforms - - def __len__(self): - return len(self.image_files) - - def __getitem__(self, index): - return self.transforms(self.image_files[index]), self.labels[index] - - - train_ds = MedNISTDataset(train_x, train_y, train_transforms) - train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers ) - - return train_loader +train_transforms = Compose( + [ + LoadImage(image_only=True), + EnsureChannelFirst(), + ScaleIntensity(), + RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True), + RandFlip(spatial_axis=0, prob=0.5), + RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5) + ] +) def train(in_model_path, out_model_path, data_path=None, client_settings_path=None): @@ -82,23 +61,36 @@ def train(in_model_path, out_model_path, data_path=None, client_settings_path=No if client_settings_path is None: client_settings_path = os.environ.get("FEDN_CLIENT_SETTINGS_PATH", dir_path + "/client_settings.yaml") + print("client_settings_path: ", client_settings_path) with open(client_settings_path, 'r') as fh: # Used by CJG for local training try: client_settings = dict(yaml.safe_load(fh)) except yaml.YAMLError as e: raise + + + print("client settings: ", client_settings) batch_size = client_settings['batch_size'] max_epochs = client_settings['local_epochs'] num_workers = client_settings['num_workers'] - sample_size = client_settings['sample_size'] + split_index = client_settings['split_index'] lr = client_settings['lr'] - num_class = len(get_classes(data_path)) + if data_path is None: + data_path = os.environ.get("FEDN_DATA_PATH") + + + with open(os.path.join(os.path.dirname(data_path), "data_splits.yaml"), 'r') as file: + clients = yaml.safe_load(file) + + image_list = clients['client ' + str(split_index)]['train'] + + train_ds = MedNISTDataset(data_path='data/MedNIST', transforms=train_transforms, + image_files=image_list) + + train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers) - # Load data - x_train, y_train = load_data(data_path, sample_size) - train_loader = pre_training_settings(num_class, batch_size, x_train, y_train, num_workers) # Load parmeters and initialize model model = load_parameters(in_model_path) @@ -125,7 +117,7 @@ def train(in_model_path, out_model_path, data_path=None, client_settings_path=No loss.backward() optimizer.step() epoch_loss += loss.item() - print(f"{step}/{len(sample_size) // train_loader.batch_size}, " f"train_loss: {loss.item():.4f}") + print(f"{step}/{len(train_loader) // train_loader.batch_size}, " f"train_loss: {loss.item():.4f}") epoch_loss /= step epoch_loss_values.append(epoch_loss) @@ -136,7 +128,7 @@ def train(in_model_path, out_model_path, data_path=None, client_settings_path=No # Metadata needed for aggregation server side metadata = { # num_examples are mandatory - "num_examples": len(x_train), + "num_examples": len(train_loader), "batch_size": batch_size, "epochs": max_epochs, "lr": lr, diff --git a/examples/monai-2D-mednist/client/validate.py b/examples/monai-2D-mednist/client/validate.py index 22cae61dd..2ad6d23ad 100644 --- a/examples/monai-2D-mednist/client/validate.py +++ b/examples/monai-2D-mednist/client/validate.py @@ -7,7 +7,7 @@ import torch from model import load_parameters, save_parameters -from data import load_data, get_classes +from data import load_data, get_classes, MedNISTDataset, DATA_CLASSES from fedn.utils.helpers.helpers import save_metadata from monai.data import decollate_batch, DataLoader @@ -35,30 +35,7 @@ dir_path = os.path.dirname(os.path.realpath(__file__)) sys.path.append(os.path.abspath(dir_path)) - -def pre_validation_settings(batch_size, train_x, train_y, val_x, val_y, num_workers=2): - - val_transforms = Compose([LoadImage(image_only=True), EnsureChannelFirst(), ScaleIntensity()]) - - class MedNISTDataset(torch.utils.data.Dataset): - def __init__(self, image_files, labels, transforms): - self.image_files = image_files - self.labels = labels - self.transforms = transforms - - def __len__(self): - return len(self.image_files) - - def __getitem__(self, index): - return self.transforms(self.image_files[index]), self.labels[index] - - - val_ds = MedNISTDataset(val_x, val_y, val_transforms) - val_loader = DataLoader(val_ds, batch_size=batch_size, num_workers=num_workers) - - return val_loader - - +val_transforms = Compose([LoadImage(image_only=True), EnsureChannelFirst(), ScaleIntensity()]) def validate(in_model_path, out_json_path, data_path=None, client_settings_path=None): """Validate model. @@ -84,14 +61,21 @@ def validate(in_model_path, out_json_path, data_path=None, client_settings_path= raise num_workers = client_settings['num_workers'] - sample_size = client_settings['sample_size'] batch_size = client_settings['batch_size'] + split_index = client_settings['split_index'] + + if data_path is None: + data_path = os.environ.get("FEDN_DATA_PATH") + + with open(os.path.join(os.path.dirname(data_path), "data_splits.yaml"), 'r') as file: + clients = yaml.safe_load(file) - # Load data - x_train, y_train = load_data(data_path, sample_size) - x_val, y_val = load_data(data_path, sample_size, is_train=False) + image_list = clients['client ' + str(split_index)]['validation'] - val_loader = pre_validation_settings(batch_size, x_train, y_train, x_val, y_val, num_workers) + val_ds = MedNISTDataset(data_path='data/MedNIST', transforms=val_transforms, + image_files=image_list) + + val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers) # Load model model = load_parameters(in_model_path) @@ -112,14 +96,21 @@ def validate(in_model_path, out_json_path, data_path=None, client_settings_path= y_true.append(val_labels[i].item()) y_pred.append(pred[i].item()) - - print(classification_report(y_true, y_pred, digits=4)) + + class_names = list(DATA_CLASSES.keys()) + print("class names: ", class_names) + cr = classification_report(y_true, y_pred, digits=4, output_dict=True, target_names=class_names) + report = {class_name + "_" + metric: cr[class_name][metric] for class_name in cr if isinstance(cr[class_name], dict) for + metric in cr[class_name]} + report.update({class_name: cr[class_name] for class_name in cr if isinstance(cr[class_name], str)}) # JSON schema - report = { + report.update({ "test_accuracy": accuracy_score(y_true, y_pred), "test_f1_score": f1_score(y_true, y_pred, average='macro') - } + }) + for r in report: + print(r , ": ", report[r]) # Save JSON save_metrics(report, out_json_path) diff --git a/examples/monai-2D-mednist/client_settings.yaml b/examples/monai-2D-mednist/client_settings.yaml new file mode 100644 index 000000000..d9eb1a081 --- /dev/null +++ b/examples/monai-2D-mednist/client_settings.yaml @@ -0,0 +1,7 @@ +lr: 0.01 +batch_size: 256 +local_epochs: 10 +num_workers: 1 +sample_size: 100 +split_index: 0 +