From 955f40c0f3af84ac0bff40c86a00f18b43fe5fe9 Mon Sep 17 00:00:00 2001 From: Jonas Date: Mon, 13 May 2024 13:07:25 +0200 Subject: [PATCH] ruff linting --- examples/FedSimSiam/client/data.py | 5 ++--- examples/FedSimSiam/client/model.py | 12 ++++++------ examples/FedSimSiam/client/monitoring.py | 18 ++++++++---------- examples/FedSimSiam/client/train.py | 12 ++++-------- examples/FedSimSiam/client/utils.py | 6 +++--- examples/FedSimSiam/client/validate.py | 13 ++++++------- 6 files changed, 29 insertions(+), 37 deletions(-) diff --git a/examples/FedSimSiam/client/data.py b/examples/FedSimSiam/client/data.py index f18ea1f32..0ed185274 100644 --- a/examples/FedSimSiam/client/data.py +++ b/examples/FedSimSiam/client/data.py @@ -1,11 +1,10 @@ import os from math import floor +import numpy as np import torch import torchvision -import torchvision.transforms as transforms -import numpy as np -import os +from torchvision import transforms dir_path = os.path.dirname(os.path.realpath(__file__)) abs_path = os.path.abspath(dir_path) diff --git a/examples/FedSimSiam/client/model.py b/examples/FedSimSiam/client/model.py index 7ba59b426..c926000e7 100644 --- a/examples/FedSimSiam/client/model.py +++ b/examples/FedSimSiam/client/model.py @@ -1,9 +1,9 @@ -import torch.nn.functional as F -from torchvision.models import resnet18 -import torch.nn as nn import collections import torch +import torch.nn.functional as f +from torch import nn +from torchvision.models import resnet18 from fedn.utils.helpers.helpers import get_helper @@ -14,12 +14,12 @@ def D(p, z, version='simplified'): # negative cosine similarity if version == 'original': z = z.detach() # stop gradient - p = F.normalize(p, dim=1) # l2-normalize - z = F.normalize(z, dim=1) # l2-normalize + p = f.normalize(p, dim=1) # l2-normalize + z = f.normalize(z, dim=1) # l2-normalize return -(p*z).sum(dim=1).mean() elif version == 'simplified': # same thing, much faster. Scroll down, speed test in __main__ - return - F.cosine_similarity(p, z.detach(), dim=-1).mean() + return - f.cosine_similarity(p, z.detach(), dim=-1).mean() else: raise Exception diff --git a/examples/FedSimSiam/client/monitoring.py b/examples/FedSimSiam/client/monitoring.py index 9e1f83339..245e7f308 100644 --- a/examples/FedSimSiam/client/monitoring.py +++ b/examples/FedSimSiam/client/monitoring.py @@ -1,23 +1,22 @@ -""" -knn monitor as in InstDisc https://arxiv.org/abs/1805.01978. +""" knn monitor as in InstDisc https://arxiv.org/abs/1805.01978. This implementation follows http://github.com/zhirongw/lemniscate.pytorch and https://github.com/leftthomas/SimCLR """ -import torch.nn.functional as F import torch +import torch.nn.functional as f -def knn_monitor(net, memory_data_loader, test_data_loader, epoch, k=200, t=0.1, hide_progress=False): +def knn_monitor(net, memory_data_loader, test_data_loader, k=200, t=0.1): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = net.to(device) net.eval() classes = len(memory_data_loader.dataset.classes) - total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0, [] + total_top1, total_num, feature_bank = 0.0, 0, [] with torch.no_grad(): # generate feature bank for data, target in memory_data_loader: # feature = net(data.cuda(non_blocking=True)) feature = net(data.to(device)) - feature = F.normalize(feature, dim=1) + feature = f.normalize(feature, dim=1) feature_bank.append(feature) # [D, N] feature_bank = torch.cat(feature_bank, dim=0).t().contiguous() @@ -26,11 +25,10 @@ def knn_monitor(net, memory_data_loader, test_data_loader, epoch, k=200, t=0.1, memory_data_loader.dataset.targets, device=feature_bank.device) # loop test data to predict the label by weighted knn search for data, target in test_data_loader: - # data, target = data.cuda( - # non_blocking=True), target.cuda(non_blocking=True) - data, target = data.to(device), target.to(device) + data, target = data.cuda( + non_blocking=True), target.cuda(non_blocking=True) feature = net(data) - feature = F.normalize(feature, dim=1) + feature = f.normalize(feature, dim=1) pred_labels = knn_predict( feature, feature_bank, feature_labels, classes, k, t) diff --git a/examples/FedSimSiam/client/train.py b/examples/FedSimSiam/client/train.py index 54536de61..26c498955 100644 --- a/examples/FedSimSiam/client/train.py +++ b/examples/FedSimSiam/client/train.py @@ -1,17 +1,13 @@ import os import sys -import torch -from torch.utils.data import Dataset -from torchvision import transforms -from PIL import Image -import torch -from torch.utils.data import DataLoader import numpy as np -import torch.optim as optim - +import torch from data import load_data from model import load_parameters, save_parameters +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms from utils import init_lrscheduler from fedn.utils.helpers.helpers import save_metadata diff --git a/examples/FedSimSiam/client/utils.py b/examples/FedSimSiam/client/utils.py index bfb8be88c..55b0d9cad 100644 --- a/examples/FedSimSiam/client/utils.py +++ b/examples/FedSimSiam/client/utils.py @@ -1,8 +1,8 @@ -import torch import numpy as np +import torch -class LR_Scheduler(object): +class LrScheduler(object): def __init__(self, optimizer, warmup_epochs, warmup_lr, num_epochs, base_lr, final_lr, iter_per_epoch, constant_predictor_lr=False): self.base_lr = base_lr self.constant_predictor_lr = constant_predictor_lr @@ -69,7 +69,7 @@ def init_lrscheduler(model, total_epochs, dataloader): momentum=momentum, weight_decay=weight_decay) - lr_scheduler = LR_Scheduler( + lr_scheduler = LrScheduler( optimizer, warmup_epochs, warmup_lr*batch_size/256, total_epochs, base_lr*batch_size/256, final_lr*batch_size/256, len(dataloader), diff --git a/examples/FedSimSiam/client/validate.py b/examples/FedSimSiam/client/validate.py index 74855d1db..47d7026c0 100644 --- a/examples/FedSimSiam/client/validate.py +++ b/examples/FedSimSiam/client/validate.py @@ -1,16 +1,15 @@ import os import sys -import torch -from torch import nn -from torch.utils.data import Dataset, DataLoader -import torchvision.transforms as transforms import numpy as np +import torch +from data import load_knn_monitoring_dataset +from model import load_parameters +from monitoring import knn_monitor from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms -from model import load_parameters -from data import load_data, load_knn_monitoring_dataset -from monitoring import * from fedn.utils.helpers.helpers import save_metrics dir_path = os.path.dirname(os.path.realpath(__file__))