Skip to content

Commit 057ff5b

Browse files
cuiboyuanbaochunli
andauthored
Supported federated learning with generative adversarial networks (GANs). (#189)
Co-authored-by: Baochun Li <bli@ece.toronto.edu>
1 parent 6eb8e26 commit 057ff5b

File tree

11 files changed

+680
-4
lines changed

11 files changed

+680
-4
lines changed

configs/CelebA/fedavg_gan_dcgan.yml

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
clients:
2+
# Type
3+
type: simple
4+
5+
# The total number of clients
6+
total_clients: 2
7+
8+
# The number of clients selected in each round
9+
per_round: 2
10+
11+
# Should the clients compute test accuracy locally?
12+
do_test: false
13+
14+
server:
15+
address: 127.0.0.1
16+
port: 8000
17+
18+
# Choose to send Generator or Discriminator to clients at
19+
# the end of each round.
20+
# Value here should be one of 'none', 'generator', 'discriminator', or 'both'
21+
network_to_sync: generator
22+
23+
data:
24+
# The training and testing dataset
25+
datasource: CelebA
26+
27+
# Only add face identity as labels for training
28+
celeba_targets:
29+
attr: false
30+
identity: true
31+
32+
# Resize all images to 64x64
33+
celeba_img_size: 64
34+
35+
# Where the dataset is located
36+
data_path: data
37+
38+
# Number of samples in each partition
39+
partition_size: 81000
40+
41+
# IID or non-IID?
42+
sampler: iid
43+
44+
# The concentration parameter for the Dirichlet distribution
45+
concentration: 0.5
46+
47+
# The random seed for sampling data
48+
random_seed: 1
49+
50+
trainer:
51+
# The type of the trainer
52+
type: gan
53+
54+
# The maximum number of training rounds
55+
rounds: 5
56+
57+
# The maximum number of clients running concurrently
58+
max_concurrency: 3
59+
60+
# The target Frechet Distance
61+
target_perplexity: 0
62+
63+
# Number of epoches for local training in each communication round
64+
epochs: 5
65+
batch_size: 128
66+
optimizer: Adam
67+
learning_rate: 0.0002
68+
weight_decay: 0.0
69+
70+
# The machine learning model
71+
model_name: dcgan
72+
73+
algorithm:
74+
# Aggregation algorithm
75+
type: fedavg_gan

configs/CelebA/fedavg_resnet18.yml

+3
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ data:
2424
# For ResNet, do not set <attr> to True since it does not match the expected output of ResNet
2525
attr: false
2626
identity: true
27+
28+
# Resize all images to 32x32; default is 64x64
29+
celeba_img_size: 32
2730

2831
# Number of identity in CelebA
2932
num_classes: 10178

plato/algorithms/fedavg_gan.py

+78
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
"""
2+
The federated averaging algorithm for GAN model.
3+
"""
4+
from collections import OrderedDict
5+
6+
from plato.algorithms import fedavg
7+
from plato.trainers.base import Trainer
8+
9+
10+
class Algorithm(fedavg.Algorithm):
11+
""" Federated averaging algorithm for GAN models, used by both the client and the server. """
12+
13+
def __init__(self, trainer: Trainer):
14+
super().__init__(trainer=trainer)
15+
self.generator = self.model.generator
16+
self.discriminator = self.model.discriminator
17+
18+
def compute_weight_deltas(self, weights_received):
19+
""" Extract the weights received from a client and compute the updates. """
20+
baseline_weights_gen, baseline_weights_disc = self.extract_weights()
21+
22+
deltas = []
23+
for weight_gen, weight_disc in weights_received:
24+
delta_gen = OrderedDict()
25+
for name, current_weight in weight_gen.items():
26+
baseline = baseline_weights_gen[name]
27+
28+
delta = current_weight - baseline
29+
delta_gen[name] = delta
30+
31+
delta_disc = OrderedDict()
32+
for name, current_weight in weight_disc.items():
33+
baseline = baseline_weights_disc[name]
34+
35+
delta = current_weight - baseline
36+
delta_disc[name] = delta
37+
38+
deltas.append((delta_gen, delta_disc))
39+
40+
return deltas
41+
42+
def update_weights(self, deltas):
43+
""" Update the existing model weights. """
44+
baseline_weights_gen, baseline_weights_disc = self.extract_weights()
45+
update_gen, update_disc = deltas
46+
47+
updated_weights_gen = OrderedDict()
48+
for name, weight in baseline_weights_gen.items():
49+
updated_weights_gen[name] = weight + update_gen[name]
50+
51+
updated_weights_disc = OrderedDict()
52+
for name, weight in baseline_weights_disc.items():
53+
updated_weights_disc[name] = weight + update_disc[name]
54+
55+
return updated_weights_gen, updated_weights_disc
56+
57+
def extract_weights(self, model=None):
58+
""" Extract weights from the model. """
59+
generator = self.generator
60+
discriminator = self.discriminator
61+
if model is not None:
62+
generator = model.generator
63+
discriminator = model.discriminator
64+
65+
gen_weight = generator.cpu().state_dict()
66+
disc_weight = discriminator.cpu().state_dict()
67+
68+
return gen_weight, disc_weight
69+
70+
def load_weights(self, weights):
71+
""" Load the model weights passed in as a parameter. """
72+
weights_gen, weights_disc = weights
73+
# The client might only receive one or none of the Generator
74+
# and Discriminator model weight.
75+
if weights_gen is not None:
76+
self.generator.load_state_dict(weights_gen, strict=True)
77+
if weights_disc is not None:
78+
self.discriminator.load_state_dict(weights_disc, strict=True)

plato/algorithms/registry.py

+2
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,13 @@
3030
from plato.algorithms import (
3131
fedavg,
3232
mistnet,
33+
fedavg_gan,
3334
)
3435

3536
registered_algorithms = OrderedDict([
3637
('fedavg', fedavg.Algorithm),
3738
('mistnet', mistnet.Algorithm),
39+
('fedavg_gan', fedavg_gan.Algorithm),
3840
])
3941

4042

plato/datasources/celeba.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -56,26 +56,32 @@ def __init__(self):
5656
else:
5757
target_types = ['attr', 'identity']
5858

59-
image_size = 32
59+
image_size = 64
60+
if hasattr(Config().data, 'celeba_img_size'):
61+
image_size = Config().data.celeba_img_size
62+
6063
_transform = transforms.Compose([
6164
transforms.Resize(image_size),
6265
transforms.CenterCrop(image_size),
6366
transforms.ToTensor(),
6467
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
6568
])
69+
_target_transform = None
70+
if target_types:
71+
_target_transform = DataSource._target_transform
6672

6773
self.trainset = CelebA(root=_path,
6874
split='train',
6975
target_type=target_types,
7076
download=False,
7177
transform=_transform,
72-
target_transform=DataSource._target_transform)
78+
target_transform=_target_transform)
7379
self.testset = CelebA(root=_path,
7480
split='test',
7581
target_type=target_types,
7682
download=False,
7783
transform=_transform,
78-
target_transform=DataSource._target_transform)
84+
target_transform=_target_transform)
7985

8086
@staticmethod
8187
def _target_transform(label):
@@ -100,7 +106,10 @@ def _target_transform(label):
100106

101107
@staticmethod
102108
def input_shape():
103-
return [162770, 3, 32, 32]
109+
image_size = 64
110+
if hasattr(Config().data, 'celeba_img_size'):
111+
image_size = Config().data.celeba_img_size
112+
return [162770, 3, image_size, image_size]
104113

105114
def num_train_examples(self):
106115
return 162770

plato/models/dcgan.py

+120
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
"""
2+
The DCGAN model.
3+
4+
Reference:
5+
https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
6+
"""
7+
8+
from torch import nn
9+
10+
nz = 100
11+
nc = 3
12+
ngf = 64
13+
ndf = 64
14+
15+
16+
class Generator(nn.Module):
17+
""" Generator network of DCGAN """
18+
19+
def __init__(self):
20+
super().__init__()
21+
22+
self.main = nn.Sequential(
23+
# input is Z, going into a convolution
24+
nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
25+
nn.BatchNorm2d(ngf * 8),
26+
nn.ReLU(True),
27+
# state size. (ngf*8) x 4 x 4
28+
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
29+
nn.BatchNorm2d(ngf * 4),
30+
nn.ReLU(True),
31+
# state size. (ngf*4) x 8 x 8
32+
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
33+
nn.BatchNorm2d(ngf * 2),
34+
nn.ReLU(True),
35+
# state size. (ngf*2) x 16 x 16
36+
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
37+
nn.BatchNorm2d(ngf),
38+
nn.ReLU(True),
39+
# state size. (ngf) x 32 x 32
40+
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
41+
nn.Tanh()
42+
# state size. (nc) x 64 x 64
43+
)
44+
45+
def forward(self, input_data):
46+
""" Forward pass. """
47+
return self.main(input_data)
48+
49+
50+
class Discriminator(nn.Module):
51+
""" Discriminator network of DCGAN. """
52+
53+
def __init__(self):
54+
super().__init__()
55+
56+
self.main = nn.Sequential(
57+
# input is (nc) x 64 x 64
58+
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
59+
nn.LeakyReLU(0.2, inplace=True),
60+
# state size. (ndf) x 32 x 32
61+
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
62+
nn.BatchNorm2d(ndf * 2),
63+
nn.LeakyReLU(0.2, inplace=True),
64+
# state size. (ndf*2) x 16 x 16
65+
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
66+
nn.BatchNorm2d(ndf * 4),
67+
nn.LeakyReLU(0.2, inplace=True),
68+
# state size. (ndf*4) x 8 x 8
69+
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
70+
nn.BatchNorm2d(ndf * 8),
71+
nn.LeakyReLU(0.2, inplace=True),
72+
# state size. (ndf*8) x 4 x 4
73+
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
74+
nn.Sigmoid())
75+
76+
def forward(self, input_data):
77+
return self.main(input_data)
78+
79+
80+
class Model:
81+
""" A wrapper class to hold the Generator and Discriminator models of DCGAN. """
82+
83+
def __init__(self) -> None:
84+
self.generator = Generator()
85+
self.discriminator = Discriminator()
86+
self.loss_criterion = nn.BCELoss()
87+
88+
self.nz = nz
89+
self.nc = nc
90+
self.ngf = ngf
91+
self.ndf = ndf
92+
93+
def weights_init(self, model):
94+
classname = model.__class__.__name__
95+
if classname.find('Conv') != -1:
96+
nn.init.normal_(model.weight.data, 0.0, 0.02)
97+
elif classname.find('BatchNorm') != -1:
98+
nn.init.normal_(model.weight.data, 1.0, 0.02)
99+
nn.init.constant_(model.bias.data, 0)
100+
101+
def cpu(self):
102+
self.generator.cpu()
103+
self.discriminator.cpu()
104+
105+
def to(self, device):
106+
self.generator.to(device)
107+
self.discriminator.to(device)
108+
109+
def train(self):
110+
self.generator.train()
111+
self.discriminator.train()
112+
113+
def eval(self):
114+
self.generator.eval()
115+
self.discriminator.eval()
116+
117+
@staticmethod
118+
def get_model(*args):
119+
""" Obtaining an instance of this model. """
120+
return Model()

plato/models/registry.py

+2
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
hybrid,
3838
efficientnet,
3939
regnet,
40+
dcgan
4041
)
4142
registered_models = OrderedDict([
4243
('lenet5', lenet5.Model),
@@ -52,6 +53,7 @@
5253
('hybrid', hybrid.Model),
5354
('efficientnet', efficientnet.Model),
5455
('regnet', regnet.Model),
56+
('dcgan', dcgan.Model),
5557
])
5658

5759

0 commit comments

Comments
 (0)