Skip to content

Commit

Permalink
Now client_settings.yaml file is added to provide client-side settings.
Browse files Browse the repository at this point in the history
  • Loading branch information
root committed May 21, 2024
1 parent a7a8f7f commit f2f7b51
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 97 deletions.
6 changes: 6 additions & 0 deletions examples/monai-2D-mednist/client/client_settings.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
lr: 0.01
batch_size: 30
local_epochs: 1
num_workers: 1
sample_size: 100

30 changes: 15 additions & 15 deletions examples/monai-2D-mednist/client/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def get_classes(data_path):
class_names = sorted(x for x in os.listdir(data_path) if os.path.isdir(os.path.join(data_path, x)))
return(class_names)

def load_data(data_path, is_train=True):
def load_data(data_path, sample_size=None, is_train=True):
"""Load data from disk.
:param data_path: Path to data directory.
Expand All @@ -64,17 +64,21 @@ def load_data(data_path, is_train=True):
data_path = os.environ.get("FEDN_DATA_PATH", abs_path + "/data/MedNIST")

class_names = get_classes(data_path)

num_class = len(class_names)

image_files_all = [
[os.path.join(data_path, class_names[i], x) for x in os.listdir(os.path.join(data_path, class_names[i]))]
for i in range(num_class)
]

# To make the dataset small, we are using 100 images of each class.
sample_size = 100
image_files = [random.sample(inner_list, sample_size) for inner_list in image_files_all]
# To make the dataset small, we are using sample_size=100 images of each class.

if sample_size is None:

image_files = image_files_all

else:
image_files = [random.sample(inner_list, sample_size) for inner_list in image_files_all]

num_each = [len(image_files[i]) for i in range(num_class)]
image_files_list = []
Expand All @@ -91,30 +95,26 @@ def load_data(data_path, is_train=True):
print(f"Label counts: {num_each}")

val_frac = 0.1
test_frac = 0.1
#test_frac = 0.1
length = len(image_files_list)
indices = np.arange(length)
np.random.shuffle(indices)

test_split = int(test_frac * length)
val_split = int(val_frac * length) + test_split
test_indices = indices[:test_split]
val_indices = indices[test_split:val_split]
val_split = int(val_frac * length)
val_indices = indices[:val_split]
train_indices = indices[val_split:]

train_x = [image_files_list[i] for i in train_indices]
train_y = [image_class[i] for i in train_indices]
val_x = [image_files_list[i] for i in val_indices]
val_y = [image_class[i] for i in val_indices]
test_x = [image_files_list[i] for i in test_indices]
test_y = [image_class[i] for i in test_indices]

print(f"Training count: {len(train_x)}, Validation count: " f"{len(val_x)}, Test count: {len(test_x)}")
print(f"Training count: {len(train_x)}, Validation count: " f"{len(val_x)}")

if is_train:
return train_x, train_y, val_x, val_y
return train_x, train_y
else:
return test_x, test_y
return val_x, val_y


if __name__ == "__main__":
Expand Down
94 changes: 30 additions & 64 deletions examples/monai-2D-mednist/client/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import os
import sys

import yaml

import torch
from model import load_parameters, save_parameters
from monai.metrics import ROCAUCMetric
from data import load_data, get_classes
from fedn.utils.helpers.helpers import save_metadata

Expand All @@ -28,7 +29,7 @@
import numpy as np
from monai.data import decollate_batch, DataLoader

def pre_training_settings(num_class, batch_size, train_x, train_y, val_x, val_y):
def pre_training_settings(num_class, batch_size, train_x, train_y, num_workers=2):

train_transforms = Compose(
[
Expand All @@ -41,10 +42,6 @@ def pre_training_settings(num_class, batch_size, train_x, train_y, val_x, val_y)
]
)

val_transforms = Compose([LoadImage(image_only=True), EnsureChannelFirst(), ScaleIntensity()])

y_pred_trans = Compose([Activations(softmax=True)])
y_trans = Compose([AsDiscrete(to_onehot=num_class)])

class MedNISTDataset(torch.utils.data.Dataset):
def __init__(self, image_files, labels, transforms):
Expand All @@ -60,15 +57,12 @@ def __getitem__(self, index):


train_ds = MedNISTDataset(train_x, train_y, train_transforms)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2)

val_ds = MedNISTDataset(val_x, val_y, val_transforms)
val_loader = DataLoader(val_ds, batch_size=batch_size, num_workers=2)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers )

return train_ds, train_loader, val_ds, val_loader, y_pred_trans, y_trans
return train_ds, train_loader


def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1, lr=0.01):
def train(in_model_path, out_model_path, data_path=None, client_settings_path=None):
"""Complete a model update.
Load model paramters from in_model_path (managed by the FEDn client),
Expand All @@ -81,21 +75,31 @@ def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1
:type out_model_path: str
:param data_path: The path to the data directory.
:type data_path: str
:param batch_size: The batch size to use.
:type batch_size: int
:param epochs: The number of epochs to train.
:type epochs: int
:param lr: The learning rate to use.
:type lr: float
:param client_settings_path: path to a local client settings file.
:type client_settings_path: str
"""
max_epochs = epochs
val_interval = 1
auc_metric = ROCAUCMetric()


if client_settings_path is None:
client_settings_path = os.environ.get("FEDN_CLIENT_SETTINGS_PATH", dir_path + "/client_settings.yaml")

with open(client_settings_path, 'r') as fh: # Used by CJG for local training

try:
client_settings = dict(yaml.safe_load(fh))
except yaml.YAMLError as e:
raise
batch_size = client_settings['batch_size']
max_epochs = client_settings['local_epochs']
num_workers = client_settings['num_workers']
sample_size = client_settings['sample_size']
lr = client_settings['lr']

#val_interval = 1
num_class = len(get_classes(data_path))

# Load data
x_train, y_train, x_val, y_val = load_data(data_path)
train_ds, train_loader, val_ds, val_loader, y_pred_trans, y_trans = pre_training_settings(num_class, batch_size, x_train, y_train, x_val, y_val)
x_train, y_train = load_data(data_path, sample_size)
train_ds, train_loader = pre_training_settings(num_class, batch_size, x_train, y_train, num_workers)


# Load parmeters and initialize model
Expand Down Expand Up @@ -133,52 +137,14 @@ def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1
epoch_loss_values.append(epoch_loss)
print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

if (epoch + 1) % val_interval == 0:
model.eval()
with torch.no_grad():
y_pred = torch.tensor([], dtype=torch.float32, device=device)
y = torch.tensor([], dtype=torch.long, device=device)
for val_data in val_loader:
val_images, val_labels = (
val_data[0].to(device),
val_data[1].to(device),
)
y_pred = torch.cat([y_pred, model(val_images)], dim=0)
y = torch.cat([y, val_labels], dim=0)
y_onehot = [y_trans(i) for i in decollate_batch(y, detach=False)]
y_pred_act = [y_pred_trans(i) for i in decollate_batch(y_pred)]
auc_metric(y_pred_act, y_onehot)
result = auc_metric.aggregate()
auc_metric.reset()
del y_pred_act, y_onehot
metric_values.append(result)
acc_value = torch.eq(y_pred.argmax(dim=1), y)
acc_metric = acc_value.sum().item() / len(acc_value)
if result > best_metric:
best_metric = result
best_metric_epoch = epoch + 1
torch.save(model.state_dict(), os.path.join("best_metric_model.pth"))
print("saved new best metric model")
print(
f"current epoch: {epoch + 1} current AUC: {result:.4f}"
f" current accuracy: {acc_metric:.4f}"
f" best AUC: {best_metric:.4f}"
f" at epoch: {best_metric_epoch}"
)
#writer.add_scalar("val_accuracy", acc_metric, epoch + 1)

print(f"train completed, best_metric: {best_metric:.4f} " f"at epoch: {best_metric_epoch}")
#writer.close()


#print(f"Epoch {e}/{epochs-1} | Batch: {b}/{n_batches-1} | Loss: {loss.item()}")
print(f"train completed!")

# Metadata needed for aggregation server side
metadata = {
# num_examples are mandatory
"num_examples": len(x_train),
"batch_size": batch_size,
"epochs": epochs,
"epochs": max_epochs,
"lr": lr,
}

Expand Down
51 changes: 33 additions & 18 deletions examples/monai-2D-mednist/client/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch
from model import load_parameters
import yaml

import torch
from model import load_parameters, save_parameters
Expand Down Expand Up @@ -35,7 +36,7 @@
sys.path.append(os.path.abspath(dir_path))


def pre_validation_settings(num_class,train_x, train_y, test_x, test_y):
def pre_validation_settings(num_class, batch_size, train_x, train_y, val_x, val_y, num_workers=2):

train_transforms = Compose(
[
Expand All @@ -50,9 +51,6 @@ def pre_validation_settings(num_class,train_x, train_y, test_x, test_y):

val_transforms = Compose([LoadImage(image_only=True), EnsureChannelFirst(), ScaleIntensity()])

y_pred_trans = Compose([Activations(softmax=True)])
y_trans = Compose([AsDiscrete(to_onehot=num_class)])

class MedNISTDataset(torch.utils.data.Dataset):
def __init__(self, image_files, labels, transforms):
self.image_files = image_files
Expand All @@ -66,16 +64,16 @@ def __getitem__(self, index):
return self.transforms(self.image_files[index]), self.labels[index]

train_ds = MedNISTDataset(train_x, train_y, val_transforms)
train_loader = DataLoader(train_ds, batch_size=30, num_workers=1)
train_loader = DataLoader(train_ds, batch_size=batch_size, num_workers=num_workers)

test_ds = MedNISTDataset(test_x, test_y, val_transforms)
test_loader = DataLoader(test_ds, batch_size=30, num_workers=1)
val_ds = MedNISTDataset(val_x, val_y, val_transforms)
val_loader = DataLoader(val_ds, batch_size=batch_size, num_workers=num_workers)

return train_ds, train_loader, test_ds, test_loader
return train_ds, train_loader, val_ds, val_loader



def validate(in_model_path, out_json_path, data_path=None):
def validate(in_model_path, out_json_path, data_path=None, client_settings_path=None):
"""Validate model.
:param in_model_path: The path to the input model.
Expand All @@ -84,13 +82,30 @@ def validate(in_model_path, out_json_path, data_path=None):
:type out_json_path: str
:param data_path: The path to the data file.
:type data_path: str
:param client_settings_path: The path to the local client settings file.
:type client_settings_path: str
"""

if client_settings_path is None:
client_settings_path = os.environ.get("FEDN_CLIENT_SETTINGS_PATH", dir_path + "/client_settings.yaml")

with open(client_settings_path, 'r') as fh: # Used by CJG for local training

try:
client_settings = dict(yaml.safe_load(fh))
except yaml.YAMLError as e:
raise

num_workers = client_settings['num_workers']
sample_size = client_settings['sample_size']
batch_size = client_settings['batch_size']

# Load data
x_train, y_train, _, _ = load_data(data_path)
x_test, y_test = load_data(data_path, is_train=False)
x_train, y_train = load_data(data_path, sample_size)
x_val, y_val = load_data(data_path, sample_size, is_train=False)

num_class = len(get_classes(data_path))
train_ds, train_loader, test_ds, test_loader = pre_validation_settings(num_class, x_train, y_train, x_test, y_test)
train_ds, train_loader, val_ds, val_loader = pre_validation_settings(num_class, batch_size, x_train, y_train, x_val, y_val, num_workers)

# Load model
model = load_parameters(in_model_path)
Expand All @@ -101,14 +116,14 @@ def validate(in_model_path, out_json_path, data_path=None):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

with torch.no_grad():
for test_data in test_loader:
test_images, test_labels = (
test_data[0].to(device),
test_data[1].to(device),
for val_data in val_loader:
val_images, val_labels = (
val_data[0].to(device),
val_data[1].to(device),
)
pred = model(test_images).argmax(dim=1)
pred = model(val_images).argmax(dim=1)
for i in range(len(pred)):
y_true.append(test_labels[i].item())
y_true.append(val_labels[i].item())
y_pred.append(pred[i].item())


Expand Down
8 changes: 8 additions & 0 deletions examples/monai-2D-mednist/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
setuptools
wheel==0.37.1
torch==2.2.1
torchvision==0.17.1
fedn==0.9.0
monai-weekly[pillow, tqdm]
scikit-learn
tensorboard

0 comments on commit f2f7b51

Please sign in to comment.