-
Notifications
You must be signed in to change notification settings - Fork 35
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
fb76c47
commit a976201
Showing
11 changed files
with
805 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
data | ||
seed.npz | ||
*.tgz | ||
*.tar.gz |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
data | ||
*.npz | ||
*.tgz | ||
*.tar.gz | ||
.fedsimsiam | ||
client.yaml |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
import os | ||
from math import floor | ||
|
||
import torch | ||
import torchvision | ||
import torchvision.transforms as transforms | ||
import numpy as np | ||
import os | ||
|
||
dir_path = os.path.dirname(os.path.realpath(__file__)) | ||
abs_path = os.path.abspath(dir_path) | ||
|
||
|
||
def get_data(out_dir='data'): | ||
# Make dir if necessary | ||
if not os.path.exists(out_dir): | ||
os.mkdir(out_dir) | ||
|
||
# Only download if not already downloaded | ||
if not os.path.exists(f'{out_dir}/train'): | ||
torchvision.datasets.CIFAR10( | ||
root=f'{out_dir}/train', train=True, download=True) | ||
|
||
if not os.path.exists(f'{out_dir}/test'): | ||
torchvision.datasets.CIFAR10( | ||
root=f'{out_dir}/test', train=False, download=True) | ||
|
||
|
||
def load_data(data_path, is_train=True): | ||
""" Load data from disk. | ||
:param data_path: Path to data file. | ||
:type data_path: str | ||
:param is_train: Whether to load training or test data. | ||
:type is_train: bool | ||
:return: Tuple of data and labels. | ||
:rtype: tuple | ||
""" | ||
if data_path is None: | ||
data_path = os.environ.get( | ||
"FEDN_DATA_PATH", abs_path+'/data/clients/1/cifar10.pt') | ||
|
||
data = torch.load(data_path) | ||
|
||
if is_train: | ||
X = data['x_train'] | ||
y = data['y_train'] | ||
else: | ||
X = data['x_test'] | ||
y = data['y_test'] | ||
|
||
return X, y | ||
|
||
|
||
def create_knn_monitoring_dataset(out_dir='data'): | ||
""" Creates dataset that is used to monitor the training progress via knn accuracies """ | ||
if not os.path.exists(out_dir): | ||
os.mkdir(out_dir) | ||
|
||
n_splits = int(os.environ.get("FEDN_NUM_DATA_SPLITS", 2)) | ||
|
||
# Make dir | ||
if not os.path.exists(f'{out_dir}/clients'): | ||
os.mkdir(f'{out_dir}/clients') | ||
|
||
normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], | ||
std=[0.247, 0.243, 0.261]) | ||
|
||
memoryset = torchvision.datasets.CIFAR10(root='./data', train=True, | ||
download=True, transform=transforms.Compose([transforms.ToTensor(), normalize])) | ||
testset = torchvision.datasets.CIFAR10(root='./data', train=False, | ||
download=True, transform=transforms.Compose([transforms.ToTensor(), normalize])) | ||
|
||
# save monitoring datasets to all clients | ||
for i in range(n_splits): | ||
subdir = f'{out_dir}/clients/{str(i+1)}' | ||
if not os.path.exists(subdir): | ||
os.mkdir(subdir) | ||
torch.save(memoryset, f'{subdir}/knn_memoryset.pt') | ||
torch.save(testset, f'{subdir}/knn_testset.pt') | ||
|
||
|
||
def load_knn_monitoring_dataset(data_path, batch_size=16): | ||
""" Loads the KNN monitoring dataset.""" | ||
if data_path is None: | ||
data_path = os.environ.get( | ||
"FEDN_DATA_PATH", abs_path+'/data/clients/1/cifar10.pt') | ||
|
||
data_directory = os.path.dirname(data_path) | ||
memory_path = os.path.join(data_directory, 'knn_memoryset.pt') | ||
testset_path = os.path.join(data_directory, 'knn_testset.pt') | ||
|
||
memoryset = torch.load(memory_path) | ||
testset = torch.load(testset_path) | ||
|
||
memoryset_loader = torch.utils.data.DataLoader( | ||
memoryset, batch_size=batch_size, shuffle=False) | ||
testset_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, | ||
shuffle=False) | ||
return memoryset_loader, testset_loader | ||
|
||
|
||
def splitset(dataset, parts): | ||
n = dataset.shape[0] | ||
local_n = floor(n/parts) | ||
result = [] | ||
for i in range(parts): | ||
result.append(dataset[i*local_n: (i+1)*local_n]) | ||
return result | ||
|
||
|
||
def split(out_dir='data'): | ||
|
||
n_splits = int(os.environ.get("FEDN_NUM_DATA_SPLITS", 2)) | ||
|
||
# Make dir | ||
if not os.path.exists(f'{out_dir}/clients'): | ||
os.mkdir(f'{out_dir}/clients') | ||
|
||
train_data = torchvision.datasets.CIFAR10( | ||
root=f'{out_dir}/train', train=True) | ||
test_data = torchvision.datasets.CIFAR10( | ||
root=f'{out_dir}/test', train=False) | ||
|
||
data = { | ||
'x_train': splitset(train_data.data, n_splits), | ||
'y_train': splitset(np.array(train_data.targets), n_splits), | ||
'x_test': splitset(test_data.data, n_splits), | ||
'y_test': splitset(np.array(test_data.targets), n_splits), | ||
} | ||
|
||
# Make splits | ||
for i in range(n_splits): | ||
subdir = f'{out_dir}/clients/{str(i+1)}' | ||
if not os.path.exists(subdir): | ||
os.mkdir(subdir) | ||
torch.save({ | ||
'x_train': data['x_train'][i], | ||
'y_train': data['y_train'][i], | ||
'x_test': data['x_test'][i], | ||
'y_test': data['y_test'][i], | ||
}, | ||
f'{subdir}/cifar10.pt') | ||
|
||
|
||
if __name__ == '__main__': | ||
# Prepare data if not already done | ||
if not os.path.exists(abs_path+'/data/clients/1'): | ||
get_data() | ||
split() | ||
create_knn_monitoring_dataset() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
python_env: python_env.yaml | ||
entry_points: | ||
build: | ||
command: python model.py | ||
startup: | ||
command: python data.py | ||
train: | ||
command: python train.py | ||
validate: | ||
command: python validate.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
import torch.nn.functional as F | ||
from torchvision.models import resnet18 | ||
import torch.nn as nn | ||
import collections | ||
|
||
import torch | ||
|
||
from fedn.utils.helpers.helpers import get_helper | ||
|
||
HELPER_MODULE = 'numpyhelper' | ||
helper = get_helper(HELPER_MODULE) | ||
|
||
|
||
def D(p, z, version='simplified'): # negative cosine similarity | ||
if version == 'original': | ||
z = z.detach() # stop gradient | ||
p = F.normalize(p, dim=1) # l2-normalize | ||
z = F.normalize(z, dim=1) # l2-normalize | ||
return -(p*z).sum(dim=1).mean() | ||
|
||
elif version == 'simplified': # same thing, much faster. Scroll down, speed test in __main__ | ||
return - F.cosine_similarity(p, z.detach(), dim=-1).mean() | ||
else: | ||
raise Exception | ||
|
||
|
||
class ProjectionMLP(nn.Module): | ||
"""Projection MLP f""" | ||
|
||
def __init__(self, in_features, h1_features, h2_features, out_features): | ||
super(ProjectionMLP, self).__init__() | ||
self.l1 = nn.Sequential( | ||
nn.Linear(in_features, h1_features), | ||
nn.BatchNorm1d(h1_features), | ||
nn.ReLU(inplace=True) | ||
) | ||
self.l2 = nn.Sequential( | ||
nn.Linear(h1_features, out_features), | ||
nn.BatchNorm1d(out_features) | ||
) | ||
|
||
def forward(self, x): | ||
x = self.l1(x) | ||
x = self.l2(x) | ||
return x | ||
|
||
|
||
class PredictionMLP(nn.Module): | ||
"""Prediction MLP h""" | ||
|
||
def __init__(self, in_features, hidden_features, out_features): | ||
super(PredictionMLP, self).__init__() | ||
self.l1 = nn.Sequential( | ||
nn.Linear(in_features, hidden_features), | ||
nn.BatchNorm1d(hidden_features), | ||
nn.ReLU(inplace=True) | ||
) | ||
self.l2 = nn.Linear(hidden_features, out_features) | ||
|
||
def forward(self, x): | ||
x = self.l1(x) | ||
x = self.l2(x) | ||
return x | ||
|
||
|
||
class SimSiam(nn.Module): | ||
def __init__(self): | ||
super(SimSiam, self).__init__() | ||
backbone = resnet18(pretrained=False) | ||
backbone.output_dim = backbone.fc.in_features | ||
backbone.fc = torch.nn.Identity() | ||
|
||
self.backbone = backbone | ||
|
||
self.projector = ProjectionMLP(backbone.output_dim, 2048, 2048, 2048) | ||
self.encoder = nn.Sequential( | ||
self.backbone, | ||
self.projector | ||
) | ||
self.predictor = PredictionMLP(2048, 512, 2048) | ||
|
||
def forward(self, x1, x2): | ||
f, h = self.encoder, self.predictor | ||
z1, z2 = f(x1), f(x2) | ||
p1, p2 = h(z1), h(z2) | ||
L = D(p1, z2) / 2 + D(p2, z1) / 2 | ||
return {'loss': L} | ||
|
||
|
||
def compile_model(): | ||
""" Compile the pytorch model. | ||
:return: The compiled model. | ||
:rtype: torch.nn.Module | ||
""" | ||
model = SimSiam() | ||
|
||
return model | ||
|
||
|
||
def save_parameters(model, out_path): | ||
""" Save model paramters to file. | ||
:param model: The model to serialize. | ||
:type model: torch.nn.Module | ||
:param out_path: The path to save to. | ||
:type out_path: str | ||
""" | ||
parameters_np = [val.cpu().numpy() | ||
for _, val in model.state_dict().items()] | ||
helper.save(parameters_np, out_path) | ||
|
||
|
||
def load_parameters(model_path): | ||
""" Load model parameters from file and populate model. | ||
param model_path: The path to load from. | ||
:type model_path: str | ||
:return: The loaded model. | ||
:rtype: torch.nn.Module | ||
""" | ||
model = compile_model() | ||
parameters_np = helper.load(model_path) | ||
|
||
params_dict = zip(model.state_dict().keys(), parameters_np) | ||
state_dict = collections.OrderedDict( | ||
{key: torch.tensor(x) for key, x in params_dict}) | ||
model.load_state_dict(state_dict, strict=True) | ||
return model | ||
|
||
|
||
def init_seed(out_path='seed.npz'): | ||
""" Initialize seed model and save it to file. | ||
:param out_path: The path to save the seed model to. | ||
:type out_path: str | ||
""" | ||
# Init and save | ||
model = compile_model() | ||
save_parameters(model, out_path) | ||
|
||
|
||
if __name__ == "__main__": | ||
init_seed('../seed.npz') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
""" knn monitor as in InstDisc https://arxiv.org/abs/1805.01978 | ||
implementation follows http://github.com/zhirongw/lemniscate.pytorch and https://github.com/leftthomas/SimCLR | ||
""" | ||
import torch.nn.functional as F | ||
import torch | ||
|
||
|
||
def knn_monitor(net, memory_data_loader, test_data_loader, epoch, k=200, t=0.1, hide_progress=False): | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
net = net.to(device) | ||
net.eval() | ||
classes = len(memory_data_loader.dataset.classes) | ||
total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0, [] | ||
with torch.no_grad(): | ||
# generate feature bank | ||
for data, target in memory_data_loader: | ||
# feature = net(data.cuda(non_blocking=True)) | ||
feature = net(data.to(device)) | ||
feature = F.normalize(feature, dim=1) | ||
feature_bank.append(feature) | ||
# [D, N] | ||
feature_bank = torch.cat(feature_bank, dim=0).t().contiguous() | ||
# [N] | ||
feature_labels = torch.tensor( | ||
memory_data_loader.dataset.targets, device=feature_bank.device) | ||
# loop test data to predict the label by weighted knn search | ||
for data, target in test_data_loader: | ||
# data, target = data.cuda( | ||
# non_blocking=True), target.cuda(non_blocking=True) | ||
data, target = data.to(device), target.to(device) | ||
feature = net(data) | ||
feature = F.normalize(feature, dim=1) | ||
|
||
pred_labels = knn_predict( | ||
feature, feature_bank, feature_labels, classes, k, t) | ||
|
||
total_num += data.size(0) | ||
total_top1 += (pred_labels[:, 0] == target).float().sum().item() | ||
return total_top1 / total_num # * 100 | ||
|
||
|
||
def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t): | ||
# compute cos similarity between each feature vector and feature bank ---> [B, N] | ||
sim_matrix = torch.mm(feature, feature_bank) | ||
# [B, K] | ||
sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1) | ||
# [B, K] | ||
sim_labels = torch.gather(feature_labels.expand( | ||
feature.size(0), -1), dim=-1, index=sim_indices) | ||
sim_weight = (sim_weight / knn_t).exp() | ||
|
||
# counts for each class | ||
one_hot_label = torch.zeros(feature.size( | ||
0) * knn_k, classes, device=sim_labels.device) | ||
# [B*K, C] | ||
one_hot_label = one_hot_label.scatter( | ||
dim=-1, index=sim_labels.view(-1, 1), value=1.0) | ||
# weighted score ---> [B, C] | ||
pred_scores = torch.sum(one_hot_label.view(feature.size( | ||
0), -1, classes) * sim_weight.unsqueeze(dim=-1), dim=1) | ||
|
||
pred_labels = pred_scores.argsort(dim=-1, descending=True) | ||
return pred_labels |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
name: fedsimsiam | ||
build_dependencies: | ||
- pip | ||
- setuptools | ||
- wheel==0.37.1 | ||
dependencies: | ||
- torch==2.2.1 | ||
- torchvision==0.17.1 | ||
- fedn==0.9.0 |
Oops, something went wrong.