diff --git a/examples/FedSimSiam/client/train.py b/examples/FedSimSiam/client/train.py index 26c498955..0e7c565f6 100644 --- a/examples/FedSimSiam/client/train.py +++ b/examples/FedSimSiam/client/train.py @@ -89,7 +89,7 @@ def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1 trainset, batch_size=batch_size, shuffle=True) device = torch.device( - 'cuda') if torch.cuda.is_available() else torch.device('cpu') + "cuda") if torch.cuda.is_available() else torch.device("cpu") model = model.to(device) model.train() @@ -103,7 +103,7 @@ def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1 optimizer.zero_grad() data_dict = model.forward(images[0].to( device, non_blocking=True), images[1].to(device, non_blocking=True)) - loss = data_dict['loss'].mean() + loss = data_dict["loss"].mean() print(loss) loss.backward() optimizer.step() @@ -112,10 +112,10 @@ def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1 # Metadata needed for aggregation server side metadata = { # num_examples are mandatory - 'num_examples': len(x_train), - 'batch_size': batch_size, - 'epochs': epochs, - 'lr': lr + "num_examples": len(x_train), + "batch_size": batch_size, + "epochs": epochs, + "lr": lr } # Save JSON metadata file (mandatory)