Skip to content

Commit

Permalink
ruff linting
Browse files Browse the repository at this point in the history
  • Loading branch information
FrankJonasmoelle committed May 13, 2024
1 parent d31cb6a commit 955f40c
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 37 deletions.
5 changes: 2 additions & 3 deletions examples/FedSimSiam/client/data.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import os
from math import floor

import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import os
from torchvision import transforms

dir_path = os.path.dirname(os.path.realpath(__file__))
abs_path = os.path.abspath(dir_path)
Expand Down
12 changes: 6 additions & 6 deletions examples/FedSimSiam/client/model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import torch.nn.functional as F
from torchvision.models import resnet18
import torch.nn as nn
import collections

import torch
import torch.nn.functional as f
from torch import nn
from torchvision.models import resnet18

from fedn.utils.helpers.helpers import get_helper

Expand All @@ -14,12 +14,12 @@
def D(p, z, version='simplified'): # negative cosine similarity
if version == 'original':
z = z.detach() # stop gradient
p = F.normalize(p, dim=1) # l2-normalize
z = F.normalize(z, dim=1) # l2-normalize
p = f.normalize(p, dim=1) # l2-normalize
z = f.normalize(z, dim=1) # l2-normalize
return -(p*z).sum(dim=1).mean()

elif version == 'simplified': # same thing, much faster. Scroll down, speed test in __main__
return - F.cosine_similarity(p, z.detach(), dim=-1).mean()
return - f.cosine_similarity(p, z.detach(), dim=-1).mean()
else:
raise Exception

Expand Down
18 changes: 8 additions & 10 deletions examples/FedSimSiam/client/monitoring.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
"""
knn monitor as in InstDisc https://arxiv.org/abs/1805.01978.
""" knn monitor as in InstDisc https://arxiv.org/abs/1805.01978.
This implementation follows http://github.com/zhirongw/lemniscate.pytorch and https://github.com/leftthomas/SimCLR
"""
import torch.nn.functional as F
import torch
import torch.nn.functional as f


def knn_monitor(net, memory_data_loader, test_data_loader, epoch, k=200, t=0.1, hide_progress=False):
def knn_monitor(net, memory_data_loader, test_data_loader, k=200, t=0.1):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = net.to(device)
net.eval()
classes = len(memory_data_loader.dataset.classes)
total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0, []
total_top1, total_num, feature_bank = 0.0, 0, []
with torch.no_grad():
# generate feature bank
for data, target in memory_data_loader:
# feature = net(data.cuda(non_blocking=True))
feature = net(data.to(device))
feature = F.normalize(feature, dim=1)
feature = f.normalize(feature, dim=1)
feature_bank.append(feature)
# [D, N]
feature_bank = torch.cat(feature_bank, dim=0).t().contiguous()
Expand All @@ -26,11 +25,10 @@ def knn_monitor(net, memory_data_loader, test_data_loader, epoch, k=200, t=0.1,
memory_data_loader.dataset.targets, device=feature_bank.device)
# loop test data to predict the label by weighted knn search
for data, target in test_data_loader:
# data, target = data.cuda(
# non_blocking=True), target.cuda(non_blocking=True)
data, target = data.to(device), target.to(device)
data, target = data.cuda(
non_blocking=True), target.cuda(non_blocking=True)
feature = net(data)
feature = F.normalize(feature, dim=1)
feature = f.normalize(feature, dim=1)

pred_labels = knn_predict(
feature, feature_bank, feature_labels, classes, k, t)
Expand Down
12 changes: 4 additions & 8 deletions examples/FedSimSiam/client/train.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
import os
import sys

import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import torch
from torch.utils.data import DataLoader
import numpy as np
import torch.optim as optim

import torch
from data import load_data
from model import load_parameters, save_parameters
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from utils import init_lrscheduler

from fedn.utils.helpers.helpers import save_metadata
Expand Down
6 changes: 3 additions & 3 deletions examples/FedSimSiam/client/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch
import numpy as np
import torch


class LR_Scheduler(object):
class LrScheduler(object):
def __init__(self, optimizer, warmup_epochs, warmup_lr, num_epochs, base_lr, final_lr, iter_per_epoch, constant_predictor_lr=False):
self.base_lr = base_lr
self.constant_predictor_lr = constant_predictor_lr
Expand Down Expand Up @@ -69,7 +69,7 @@ def init_lrscheduler(model, total_epochs, dataloader):
momentum=momentum,
weight_decay=weight_decay)

lr_scheduler = LR_Scheduler(
lr_scheduler = LrScheduler(
optimizer, warmup_epochs, warmup_lr*batch_size/256,
total_epochs, base_lr*batch_size/256, final_lr*batch_size/256,
len(dataloader),
Expand Down
13 changes: 6 additions & 7 deletions examples/FedSimSiam/client/validate.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import os
import sys

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import numpy as np
import torch
from data import load_knn_monitoring_dataset
from model import load_parameters
from monitoring import knn_monitor
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

from model import load_parameters
from data import load_data, load_knn_monitoring_dataset
from monitoring import *
from fedn.utils.helpers.helpers import save_metrics

dir_path = os.path.dirname(os.path.realpath(__file__))
Expand Down

0 comments on commit 955f40c

Please sign in to comment.