Skip to content

Commit

Permalink
update entrypoint file
Browse files Browse the repository at this point in the history
  • Loading branch information
Wrede committed Feb 15, 2024
1 parent d5bcfba commit 9df6d95
Showing 1 changed file with 49 additions and 37 deletions.
86 changes: 49 additions & 37 deletions docs/tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,23 @@ A *entrypoint.py* example can look like this:
from fedn.utils.helpers.helpers import get_helper, save_metadata, save_metrics
HELPER_MODULE = 'numpyhelper'
helper = get_helper(HELPER_MODULE)
NUM_CLASSES = 10
def _compile_model():
def _get_data_path():
""" For test automation using docker-compose. """
# Figure out FEDn client number from container name
client = docker.from_env()
container = client.containers.get(os.environ['HOSTNAME'])
number = container.name[-1]
# Return data path
return f"/var/data/clients/{number}/mnist.pt"
def compile_model():
""" Compile the pytorch model.
:return: The compiled model.
Expand All @@ -100,12 +114,11 @@ A *entrypoint.py* example can look like this:
x = torch.nn.functional.log_softmax(self.fc3(x), dim=1)
return x
# Return model
return Net()
def _load_data(data_path, is_train=True):
""" Load data from disk.
def load_data(data_path, is_train=True):
""" Load data from disk.
:param data_path: Path to data file.
:type data_path: str
Expand All @@ -132,54 +145,52 @@ A *entrypoint.py* example can look like this:
return X, y
def _save_model(model, out_path):
""" Save model to disk.
def save_parameters(model, out_path):
""" Save model paramters to file.
:param model: The model to save.
:param model: The model to serialize.
:type model: torch.nn.Module
:param out_path: The path to save to.
:type out_path: str
"""
weights = model.state_dict()
weights_np = collections.OrderedDict()
for w in weights:
weights_np[w] = weights[w].cpu().detach().numpy()
helper = get_helper(HELPER_MODULE)
helper.save(weights, out_path)
parameters_np = [val.cpu().numpy() for _, val in model.state_dict().items()]
helper.save(parameters_np, out_path)
def _load_model(model_path):
""" Load model from disk.
def load_parameters(model_path):
""" Load model parameters from file and populate model.
param model_path: The path to load from.
:type model_path: str
:return: The loaded model.
:rtype: torch.nn.Module
"""
helper = get_helper(HELPER_MODULE)
weights_np = helper.load(model_path)
weights = collections.OrderedDict()
for w in weights_np:
weights[w] = torch.tensor(weights_np[w])
model = _compile_model()
model.load_state_dict(weights)
model.eval()
model = compile_model()
parameters_np = helper.load(model_path)
params_dict = zip(model.state_dict().keys(), parameters_np)
state_dict = collections.OrderedDict({key: torch.tensor(x) for key, x in params_dict})
model.load_state_dict(state_dict, strict=True)
return model
def init_seed(out_path='seed.npz'):
""" Initialize seed model.
""" Initialize seed model and save it to file.
:param out_path: The path to save the seed model to.
:type out_path: str
"""
# Init and save
model = _compile_model()
_save_model(model, out_path)
model = compile_model()
save_parameters(model, out_path)
def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1, lr=0.01):
""" Train model.
""" Complete a model update.
Load model paramters from in_model_path (managed by the FEDn client),
perform a model update, and write updated paramters
to out_model_path (picked up by the FEDn client).
:param in_model_path: The path to the input model.
:type in_model_path: str
Expand All @@ -195,10 +206,10 @@ A *entrypoint.py* example can look like this:
:type lr: float
"""
# Load data
x_train, y_train = _load_data(data_path)
x_train, y_train = load_data(data_path)
# Load model
model = _load_model(in_model_path)
# Load parmeters and initialize model
model = load_parameters(in_model_path)
# Train
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
Expand All @@ -222,17 +233,18 @@ A *entrypoint.py* example can look like this:
# Metadata needed for aggregation server side
metadata = {
# num_examples are mandatory
'num_examples': len(x_train),
'batch_size': batch_size,
'epochs': epochs,
'lr': lr
}
# Save JSON metadata file
# Save JSON metadata file (mandatory)
save_metadata(metadata, out_model_path)
# Save model update
_save_model(model, out_model_path)
# Save model update (mandatory)
save_parameters(model, out_model_path)
def validate(in_model_path, out_json_path, data_path=None):
Expand All @@ -246,11 +258,12 @@ A *entrypoint.py* example can look like this:
:type data_path: str
"""
# Load data
x_train, y_train = _load_data(data_path)
x_test, y_test = _load_data(data_path, is_train=False)
x_train, y_train = load_data(data_path)
x_test, y_test = load_data(data_path, is_train=False)
# Load model
model = _load_model(in_model_path)
model = load_parameters(in_model_path)
model.eval()
# Evaluate
criterion = torch.nn.NLLLoss()
Expand Down Expand Up @@ -281,7 +294,6 @@ A *entrypoint.py* example can look like this:
'init_seed': init_seed,
'train': train,
'validate': validate,
# '_get_data_path': _get_data_path, # for testing
})
Expand Down

0 comments on commit 9df6d95

Please sign in to comment.