Skip to content

Commit

Permalink
initial fedsimsiam code
Browse files Browse the repository at this point in the history
  • Loading branch information
FrankJonasmoelle committed May 11, 2024
1 parent fb76c47 commit a976201
Show file tree
Hide file tree
Showing 11 changed files with 805 additions and 0 deletions.
4 changes: 4 additions & 0 deletions examples/FedSimSiam/.dockerignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
data
seed.npz
*.tgz
*.tar.gz
6 changes: 6 additions & 0 deletions examples/FedSimSiam/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
data
*.npz
*.tgz
*.tar.gz
.fedsimsiam
client.yaml
151 changes: 151 additions & 0 deletions examples/FedSimSiam/client/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import os
from math import floor

import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import os

dir_path = os.path.dirname(os.path.realpath(__file__))
abs_path = os.path.abspath(dir_path)


def get_data(out_dir='data'):
# Make dir if necessary
if not os.path.exists(out_dir):
os.mkdir(out_dir)

# Only download if not already downloaded
if not os.path.exists(f'{out_dir}/train'):
torchvision.datasets.CIFAR10(
root=f'{out_dir}/train', train=True, download=True)

if not os.path.exists(f'{out_dir}/test'):
torchvision.datasets.CIFAR10(
root=f'{out_dir}/test', train=False, download=True)


def load_data(data_path, is_train=True):
""" Load data from disk.
:param data_path: Path to data file.
:type data_path: str
:param is_train: Whether to load training or test data.
:type is_train: bool
:return: Tuple of data and labels.
:rtype: tuple
"""
if data_path is None:
data_path = os.environ.get(
"FEDN_DATA_PATH", abs_path+'/data/clients/1/cifar10.pt')

data = torch.load(data_path)

if is_train:
X = data['x_train']
y = data['y_train']
else:
X = data['x_test']
y = data['y_test']

return X, y


def create_knn_monitoring_dataset(out_dir='data'):
""" Creates dataset that is used to monitor the training progress via knn accuracies """
if not os.path.exists(out_dir):
os.mkdir(out_dir)

n_splits = int(os.environ.get("FEDN_NUM_DATA_SPLITS", 2))

# Make dir
if not os.path.exists(f'{out_dir}/clients'):
os.mkdir(f'{out_dir}/clients')

normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
std=[0.247, 0.243, 0.261])

memoryset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transforms.Compose([transforms.ToTensor(), normalize]))
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transforms.Compose([transforms.ToTensor(), normalize]))

# save monitoring datasets to all clients
for i in range(n_splits):
subdir = f'{out_dir}/clients/{str(i+1)}'
if not os.path.exists(subdir):
os.mkdir(subdir)
torch.save(memoryset, f'{subdir}/knn_memoryset.pt')
torch.save(testset, f'{subdir}/knn_testset.pt')


def load_knn_monitoring_dataset(data_path, batch_size=16):
""" Loads the KNN monitoring dataset."""
if data_path is None:
data_path = os.environ.get(
"FEDN_DATA_PATH", abs_path+'/data/clients/1/cifar10.pt')

data_directory = os.path.dirname(data_path)
memory_path = os.path.join(data_directory, 'knn_memoryset.pt')
testset_path = os.path.join(data_directory, 'knn_testset.pt')

memoryset = torch.load(memory_path)
testset = torch.load(testset_path)

memoryset_loader = torch.utils.data.DataLoader(
memoryset, batch_size=batch_size, shuffle=False)
testset_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
shuffle=False)
return memoryset_loader, testset_loader


def splitset(dataset, parts):
n = dataset.shape[0]
local_n = floor(n/parts)
result = []
for i in range(parts):
result.append(dataset[i*local_n: (i+1)*local_n])
return result


def split(out_dir='data'):

n_splits = int(os.environ.get("FEDN_NUM_DATA_SPLITS", 2))

# Make dir
if not os.path.exists(f'{out_dir}/clients'):
os.mkdir(f'{out_dir}/clients')

train_data = torchvision.datasets.CIFAR10(
root=f'{out_dir}/train', train=True)
test_data = torchvision.datasets.CIFAR10(
root=f'{out_dir}/test', train=False)

data = {
'x_train': splitset(train_data.data, n_splits),
'y_train': splitset(np.array(train_data.targets), n_splits),
'x_test': splitset(test_data.data, n_splits),
'y_test': splitset(np.array(test_data.targets), n_splits),
}

# Make splits
for i in range(n_splits):
subdir = f'{out_dir}/clients/{str(i+1)}'
if not os.path.exists(subdir):
os.mkdir(subdir)
torch.save({
'x_train': data['x_train'][i],
'y_train': data['y_train'][i],
'x_test': data['x_test'][i],
'y_test': data['y_test'][i],
},
f'{subdir}/cifar10.pt')


if __name__ == '__main__':
# Prepare data if not already done
if not os.path.exists(abs_path+'/data/clients/1'):
get_data()
split()
create_knn_monitoring_dataset()
10 changes: 10 additions & 0 deletions examples/FedSimSiam/client/fedn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
python_env: python_env.yaml
entry_points:
build:
command: python model.py
startup:
command: python data.py
train:
command: python train.py
validate:
command: python validate.py
144 changes: 144 additions & 0 deletions examples/FedSimSiam/client/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import torch.nn.functional as F
from torchvision.models import resnet18
import torch.nn as nn
import collections

import torch

from fedn.utils.helpers.helpers import get_helper

HELPER_MODULE = 'numpyhelper'
helper = get_helper(HELPER_MODULE)


def D(p, z, version='simplified'): # negative cosine similarity
if version == 'original':
z = z.detach() # stop gradient
p = F.normalize(p, dim=1) # l2-normalize
z = F.normalize(z, dim=1) # l2-normalize
return -(p*z).sum(dim=1).mean()

elif version == 'simplified': # same thing, much faster. Scroll down, speed test in __main__
return - F.cosine_similarity(p, z.detach(), dim=-1).mean()
else:
raise Exception


class ProjectionMLP(nn.Module):
"""Projection MLP f"""

def __init__(self, in_features, h1_features, h2_features, out_features):
super(ProjectionMLP, self).__init__()
self.l1 = nn.Sequential(
nn.Linear(in_features, h1_features),
nn.BatchNorm1d(h1_features),
nn.ReLU(inplace=True)
)
self.l2 = nn.Sequential(
nn.Linear(h1_features, out_features),
nn.BatchNorm1d(out_features)
)

def forward(self, x):
x = self.l1(x)
x = self.l2(x)
return x


class PredictionMLP(nn.Module):
"""Prediction MLP h"""

def __init__(self, in_features, hidden_features, out_features):
super(PredictionMLP, self).__init__()
self.l1 = nn.Sequential(
nn.Linear(in_features, hidden_features),
nn.BatchNorm1d(hidden_features),
nn.ReLU(inplace=True)
)
self.l2 = nn.Linear(hidden_features, out_features)

def forward(self, x):
x = self.l1(x)
x = self.l2(x)
return x


class SimSiam(nn.Module):
def __init__(self):
super(SimSiam, self).__init__()
backbone = resnet18(pretrained=False)
backbone.output_dim = backbone.fc.in_features
backbone.fc = torch.nn.Identity()

self.backbone = backbone

self.projector = ProjectionMLP(backbone.output_dim, 2048, 2048, 2048)
self.encoder = nn.Sequential(
self.backbone,
self.projector
)
self.predictor = PredictionMLP(2048, 512, 2048)

def forward(self, x1, x2):
f, h = self.encoder, self.predictor
z1, z2 = f(x1), f(x2)
p1, p2 = h(z1), h(z2)
L = D(p1, z2) / 2 + D(p2, z1) / 2
return {'loss': L}


def compile_model():
""" Compile the pytorch model.
:return: The compiled model.
:rtype: torch.nn.Module
"""
model = SimSiam()

return model


def save_parameters(model, out_path):
""" Save model paramters to file.
:param model: The model to serialize.
:type model: torch.nn.Module
:param out_path: The path to save to.
:type out_path: str
"""
parameters_np = [val.cpu().numpy()
for _, val in model.state_dict().items()]
helper.save(parameters_np, out_path)


def load_parameters(model_path):
""" Load model parameters from file and populate model.
param model_path: The path to load from.
:type model_path: str
:return: The loaded model.
:rtype: torch.nn.Module
"""
model = compile_model()
parameters_np = helper.load(model_path)

params_dict = zip(model.state_dict().keys(), parameters_np)
state_dict = collections.OrderedDict(
{key: torch.tensor(x) for key, x in params_dict})
model.load_state_dict(state_dict, strict=True)
return model


def init_seed(out_path='seed.npz'):
""" Initialize seed model and save it to file.
:param out_path: The path to save the seed model to.
:type out_path: str
"""
# Init and save
model = compile_model()
save_parameters(model, out_path)


if __name__ == "__main__":
init_seed('../seed.npz')
63 changes: 63 additions & 0 deletions examples/FedSimSiam/client/monitoring.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
""" knn monitor as in InstDisc https://arxiv.org/abs/1805.01978
implementation follows http://github.com/zhirongw/lemniscate.pytorch and https://github.com/leftthomas/SimCLR
"""
import torch.nn.functional as F
import torch


def knn_monitor(net, memory_data_loader, test_data_loader, epoch, k=200, t=0.1, hide_progress=False):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = net.to(device)
net.eval()
classes = len(memory_data_loader.dataset.classes)
total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0, []
with torch.no_grad():
# generate feature bank
for data, target in memory_data_loader:
# feature = net(data.cuda(non_blocking=True))
feature = net(data.to(device))
feature = F.normalize(feature, dim=1)
feature_bank.append(feature)
# [D, N]
feature_bank = torch.cat(feature_bank, dim=0).t().contiguous()
# [N]
feature_labels = torch.tensor(
memory_data_loader.dataset.targets, device=feature_bank.device)
# loop test data to predict the label by weighted knn search
for data, target in test_data_loader:
# data, target = data.cuda(
# non_blocking=True), target.cuda(non_blocking=True)
data, target = data.to(device), target.to(device)
feature = net(data)
feature = F.normalize(feature, dim=1)

pred_labels = knn_predict(
feature, feature_bank, feature_labels, classes, k, t)

total_num += data.size(0)
total_top1 += (pred_labels[:, 0] == target).float().sum().item()
return total_top1 / total_num # * 100


def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t):
# compute cos similarity between each feature vector and feature bank ---> [B, N]
sim_matrix = torch.mm(feature, feature_bank)
# [B, K]
sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1)
# [B, K]
sim_labels = torch.gather(feature_labels.expand(
feature.size(0), -1), dim=-1, index=sim_indices)
sim_weight = (sim_weight / knn_t).exp()

# counts for each class
one_hot_label = torch.zeros(feature.size(
0) * knn_k, classes, device=sim_labels.device)
# [B*K, C]
one_hot_label = one_hot_label.scatter(
dim=-1, index=sim_labels.view(-1, 1), value=1.0)
# weighted score ---> [B, C]
pred_scores = torch.sum(one_hot_label.view(feature.size(
0), -1, classes) * sim_weight.unsqueeze(dim=-1), dim=1)

pred_labels = pred_scores.argsort(dim=-1, descending=True)
return pred_labels
9 changes: 9 additions & 0 deletions examples/FedSimSiam/client/python_env.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
name: fedsimsiam
build_dependencies:
- pip
- setuptools
- wheel==0.37.1
dependencies:
- torch==2.2.1
- torchvision==0.17.1
- fedn==0.9.0
Loading

0 comments on commit a976201

Please sign in to comment.