From 9df6d95171204ba1ad1d88a84871c1a93a69840a Mon Sep 17 00:00:00 2001 From: Fredrik Wrede Date: Thu, 15 Feb 2024 13:53:59 +0000 Subject: [PATCH] update entrypoint file --- docs/tutorial.rst | 86 +++++++++++++++++++++++++++-------------------- 1 file changed, 49 insertions(+), 37 deletions(-) diff --git a/docs/tutorial.rst b/docs/tutorial.rst index 9fc0bf26e..916058774 100644 --- a/docs/tutorial.rst +++ b/docs/tutorial.rst @@ -78,9 +78,23 @@ A *entrypoint.py* example can look like this: from fedn.utils.helpers.helpers import get_helper, save_metadata, save_metrics HELPER_MODULE = 'numpyhelper' + helper = get_helper(HELPER_MODULE) + NUM_CLASSES = 10 - def _compile_model(): + + def _get_data_path(): + """ For test automation using docker-compose. """ + # Figure out FEDn client number from container name + client = docker.from_env() + container = client.containers.get(os.environ['HOSTNAME']) + number = container.name[-1] + + # Return data path + return f"/var/data/clients/{number}/mnist.pt" + + + def compile_model(): """ Compile the pytorch model. :return: The compiled model. @@ -100,12 +114,11 @@ A *entrypoint.py* example can look like this: x = torch.nn.functional.log_softmax(self.fc3(x), dim=1) return x - # Return model return Net() - def _load_data(data_path, is_train=True): - """ Load data from disk. + def load_data(data_path, is_train=True): + """ Load data from disk. :param data_path: Path to data file. :type data_path: str @@ -132,54 +145,52 @@ A *entrypoint.py* example can look like this: return X, y - def _save_model(model, out_path): - """ Save model to disk. + def save_parameters(model, out_path): + """ Save model paramters to file. - :param model: The model to save. + :param model: The model to serialize. :type model: torch.nn.Module :param out_path: The path to save to. :type out_path: str """ - weights = model.state_dict() - weights_np = collections.OrderedDict() - for w in weights: - weights_np[w] = weights[w].cpu().detach().numpy() - helper = get_helper(HELPER_MODULE) - helper.save(weights, out_path) + parameters_np = [val.cpu().numpy() for _, val in model.state_dict().items()] + helper.save(parameters_np, out_path) - def _load_model(model_path): - """ Load model from disk. + def load_parameters(model_path): + """ Load model parameters from file and populate model. param model_path: The path to load from. :type model_path: str :return: The loaded model. :rtype: torch.nn.Module """ - helper = get_helper(HELPER_MODULE) - weights_np = helper.load(model_path) - weights = collections.OrderedDict() - for w in weights_np: - weights[w] = torch.tensor(weights_np[w]) - model = _compile_model() - model.load_state_dict(weights) - model.eval() + model = compile_model() + parameters_np = helper.load(model_path) + + params_dict = zip(model.state_dict().keys(), parameters_np) + state_dict = collections.OrderedDict({key: torch.tensor(x) for key, x in params_dict}) + model.load_state_dict(state_dict, strict=True) return model def init_seed(out_path='seed.npz'): - """ Initialize seed model. + """ Initialize seed model and save it to file. :param out_path: The path to save the seed model to. :type out_path: str """ # Init and save - model = _compile_model() - _save_model(model, out_path) + model = compile_model() + save_parameters(model, out_path) def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1, lr=0.01): - """ Train model. + """ Complete a model update. + + Load model paramters from in_model_path (managed by the FEDn client), + perform a model update, and write updated paramters + to out_model_path (picked up by the FEDn client). :param in_model_path: The path to the input model. :type in_model_path: str @@ -195,10 +206,10 @@ A *entrypoint.py* example can look like this: :type lr: float """ # Load data - x_train, y_train = _load_data(data_path) + x_train, y_train = load_data(data_path) - # Load model - model = _load_model(in_model_path) + # Load parmeters and initialize model + model = load_parameters(in_model_path) # Train optimizer = torch.optim.SGD(model.parameters(), lr=lr) @@ -222,17 +233,18 @@ A *entrypoint.py* example can look like this: # Metadata needed for aggregation server side metadata = { + # num_examples are mandatory 'num_examples': len(x_train), 'batch_size': batch_size, 'epochs': epochs, 'lr': lr } - # Save JSON metadata file + # Save JSON metadata file (mandatory) save_metadata(metadata, out_model_path) - # Save model update - _save_model(model, out_model_path) + # Save model update (mandatory) + save_parameters(model, out_model_path) def validate(in_model_path, out_json_path, data_path=None): @@ -246,11 +258,12 @@ A *entrypoint.py* example can look like this: :type data_path: str """ # Load data - x_train, y_train = _load_data(data_path) - x_test, y_test = _load_data(data_path, is_train=False) + x_train, y_train = load_data(data_path) + x_test, y_test = load_data(data_path, is_train=False) # Load model - model = _load_model(in_model_path) + model = load_parameters(in_model_path) + model.eval() # Evaluate criterion = torch.nn.NLLLoss() @@ -281,7 +294,6 @@ A *entrypoint.py* example can look like this: 'init_seed': init_seed, 'train': train, 'validate': validate, - # '_get_data_path': _get_data_path, # for testing })