Skip to content

Commit

Permalink
add script
Browse files Browse the repository at this point in the history
  • Loading branch information
mattiasakesson committed Jun 3, 2024
1 parent ba822a6 commit bc7bbea
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 76 deletions.
69 changes: 66 additions & 3 deletions examples/monai-2D-mednist/client/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,48 @@
import torch
import torchvision
from monai.apps import download_and_extract
from torch.utils.data import Dataset as Dataset_

dir_path = os.path.dirname(os.path.realpath(__file__))
abs_path = os.path.abspath(dir_path)

import yaml

import numpy as np

DATA_CLASSES = {'AbdomenCT': 0, 'BreastMRI': 1, 'CXR': 2, 'ChestCT': 3, 'Hand': 4, 'HeadCT': 5}


def split_data(data_path='data/MedNIST', splits=100, validation_split=0.9):
# create clients
clients = {'client ' + str(i): {"train": [], "validation": []} for i in range(splits)}
# print("clients: ", clients)

for class_ in os.listdir(data_path):
if os.path.isdir(os.path.join(data_path, class_)):
# print("class_: ", class_)
patients_in_class = [os.path.join(class_, patient) for patient in
os.listdir(os.path.join(data_path, class_))]
# print(patients_in_class)
np.random.shuffle(patients_in_class)
chops = np.int32(np.linspace(0, len(patients_in_class), splits + 1))
for split in range(splits):
# print("split ", split)
p = patients_in_class[chops[split]:chops[split + 1]]
valsplit = np.int32(len(p) * validation_split)
# print("'client ' + str(split): " 'client ' + str(split))
# print("p[:valsplit]: ", p[:valsplit])
# clients['client' + str(split)]["train"] = 3

clients['client ' + str(split)]["train"] += p[:valsplit]
clients['client ' + str(split)]["validation"] += p[valsplit:]

with open(os.path.join(os.path.dirname(data_path), "data_splits.yaml"), 'w') as file:
yaml.dump(clients, file, default_flow_style=False)





def get_data(out_dir="data"):
"""Get data from the external repository.
Expand Down Expand Up @@ -37,6 +75,10 @@ def get_data(out_dir="data"):
else:
print('files already exist.')

split_data()



def get_classes(data_path):
"""Get a list of classes from the dataset
Expand All @@ -50,6 +92,12 @@ def get_classes(data_path):
class_names = sorted(x for x in os.listdir(data_path) if os.path.isdir(os.path.join(data_path, x)))
return(class_names)







def load_data(data_path, sample_size=None, is_train=True):
"""Load data from disk.
Expand Down Expand Up @@ -114,11 +162,26 @@ def load_data(data_path, sample_size=None, is_train=True):
if is_train:
return train_x, train_y
else:
return val_x, val_y
return val_x, val_y, class_names



class MedNISTDataset(torch.utils.data.Dataset):
def __init__(self, data_path, image_files, transforms):
self.data_path = data_path
self.image_files = image_files
self.transforms = transforms

def __len__(self):
return len(self.image_files)

def __getitem__(self, index):

return (self.transforms(os.path.join(self.data_path,self.image_files[index])),
DATA_CLASSES[os.path.dirname(self.image_files[index])])

if __name__ == "__main__":
# Prepare data if not already done
if not os.path.exists(abs_path + "/data"):
get_data()
#if not os.path.exists(abs_path + "/data"):
get_data()
#load_data('./data')
70 changes: 31 additions & 39 deletions examples/monai-2D-mednist/client/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch
from model import load_parameters, save_parameters
from data import load_data, get_classes
from data import load_data, get_classes, MedNISTDataset
from fedn.utils.helpers.helpers import save_metadata

dir_path = os.path.dirname(os.path.realpath(__file__))
Expand All @@ -29,37 +29,16 @@
import numpy as np
from monai.data import decollate_batch, DataLoader

def pre_training_settings(num_class, batch_size, train_x, train_y, num_workers=2):

train_transforms = Compose(
[
LoadImage(image_only=True),
EnsureChannelFirst(),
ScaleIntensity(),
RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True),
RandFlip(spatial_axis=0, prob=0.5),
RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5),
]
)


class MedNISTDataset(torch.utils.data.Dataset):
def __init__(self, image_files, labels, transforms):
self.image_files = image_files
self.labels = labels
self.transforms = transforms

def __len__(self):
return len(self.image_files)

def __getitem__(self, index):
return self.transforms(self.image_files[index]), self.labels[index]


train_ds = MedNISTDataset(train_x, train_y, train_transforms)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers )

return train_loader
train_transforms = Compose(
[
LoadImage(image_only=True),
EnsureChannelFirst(),
ScaleIntensity(),
RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True),
RandFlip(spatial_axis=0, prob=0.5),
RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5)
]
)


def train(in_model_path, out_model_path, data_path=None, client_settings_path=None):
Expand All @@ -82,23 +61,36 @@ def train(in_model_path, out_model_path, data_path=None, client_settings_path=No
if client_settings_path is None:
client_settings_path = os.environ.get("FEDN_CLIENT_SETTINGS_PATH", dir_path + "/client_settings.yaml")

print("client_settings_path: ", client_settings_path)
with open(client_settings_path, 'r') as fh: # Used by CJG for local training

try:
client_settings = dict(yaml.safe_load(fh))
except yaml.YAMLError as e:
raise


print("client settings: ", client_settings)
batch_size = client_settings['batch_size']
max_epochs = client_settings['local_epochs']
num_workers = client_settings['num_workers']
sample_size = client_settings['sample_size']
split_index = client_settings['split_index']
lr = client_settings['lr']

num_class = len(get_classes(data_path))
if data_path is None:
data_path = os.environ.get("FEDN_DATA_PATH")


with open(os.path.join(os.path.dirname(data_path), "data_splits.yaml"), 'r') as file:
clients = yaml.safe_load(file)

image_list = clients['client ' + str(split_index)]['train']

train_ds = MedNISTDataset(data_path='data/MedNIST', transforms=train_transforms,
image_files=image_list)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers)

# Load data
x_train, y_train = load_data(data_path, sample_size)
train_loader = pre_training_settings(num_class, batch_size, x_train, y_train, num_workers)

# Load parmeters and initialize model
model = load_parameters(in_model_path)
Expand All @@ -125,7 +117,7 @@ def train(in_model_path, out_model_path, data_path=None, client_settings_path=No
loss.backward()
optimizer.step()
epoch_loss += loss.item()
print(f"{step}/{len(sample_size) // train_loader.batch_size}, " f"train_loss: {loss.item():.4f}")
print(f"{step}/{len(train_loader) // train_loader.batch_size}, " f"train_loss: {loss.item():.4f}")

epoch_loss /= step
epoch_loss_values.append(epoch_loss)
Expand All @@ -136,7 +128,7 @@ def train(in_model_path, out_model_path, data_path=None, client_settings_path=No
# Metadata needed for aggregation server side
metadata = {
# num_examples are mandatory
"num_examples": len(x_train),
"num_examples": len(train_loader),
"batch_size": batch_size,
"epochs": max_epochs,
"lr": lr,
Expand Down
59 changes: 25 additions & 34 deletions examples/monai-2D-mednist/client/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import torch
from model import load_parameters, save_parameters
from data import load_data, get_classes
from data import load_data, get_classes, MedNISTDataset, DATA_CLASSES
from fedn.utils.helpers.helpers import save_metadata

from monai.data import decollate_batch, DataLoader
Expand Down Expand Up @@ -35,30 +35,7 @@
dir_path = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.abspath(dir_path))


def pre_validation_settings(batch_size, train_x, train_y, val_x, val_y, num_workers=2):

val_transforms = Compose([LoadImage(image_only=True), EnsureChannelFirst(), ScaleIntensity()])

class MedNISTDataset(torch.utils.data.Dataset):
def __init__(self, image_files, labels, transforms):
self.image_files = image_files
self.labels = labels
self.transforms = transforms

def __len__(self):
return len(self.image_files)

def __getitem__(self, index):
return self.transforms(self.image_files[index]), self.labels[index]


val_ds = MedNISTDataset(val_x, val_y, val_transforms)
val_loader = DataLoader(val_ds, batch_size=batch_size, num_workers=num_workers)

return val_loader


val_transforms = Compose([LoadImage(image_only=True), EnsureChannelFirst(), ScaleIntensity()])

def validate(in_model_path, out_json_path, data_path=None, client_settings_path=None):
"""Validate model.
Expand All @@ -84,14 +61,21 @@ def validate(in_model_path, out_json_path, data_path=None, client_settings_path=
raise

num_workers = client_settings['num_workers']
sample_size = client_settings['sample_size']
batch_size = client_settings['batch_size']
split_index = client_settings['split_index']

if data_path is None:
data_path = os.environ.get("FEDN_DATA_PATH")

with open(os.path.join(os.path.dirname(data_path), "data_splits.yaml"), 'r') as file:
clients = yaml.safe_load(file)

# Load data
x_train, y_train = load_data(data_path, sample_size)
x_val, y_val = load_data(data_path, sample_size, is_train=False)
image_list = clients['client ' + str(split_index)]['validation']

val_loader = pre_validation_settings(batch_size, x_train, y_train, x_val, y_val, num_workers)
val_ds = MedNISTDataset(data_path='data/MedNIST', transforms=val_transforms,
image_files=image_list)

val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers)

# Load model
model = load_parameters(in_model_path)
Expand All @@ -112,14 +96,21 @@ def validate(in_model_path, out_json_path, data_path=None, client_settings_path=
y_true.append(val_labels[i].item())
y_pred.append(pred[i].item())


print(classification_report(y_true, y_pred, digits=4))

class_names = list(DATA_CLASSES.keys())
print("class names: ", class_names)
cr = classification_report(y_true, y_pred, digits=4, output_dict=True, target_names=class_names)
report = {class_name + "_" + metric: cr[class_name][metric] for class_name in cr if isinstance(cr[class_name], dict) for
metric in cr[class_name]}
report.update({class_name: cr[class_name] for class_name in cr if isinstance(cr[class_name], str)})

# JSON schema
report = {
report.update({
"test_accuracy": accuracy_score(y_true, y_pred),
"test_f1_score": f1_score(y_true, y_pred, average='macro')
}
})
for r in report:
print(r , ": ", report[r])

# Save JSON
save_metrics(report, out_json_path)
Expand Down
7 changes: 7 additions & 0 deletions examples/monai-2D-mednist/client_settings.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
lr: 0.01
batch_size: 256
local_epochs: 10
num_workers: 1
sample_size: 100
split_index: 0

0 comments on commit bc7bbea

Please sign in to comment.