Skip to content

Commit

Permalink
fixing quotes in train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
FrankJonasmoelle committed May 13, 2024
1 parent 9b79822 commit bfc544a
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions examples/FedSimSiam/client/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1
trainset, batch_size=batch_size, shuffle=True)

device = torch.device(
'cuda') if torch.cuda.is_available() else torch.device('cpu')
"cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)

model.train()
Expand All @@ -103,7 +103,7 @@ def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1
optimizer.zero_grad()
data_dict = model.forward(images[0].to(
device, non_blocking=True), images[1].to(device, non_blocking=True))
loss = data_dict['loss'].mean()
loss = data_dict["loss"].mean()
print(loss)
loss.backward()
optimizer.step()
Expand All @@ -112,10 +112,10 @@ def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1
# Metadata needed for aggregation server side
metadata = {
# num_examples are mandatory
'num_examples': len(x_train),
'batch_size': batch_size,
'epochs': epochs,
'lr': lr
"num_examples": len(x_train),
"batch_size": batch_size,
"epochs": epochs,
"lr": lr
}

# Save JSON metadata file (mandatory)
Expand Down

0 comments on commit bfc544a

Please sign in to comment.