diff --git a/examples/FedSimSiam/.dockerignore b/examples/FedSimSiam/.dockerignore new file mode 100644 index 000000000..8ba9024ad --- /dev/null +++ b/examples/FedSimSiam/.dockerignore @@ -0,0 +1,4 @@ +data +seed.npz +*.tgz +*.tar.gz \ No newline at end of file diff --git a/examples/FedSimSiam/.gitignore b/examples/FedSimSiam/.gitignore new file mode 100644 index 000000000..047341d71 --- /dev/null +++ b/examples/FedSimSiam/.gitignore @@ -0,0 +1,6 @@ +data +*.npz +*.tgz +*.tar.gz +.fedsimsiam +client.yaml \ No newline at end of file diff --git a/examples/FedSimSiam/client/data.py b/examples/FedSimSiam/client/data.py new file mode 100644 index 000000000..f18ea1f32 --- /dev/null +++ b/examples/FedSimSiam/client/data.py @@ -0,0 +1,151 @@ +import os +from math import floor + +import torch +import torchvision +import torchvision.transforms as transforms +import numpy as np +import os + +dir_path = os.path.dirname(os.path.realpath(__file__)) +abs_path = os.path.abspath(dir_path) + + +def get_data(out_dir='data'): + # Make dir if necessary + if not os.path.exists(out_dir): + os.mkdir(out_dir) + + # Only download if not already downloaded + if not os.path.exists(f'{out_dir}/train'): + torchvision.datasets.CIFAR10( + root=f'{out_dir}/train', train=True, download=True) + + if not os.path.exists(f'{out_dir}/test'): + torchvision.datasets.CIFAR10( + root=f'{out_dir}/test', train=False, download=True) + + +def load_data(data_path, is_train=True): + """ Load data from disk. + + :param data_path: Path to data file. + :type data_path: str + :param is_train: Whether to load training or test data. + :type is_train: bool + :return: Tuple of data and labels. + :rtype: tuple + """ + if data_path is None: + data_path = os.environ.get( + "FEDN_DATA_PATH", abs_path+'/data/clients/1/cifar10.pt') + + data = torch.load(data_path) + + if is_train: + X = data['x_train'] + y = data['y_train'] + else: + X = data['x_test'] + y = data['y_test'] + + return X, y + + +def create_knn_monitoring_dataset(out_dir='data'): + """ Creates dataset that is used to monitor the training progress via knn accuracies """ + if not os.path.exists(out_dir): + os.mkdir(out_dir) + + n_splits = int(os.environ.get("FEDN_NUM_DATA_SPLITS", 2)) + + # Make dir + if not os.path.exists(f'{out_dir}/clients'): + os.mkdir(f'{out_dir}/clients') + + normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], + std=[0.247, 0.243, 0.261]) + + memoryset = torchvision.datasets.CIFAR10(root='./data', train=True, + download=True, transform=transforms.Compose([transforms.ToTensor(), normalize])) + testset = torchvision.datasets.CIFAR10(root='./data', train=False, + download=True, transform=transforms.Compose([transforms.ToTensor(), normalize])) + + # save monitoring datasets to all clients + for i in range(n_splits): + subdir = f'{out_dir}/clients/{str(i+1)}' + if not os.path.exists(subdir): + os.mkdir(subdir) + torch.save(memoryset, f'{subdir}/knn_memoryset.pt') + torch.save(testset, f'{subdir}/knn_testset.pt') + + +def load_knn_monitoring_dataset(data_path, batch_size=16): + """ Loads the KNN monitoring dataset.""" + if data_path is None: + data_path = os.environ.get( + "FEDN_DATA_PATH", abs_path+'/data/clients/1/cifar10.pt') + + data_directory = os.path.dirname(data_path) + memory_path = os.path.join(data_directory, 'knn_memoryset.pt') + testset_path = os.path.join(data_directory, 'knn_testset.pt') + + memoryset = torch.load(memory_path) + testset = torch.load(testset_path) + + memoryset_loader = torch.utils.data.DataLoader( + memoryset, batch_size=batch_size, shuffle=False) + testset_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, + shuffle=False) + return memoryset_loader, testset_loader + + +def splitset(dataset, parts): + n = dataset.shape[0] + local_n = floor(n/parts) + result = [] + for i in range(parts): + result.append(dataset[i*local_n: (i+1)*local_n]) + return result + + +def split(out_dir='data'): + + n_splits = int(os.environ.get("FEDN_NUM_DATA_SPLITS", 2)) + + # Make dir + if not os.path.exists(f'{out_dir}/clients'): + os.mkdir(f'{out_dir}/clients') + + train_data = torchvision.datasets.CIFAR10( + root=f'{out_dir}/train', train=True) + test_data = torchvision.datasets.CIFAR10( + root=f'{out_dir}/test', train=False) + + data = { + 'x_train': splitset(train_data.data, n_splits), + 'y_train': splitset(np.array(train_data.targets), n_splits), + 'x_test': splitset(test_data.data, n_splits), + 'y_test': splitset(np.array(test_data.targets), n_splits), + } + + # Make splits + for i in range(n_splits): + subdir = f'{out_dir}/clients/{str(i+1)}' + if not os.path.exists(subdir): + os.mkdir(subdir) + torch.save({ + 'x_train': data['x_train'][i], + 'y_train': data['y_train'][i], + 'x_test': data['x_test'][i], + 'y_test': data['y_test'][i], + }, + f'{subdir}/cifar10.pt') + + +if __name__ == '__main__': + # Prepare data if not already done + if not os.path.exists(abs_path+'/data/clients/1'): + get_data() + split() + create_knn_monitoring_dataset() diff --git a/examples/FedSimSiam/client/fedn.yaml b/examples/FedSimSiam/client/fedn.yaml new file mode 100644 index 000000000..b05504102 --- /dev/null +++ b/examples/FedSimSiam/client/fedn.yaml @@ -0,0 +1,10 @@ +python_env: python_env.yaml +entry_points: + build: + command: python model.py + startup: + command: python data.py + train: + command: python train.py + validate: + command: python validate.py \ No newline at end of file diff --git a/examples/FedSimSiam/client/model.py b/examples/FedSimSiam/client/model.py new file mode 100644 index 000000000..7ba59b426 --- /dev/null +++ b/examples/FedSimSiam/client/model.py @@ -0,0 +1,144 @@ +import torch.nn.functional as F +from torchvision.models import resnet18 +import torch.nn as nn +import collections + +import torch + +from fedn.utils.helpers.helpers import get_helper + +HELPER_MODULE = 'numpyhelper' +helper = get_helper(HELPER_MODULE) + + +def D(p, z, version='simplified'): # negative cosine similarity + if version == 'original': + z = z.detach() # stop gradient + p = F.normalize(p, dim=1) # l2-normalize + z = F.normalize(z, dim=1) # l2-normalize + return -(p*z).sum(dim=1).mean() + + elif version == 'simplified': # same thing, much faster. Scroll down, speed test in __main__ + return - F.cosine_similarity(p, z.detach(), dim=-1).mean() + else: + raise Exception + + +class ProjectionMLP(nn.Module): + """Projection MLP f""" + + def __init__(self, in_features, h1_features, h2_features, out_features): + super(ProjectionMLP, self).__init__() + self.l1 = nn.Sequential( + nn.Linear(in_features, h1_features), + nn.BatchNorm1d(h1_features), + nn.ReLU(inplace=True) + ) + self.l2 = nn.Sequential( + nn.Linear(h1_features, out_features), + nn.BatchNorm1d(out_features) + ) + + def forward(self, x): + x = self.l1(x) + x = self.l2(x) + return x + + +class PredictionMLP(nn.Module): + """Prediction MLP h""" + + def __init__(self, in_features, hidden_features, out_features): + super(PredictionMLP, self).__init__() + self.l1 = nn.Sequential( + nn.Linear(in_features, hidden_features), + nn.BatchNorm1d(hidden_features), + nn.ReLU(inplace=True) + ) + self.l2 = nn.Linear(hidden_features, out_features) + + def forward(self, x): + x = self.l1(x) + x = self.l2(x) + return x + + +class SimSiam(nn.Module): + def __init__(self): + super(SimSiam, self).__init__() + backbone = resnet18(pretrained=False) + backbone.output_dim = backbone.fc.in_features + backbone.fc = torch.nn.Identity() + + self.backbone = backbone + + self.projector = ProjectionMLP(backbone.output_dim, 2048, 2048, 2048) + self.encoder = nn.Sequential( + self.backbone, + self.projector + ) + self.predictor = PredictionMLP(2048, 512, 2048) + + def forward(self, x1, x2): + f, h = self.encoder, self.predictor + z1, z2 = f(x1), f(x2) + p1, p2 = h(z1), h(z2) + L = D(p1, z2) / 2 + D(p2, z1) / 2 + return {'loss': L} + + +def compile_model(): + """ Compile the pytorch model. + + :return: The compiled model. + :rtype: torch.nn.Module + """ + model = SimSiam() + + return model + + +def save_parameters(model, out_path): + """ Save model paramters to file. + + :param model: The model to serialize. + :type model: torch.nn.Module + :param out_path: The path to save to. + :type out_path: str + """ + parameters_np = [val.cpu().numpy() + for _, val in model.state_dict().items()] + helper.save(parameters_np, out_path) + + +def load_parameters(model_path): + """ Load model parameters from file and populate model. + + param model_path: The path to load from. + :type model_path: str + :return: The loaded model. + :rtype: torch.nn.Module + """ + model = compile_model() + parameters_np = helper.load(model_path) + + params_dict = zip(model.state_dict().keys(), parameters_np) + state_dict = collections.OrderedDict( + {key: torch.tensor(x) for key, x in params_dict}) + model.load_state_dict(state_dict, strict=True) + return model + + +def init_seed(out_path='seed.npz'): + """ Initialize seed model and save it to file. + + :param out_path: The path to save the seed model to. + :type out_path: str + """ + # Init and save + model = compile_model() + save_parameters(model, out_path) + + +if __name__ == "__main__": + init_seed('../seed.npz') diff --git a/examples/FedSimSiam/client/monitoring.py b/examples/FedSimSiam/client/monitoring.py new file mode 100644 index 000000000..60da3e12a --- /dev/null +++ b/examples/FedSimSiam/client/monitoring.py @@ -0,0 +1,63 @@ +""" knn monitor as in InstDisc https://arxiv.org/abs/1805.01978 + implementation follows http://github.com/zhirongw/lemniscate.pytorch and https://github.com/leftthomas/SimCLR +""" +import torch.nn.functional as F +import torch + + +def knn_monitor(net, memory_data_loader, test_data_loader, epoch, k=200, t=0.1, hide_progress=False): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + net = net.to(device) + net.eval() + classes = len(memory_data_loader.dataset.classes) + total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0, [] + with torch.no_grad(): + # generate feature bank + for data, target in memory_data_loader: + # feature = net(data.cuda(non_blocking=True)) + feature = net(data.to(device)) + feature = F.normalize(feature, dim=1) + feature_bank.append(feature) + # [D, N] + feature_bank = torch.cat(feature_bank, dim=0).t().contiguous() + # [N] + feature_labels = torch.tensor( + memory_data_loader.dataset.targets, device=feature_bank.device) + # loop test data to predict the label by weighted knn search + for data, target in test_data_loader: + # data, target = data.cuda( + # non_blocking=True), target.cuda(non_blocking=True) + data, target = data.to(device), target.to(device) + feature = net(data) + feature = F.normalize(feature, dim=1) + + pred_labels = knn_predict( + feature, feature_bank, feature_labels, classes, k, t) + + total_num += data.size(0) + total_top1 += (pred_labels[:, 0] == target).float().sum().item() + return total_top1 / total_num # * 100 + + +def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t): + # compute cos similarity between each feature vector and feature bank ---> [B, N] + sim_matrix = torch.mm(feature, feature_bank) + # [B, K] + sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1) + # [B, K] + sim_labels = torch.gather(feature_labels.expand( + feature.size(0), -1), dim=-1, index=sim_indices) + sim_weight = (sim_weight / knn_t).exp() + + # counts for each class + one_hot_label = torch.zeros(feature.size( + 0) * knn_k, classes, device=sim_labels.device) + # [B*K, C] + one_hot_label = one_hot_label.scatter( + dim=-1, index=sim_labels.view(-1, 1), value=1.0) + # weighted score ---> [B, C] + pred_scores = torch.sum(one_hot_label.view(feature.size( + 0), -1, classes) * sim_weight.unsqueeze(dim=-1), dim=1) + + pred_labels = pred_scores.argsort(dim=-1, descending=True) + return pred_labels diff --git a/examples/FedSimSiam/client/python_env.yaml b/examples/FedSimSiam/client/python_env.yaml new file mode 100644 index 000000000..e898a34da --- /dev/null +++ b/examples/FedSimSiam/client/python_env.yaml @@ -0,0 +1,9 @@ +name: fedsimsiam +build_dependencies: + - pip + - setuptools + - wheel==0.37.1 +dependencies: + - torch==2.2.1 + - torchvision==0.17.1 + - fedn==0.9.0 \ No newline at end of file diff --git a/examples/FedSimSiam/client/train.py b/examples/FedSimSiam/client/train.py new file mode 100644 index 000000000..cc41ea9ed --- /dev/null +++ b/examples/FedSimSiam/client/train.py @@ -0,0 +1,138 @@ +import os +import sys + +import torch +from torch.utils.data import Dataset +from torchvision import transforms +from PIL import Image +import torch +from torch.utils.data import DataLoader +import numpy as np +import torch.optim as optim + +from data import load_data +from model import load_parameters, save_parameters +from utils import init_lrscheduler + +from fedn.utils.helpers.helpers import save_metadata + +dir_path = os.path.dirname(os.path.realpath(__file__)) +sys.path.append(os.path.abspath(dir_path)) + + +class SimSiamDataset(Dataset): + def __init__(self, x, y, is_train=True): + self.x = x + self.y = y + self.is_train = is_train + + def __getitem__(self, idx): + x = self.x[idx] + x = Image.fromarray(x.astype(np.uint8)) + + y = self.y[idx] + + normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], + std=[0.247, 0.243, 0.261]) + augmentation = [ + transforms.RandomResizedCrop(32, scale=(0.2, 1.)), + transforms.RandomApply([ + transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) + ], p=0.8), + transforms.RandomGrayscale(p=0.2), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize + ] + + if self.is_train: + transform = transforms.Compose(augmentation) + + x1 = transform(x) + x2 = transform(x) + return [x1, x2], y + + else: + transform = transforms.Compose([transforms.ToTensor(), normalize]) + + x = transform(x) + return x, y + + def __len__(self): + return len(self.x) + + +def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1, lr=0.01): + """ Complete a model update. + + Load model paramters from in_model_path (managed by the FEDn client), + perform a model update, and write updated paramters + to out_model_path (picked up by the FEDn client). + + :param in_model_path: The path to the input model. + :type in_model_path: str + :param out_model_path: The path to save the output model to. + :type out_model_path: str + :param data_path: The path to the data file. + :type data_path: str + :param batch_size: The batch size to use. + :type batch_size: int + :param epochs: The number of epochs to train. + :type epochs: int + :param lr: The learning rate to use. + :type lr: float + """ + # Load data + x_train, y_train = load_data(data_path) + + # Load parmeters and initialize model + model = load_parameters(in_model_path) + + trainset = SimSiamDataset(x_train, y_train, is_train=True) + trainloader = DataLoader( + trainset, batch_size=batch_size, shuffle=True) + + device = torch.device( + 'cuda') if torch.cuda.is_available() else torch.device('cpu') + model = model.to(device) + model.train() + + # optimizer = optim.SGD(model.parameters(), lr=0.03, + # momentum=0.9, weight_decay=0.0005) + + optimizer, lr_scheduler = init_lrscheduler( + model, 500, trainloader) # TODO: Change num epochs + + print("starting training with lr ", optimizer.param_groups[0]['lr']) + for epoch in range(epochs): + for idx, data in enumerate(trainloader): + images = data[0] + optimizer.zero_grad() + data_dict = model.forward(images[0].to( + device, non_blocking=True), images[1].to(device, non_blocking=True)) + loss = data_dict['loss'].mean() + print(loss) + loss.backward() + optimizer.step() + lr_scheduler.step() + + print('last learning rate: ', optimizer.param_groups[0]['lr']) + + # Metadata needed for aggregation server side + metadata = { + # num_examples are mandatory + 'num_examples': len(x_train), + 'batch_size': batch_size, + 'epochs': epochs, + 'lr': lr + } + + # Save JSON metadata file (mandatory) + save_metadata(metadata, out_model_path) + + # Save model update (mandatory) + save_parameters(model, out_model_path) + + +if __name__ == "__main__": + train(sys.argv[1], sys.argv[2]) diff --git a/examples/FedSimSiam/client/utils.py b/examples/FedSimSiam/client/utils.py new file mode 100644 index 000000000..bfb8be88c --- /dev/null +++ b/examples/FedSimSiam/client/utils.py @@ -0,0 +1,78 @@ +import torch +import numpy as np + + +class LR_Scheduler(object): + def __init__(self, optimizer, warmup_epochs, warmup_lr, num_epochs, base_lr, final_lr, iter_per_epoch, constant_predictor_lr=False): + self.base_lr = base_lr + self.constant_predictor_lr = constant_predictor_lr + warmup_iter = iter_per_epoch * warmup_epochs + warmup_lr_schedule = np.linspace(warmup_lr, base_lr, warmup_iter) + decay_iter = iter_per_epoch * (num_epochs - warmup_epochs) + cosine_lr_schedule = final_lr+0.5 * \ + (base_lr-final_lr)*(1+np.cos(np.pi*np.arange(decay_iter)/decay_iter)) + + self.lr_schedule = np.concatenate( + (warmup_lr_schedule, cosine_lr_schedule)) + self.optimizer = optimizer + self.iter = 0 + self.current_lr = 0 + + def step(self): + for param_group in self.optimizer.param_groups: + + if self.constant_predictor_lr and param_group['name'] == 'predictor': + param_group['lr'] = self.base_lr + else: + lr = param_group['lr'] = self.lr_schedule[self.iter] + + self.iter += 1 + self.current_lr = lr + return lr + + def get_lr(self): + return self.current_lr + + +def get_optimizer(name, model, lr, momentum, weight_decay): + + predictor_prefix = ('module.predictor', 'predictor') + parameters = [{ + 'name': 'base', + 'params': [param for name, param in model.named_parameters() if not name.startswith(predictor_prefix)], + 'lr': lr + }, { + 'name': 'predictor', + 'params': [param for name, param in model.named_parameters() if name.startswith(predictor_prefix)], + 'lr': lr + }] + + if name == 'sgd': + optimizer = torch.optim.SGD( + parameters, lr=lr, momentum=momentum, weight_decay=weight_decay) + + return optimizer + + +def init_lrscheduler(model, total_epochs, dataloader): + warmup_epochs = 10 + warmup_lr = 0 + base_lr = 0.03 + final_lr = 0 + momentum = 0.9 + weight_decay = 0.0005 + batch_size = 64 + + optimizer = get_optimizer( + 'sgd', model, + lr=base_lr*batch_size/256, + momentum=momentum, + weight_decay=weight_decay) + + lr_scheduler = LR_Scheduler( + optimizer, warmup_epochs, warmup_lr*batch_size/256, + total_epochs, base_lr*batch_size/256, final_lr*batch_size/256, + len(dataloader), + constant_predictor_lr=True + ) + return optimizer, lr_scheduler diff --git a/examples/FedSimSiam/client/validate.py b/examples/FedSimSiam/client/validate.py new file mode 100644 index 000000000..fa329bf9c --- /dev/null +++ b/examples/FedSimSiam/client/validate.py @@ -0,0 +1,167 @@ +import os +import sys + +import torch +from torch import nn +from torch.utils.data import Dataset, DataLoader +import torchvision.transforms as transforms +import numpy as np +from PIL import Image + +from model import load_parameters +from data import load_data, load_knn_monitoring_dataset +from monitoring import * +from fedn.utils.helpers.helpers import save_metrics + +dir_path = os.path.dirname(os.path.realpath(__file__)) +sys.path.append(os.path.abspath(dir_path)) + + +class Cifar10(Dataset): + def __init__(self, x, y): + self.x = x + self.y = y + self.transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], # Approx. CIFAR-10 means + std=[0.247, 0.243, 0.261]) # Approx. CIFAR-10 std deviations + ]) + + def __len__(self): + return len(self.x) + + def __getitem__(self, idx): + x = self.x[idx] + x = Image.fromarray(x.astype(np.uint8)) + x = self.transform(x) + y = self.y[idx] + return x, y + + +class LinearEvaluationSimSiam(nn.Module): + def __init__(self, in_model_path): + super(LinearEvaluationSimSiam, self).__init__() + model = load_parameters(in_model_path) + device = torch.device( + 'cuda') if torch.cuda.is_available() else torch.device('cpu') + self.encoder = model.encoder.to(device) + + # freeze parameters + for param in self.encoder.parameters(): + param.requires_grad = False + + self.classifier = nn.Linear(2048, 10).to(device) + + def forward(self, x): + x = self.encoder(x) + x = self.classifier(x) + return x + + +def linear_evaluation(in_model_path, out_json_path, data_path=None, train_data_percentage=0.1, epochs=5): + model = LinearEvaluationSimSiam(in_model_path) + + device = torch.device( + 'cuda') if torch.cuda.is_available() else torch.device('cpu') + + x_train, y_train = load_data(data_path) + x_test, y_test = load_data(data_path, is_train=False) + + # for linear evaluation, train only on small subset of training data + n_training_data = train_data_percentage * len(x_train) + print("number of training points: ", n_training_data) + + x_train = x_train[:int(n_training_data)] + y_train = y_train[:int(n_training_data)] + print(len(x_train)) + + traindata = Cifar10(x_train, y_train) + trainloader = DataLoader(traindata, batch_size=4, shuffle=True) + + testdata = Cifar10(x_test, y_test) + testloader = DataLoader(testdata, batch_size=4, shuffle=False) + + criterion = nn.CrossEntropyLoss() + optimizer = torch.optim.Adam(model.classifier.parameters(), lr=0.001) + + model.encoder.eval() # this is linear evaluation, only train the classifier + model.classifier.train() + + for epoch in range(epochs): + correct = 0 + total = 0 + total_loss = 0.0 + for i, data in enumerate(trainloader): + inputs, labels = data[0].to(device), data[1].to(device) + optimizer.zero_grad() + + with torch.no_grad(): + features = model.encoder(inputs) + outputs = model.classifier(features) + + loss = criterion(outputs, labels) + print(loss) + loss.backward() + optimizer.step() + + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + total_loss += loss.item() * labels.size(0) + + training_accuracy = correct / total + print(f"Accuracy: {training_accuracy * 100:.2f}%") + + training_loss = total_loss / total + print("train loss: ", training_loss) + + # test on test_set + model.eval() + total_loss = 0.0 + correct_preds = 0 + total_samples = 0 + + with torch.no_grad(): + for i, data in enumerate(testloader): + inputs, labels = data[0].to(device), data[1].to(device) + outputs = model(inputs) + loss = criterion(outputs, labels) + + _, predicted = torch.max(outputs.data, 1) + total_loss += loss.item() * inputs.size(0) # Multiply by batch size + total_samples += labels.size(0) + correct_preds += (predicted == labels).sum().item() + + test_accuracy = correct_preds / total_samples + print(f"Test accuracy: {test_accuracy * 100:.2f}%") + + test_loss = total_loss / total_samples + print("test loss: ", test_loss) + + return training_loss, training_accuracy, test_loss, test_accuracy + + +def validate(in_model_path, out_json_path, data_path=None, train_data_percentage=1, epochs=3): + + memory_loader, test_loader = load_knn_monitoring_dataset(data_path) + + model = load_parameters(in_model_path) + device = torch.device( + 'cuda') if torch.cuda.is_available() else torch.device('cpu') + + knn_accuracy = knn_monitor(model.encoder, memory_loader, test_loader, device, k=min( + 25, len(memory_loader.dataset))) + + print("knn accuracy: ", knn_accuracy) + + # JSON schema + report = { + "knn_accuracy": knn_accuracy, + } + + # Save JSON + save_metrics(report, out_json_path) + + +if __name__ == "__main__": + validate(sys.argv[1], sys.argv[2]) diff --git a/examples/FedSimSiam/docker-compose.override.yaml b/examples/FedSimSiam/docker-compose.override.yaml new file mode 100644 index 000000000..524e39d1d --- /dev/null +++ b/examples/FedSimSiam/docker-compose.override.yaml @@ -0,0 +1,35 @@ +# Compose schema version +version: '3.4' + +# Overriding requirements + +x-env: &defaults + GET_HOSTS_FROM: dns + FEDN_PACKAGE_EXTRACT_DIR: package + FEDN_NUM_DATA_SPLITS: 2 + +services: + + client1: + extends: + file: ${HOST_REPO_DIR:-.}/docker-compose.yaml + service: client + environment: + <<: *defaults + FEDN_DATA_PATH: /app/package/client/data/clients/1/cifar10.pt + deploy: + replicas: 1 + volumes: + - ${HOST_REPO_DIR:-.}/fedn:/app/fedn + + client2: + extends: + file: ${HOST_REPO_DIR:-.}/docker-compose.yaml + service: client + environment: + <<: *defaults + FEDN_DATA_PATH: /app/package/client/data/clients/2/cifar10.pt + deploy: + replicas: 1 + volumes: + - ${HOST_REPO_DIR:-.}/fedn:/app/fedn