From deb25985bc869d4aa7e8d8e37500d70f8f03daaa Mon Sep 17 00:00:00 2001 From: mattiasakesson Date: Tue, 9 Apr 2024 17:52:25 +0200 Subject: [PATCH] clean script --- examples/mnist-pytorch-fedprox/client/entrypoint | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/examples/mnist-pytorch-fedprox/client/entrypoint b/examples/mnist-pytorch-fedprox/client/entrypoint index f8d00378e..3ddb5b156 100755 --- a/examples/mnist-pytorch-fedprox/client/entrypoint +++ b/examples/mnist-pytorch-fedprox/client/entrypoint @@ -146,13 +146,8 @@ def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1 mu = client_settings['mu'] - print("mu: ", mu) + print("mu value: ", mu) - print("data_path: ", data_path) - print(os.getcwd()) - print("list data path: ", os.listdir('/var/data')) - - print("list data/clients path: ", os.listdir('/var/data/clients')) # Load data x_train, y_train = load_data(data_path) @@ -176,12 +171,8 @@ def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1 proximal_term = 0.0 for w, w_t in zip(model.parameters(), global_model.parameters()): proximal_term += (w - w_t).norm(2) - #print("proximal_term: ", proximal_term) - - # loss = criterion(outputs, batch_y) # <-- old - # loss = loss_function(y_pred, label) + (args.mu / 2) * proximal_term <-- fed prox term - loss = criterion(outputs, batch_y) + (mu / 2) * proximal_term # <-- new + loss = criterion(outputs, batch_y) + (mu / 2) * proximal_term loss.backward() optimizer.step()