Skip to content

Commit

Permalink
clean script
Browse files Browse the repository at this point in the history
  • Loading branch information
mattiasakesson committed Apr 9, 2024
1 parent 858a62c commit deb2598
Showing 1 changed file with 2 additions and 11 deletions.
13 changes: 2 additions & 11 deletions examples/mnist-pytorch-fedprox/client/entrypoint
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,8 @@ def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1

mu = client_settings['mu']

print("mu: ", mu)
print("mu value: ", mu)

print("data_path: ", data_path)
print(os.getcwd())
print("list data path: ", os.listdir('/var/data'))

print("list data/clients path: ", os.listdir('/var/data/clients'))
# Load data
x_train, y_train = load_data(data_path)

Expand All @@ -176,12 +171,8 @@ def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1
proximal_term = 0.0
for w, w_t in zip(model.parameters(), global_model.parameters()):
proximal_term += (w - w_t).norm(2)
#print("proximal_term: ", proximal_term)

# loss = criterion(outputs, batch_y) # <-- old

# loss = loss_function(y_pred, label) + (args.mu / 2) * proximal_term <-- fed prox term
loss = criterion(outputs, batch_y) + (mu / 2) * proximal_term # <-- new
loss = criterion(outputs, batch_y) + (mu / 2) * proximal_term

loss.backward()
optimizer.step()
Expand Down

0 comments on commit deb2598

Please sign in to comment.