Skip to content

Commit 570d0bf

Browse files
cying17baochunli
andauthored
Fixed an issue related to client samplers. (#177)
Co-authored-by: Baochun Li <bli@ece.toronto.edu>
1 parent 25b0104 commit 570d0bf

File tree

10 files changed

+25
-26
lines changed

10 files changed

+25
-26
lines changed

docs/Configuration.md

+3-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ Attributes in **bold** must be included in a configuration file, while attribute
2323
|||`mistnet`|A client for MistNet|
2424
|**total_clients**|The total number of clients|A positive number||
2525
|**per_round**|The number of clients selected in each round| Any positive integer that is not larger than **total_clients**||
26-
|**do_test**|Whether the clients compute test accuracy locally| `true` or `false`|if `true` and the configuration file has `results` section, a CSV file will log test accuracy of every selected client in each round|
26+
|do_test|Whether the clients compute test accuracy locally| `true` or `false`|if `true` and the configuration file has `results` section, a CSV file will log test accuracy of every selected client in each round|
2727
|speed_simulation|Whether we simulate client heterogeneity in training speed|
2828
|simulation_distribution|Parameters for simulating client heterogeneity in training speed|`distribution`|`normal` for normal or `zipf` for Zipf|
2929
|||`s`|the parameter `s` in Zipf distribution|
@@ -76,6 +76,7 @@ Attributes in **bold** must be included in a configuration file, while attribute
7676
|disable_clients|If this optional setting is enabled as `true`, the server will not launched client processes on the same machine.||
7777
|s3_endpoint_url|The endpoint URL for an S3-compatible storage service, used for transferring payloads between clients and servers.||
7878
|s3_bucket|The bucket name for an S3-compatible storage service, used for transferring payloads between clients and servers.||
79+
|random_seed|Use a fixed random seed for selecting clients (and sampling testset if needed) so that experiments are reproducible||
7980
|ping_interval|The time interval in seconds at which the server pings the client. ||default: 3600|
8081
|ping_timeout| The time in seconds that the client waits for the server to respond before disconnecting.|| default: 3600|
8182
|synchronous|Synchronous or asynchronous mode|`true` or `false`||
@@ -120,7 +121,7 @@ Attributes in **bold** must be included in a configuration file, while attribute
120121
|||`mixed`|Some data are iid, while others are non-iid. Must have *non_iid_clients* attributes|
121122
|test_set_sampler|How to sample the test set when clients test locally|Could be any **sampler**|Without this parameter, every client's test set is the test set of the datasource|
122123
|edge_test_set_sampler|How to sample the test set when edge servers test locally|Could be any **sampler**|Without this parameter, edge servers' test sets are the test set of the datasource if they locally test their aggregated models in cross-silo FL|
123-
|random_seed|Use a fixed random seed so that experiments are reproducible (clients always have the same datasets)||
124+
|random_seed|Use a fixed random seed to sample each client's dataset so that experiments are reproducible||
124125
|**partition_size**|Number of samples in each client's dataset|Any positive integer||
125126
|concentration| The concentration parameter of symmetric Dirichlet distribution, used by `noniid` **sampler** || default: 1|
126127
|*non_iid_clients*|Indexs of clients whose datasets are non-iid. Other clients' datasets are iid|e.g., 4|Must have this attribute if the **sampler** is `mixed`|

examples/adaptive_hgb/adaptive_hgb_client.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def load_data(self) -> None:
9595

9696
self.valset = self.datasource.get_val_set()
9797

98-
if Config().clients.do_test:
98+
if hasattr(Config().clients, 'do_test') and Config().clients.do_test:
9999
# Set the testset if local testing is needed
100100
self.testset = self.datasource.get_test_set()
101101

@@ -225,15 +225,15 @@ async def train(self):
225225
delta_o, delta_g = self.obtain_delta_og()
226226

227227
# Generate a report for the server, performing model testing if applicable
228-
if Config().clients.do_test:
228+
if hasattr(Config().clients, 'do_test') and Config().clients.do_test:
229229
accuracy = self.trainer.test(self.testset)
230230

231231
if accuracy == 0:
232232
# The testing process failed, disconnect from the server
233233
await self.sio.disconnect()
234234

235-
logging.info("[Client #{:d}] Test accuracy: {:.2f}%".format(
236-
self.client_id, 100 * accuracy))
235+
logging.info('[Client #%d] Test accuracy: %.2f%%.', self.client_id,
236+
100 * accuracy)
237237
else:
238238
accuracy = 0
239239

examples/fedasync/fedasync_server.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ async def process_reports(self):
7676
self.algorithm.load_weights(updated_weights)
7777

7878
# Testing the global model accuracy
79-
if Config().clients.do_test:
79+
if hasattr(Config().server, 'do_test') and not Config().server.do_test:
8080
# Compute the average accuracy from client reports
8181
self.accuracy = self.accuracy_averaging(self.updates)
8282
logging.info('[%s] Average client accuracy: %.2f%%.', self,

examples/fedunlearning_baseline/fedunlearning_client.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ def process_server_response(self, server_response):
4646
self,
4747
Config().clients.deleted_data_ratio * 100)
4848

49-
if (hasattr(Config().data, 'reload_data')
50-
and Config().data.reload_data) or not self.data_loaded:
49+
if not hasattr(Config().data,
50+
'reload_data') or Config().data.reload_data:
5151
logging.info("[%s] Loading the dataset.", self)
5252
self.load_data()
5353

plato/clients/base.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ def __init__(self) -> None:
8585
self.sio = None
8686
self.chunks = []
8787
self.server_payload = None
88-
self.data_loaded = False # is training data already loaded from the disk?
8988
self.s3_client = None
9089
self.outbound_processor = None
9190
self.inbound_processor = None
@@ -156,8 +155,8 @@ async def payload_to_arrive(self, response) -> None:
156155

157156
logging.info("[Client #%d] Selected by the server.", self.client_id)
158157

159-
if (hasattr(Config().data, 'reload_data')
160-
and Config().data.reload_data) or not self.data_loaded:
158+
if not hasattr(Config().data,
159+
'reload_data') or Config().data.reload_data:
161160
self.load_data()
162161

163162
if self.comm_simulation:

plato/clients/simple.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,6 @@ def load_data(self) -> None:
8686
self.datasource = datasources_registry.get(
8787
client_id=self.client_id)
8888

89-
self.data_loaded = True
90-
9189
logging.info("[%s] Dataset size: %s", self,
9290
self.datasource.num_train_examples())
9391

@@ -102,7 +100,7 @@ def load_data(self) -> None:
102100
# PyTorch uses samplers when loading data with a data loader
103101
self.trainset = self.datasource.get_train_set()
104102

105-
if Config().clients.do_test:
103+
if hasattr(Config().clients, 'do_test') and Config().clients.do_test:
106104
# Set the testset if local testing is needed
107105
self.testset = self.datasource.get_test_set()
108106
if hasattr(Config().data, 'testset_sampler'):
@@ -130,9 +128,9 @@ async def train(self):
130128
weights = self.algorithm.extract_weights()
131129

132130
# Generate a report for the server, performing model testing if applicable
133-
if Config().clients.do_test and (
134-
not hasattr(Config().clients, 'test_interval')
135-
or self.current_round % Config().clients.test_interval == 0):
131+
if (hasattr(Config().clients, 'do_test') and Config().clients.do_test
132+
) and (not hasattr(Config().clients, 'test_interval') or
133+
self.current_round % Config().clients.test_interval == 0):
136134
accuracy = self.trainer.test(self.testset, self.testset_sampler)
137135

138136
if accuracy == -1:

plato/samplers/dirichlet.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
class Sampler(base.Sampler):
1313
"""Create a data sampler for each client to use a divided partition of the
1414
dataset, biased across labels according to the Dirichlet distribution."""
15+
1516
def __init__(self, datasource, client_id, testing):
1617
super().__init__()
1718

@@ -26,12 +27,11 @@ def __init__(self, datasource, client_id, testing):
2627

2728
if dist.distribution.lower() == "uniform":
2829
self.partition_size *= np.random.uniform(dist.low, dist.high)
29-
30+
3031
if dist.distribution.lower() == "normal":
3132
self.partition_size *= np.random.normal(dist.mean, dist.high)
3233

33-
self.partition_size = int(self.partition_size)
34-
34+
self.partition_size = int(self.partition_size)
3535

3636
# Concentration parameter to be used in the Dirichlet distribution
3737
concentration = Config().data.concentration if hasattr(

plato/samplers/iid.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class Sampler(base.Sampler):
1515

1616
def __init__(self, datasource, client_id, testing):
1717
super().__init__()
18+
1819
if testing:
1920
dataset = datasource.get_test_set()
2021
else:
@@ -45,8 +46,8 @@ def get(self):
4546
"""Obtains an instance of the sampler. """
4647
gen = torch.Generator()
4748
gen.manual_seed(self.random_seed)
48-
version = torch.__version__
49-
if int(version[0]) <= 1 and int(version[2]) <= 5:
49+
version = torch.__version__.split(".")
50+
if int(version[0]) <= 1 and int(version[1]) <= 5:
5051
return SubsetRandomSampler(self.subset_indices)
5152
return SubsetRandomSampler(self.subset_indices, generator=gen)
5253

plato/servers/fedavg.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def configure(self):
108108
Config().params['result_path'])
109109

110110
# Initialize the test accuracy csv file if clients compute locally
111-
if Config().clients.do_test:
111+
if hasattr(Config().clients, 'do_test') and Config().clients.do_test:
112112
accuracy_csv_file = f"{Config().params['result_path']}/{os.getpid()}_accuracy.csv"
113113
accuracy_headers = ["round", "client_id", "accuracy"]
114114
csv_processor.initialize_csv(accuracy_csv_file, accuracy_headers,
@@ -223,7 +223,7 @@ async def wrap_up_processing_reports(self):
223223
result_csv_file = f"{Config().params['result_path']}/{os.getpid()}.csv"
224224
csv_processor.write_csv(result_csv_file, new_row)
225225

226-
if Config().clients.do_test:
226+
if hasattr(Config().clients, 'do_test') and Config().clients.do_test:
227227
# Updates the log for client test accuracies
228228
accuracy_csv_file = f"{Config().params['result_path']}/{os.getpid()}_accuracy.csv"
229229

plato/servers/mistnet.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ async def process_reports(self):
4949
Config().algorithm.cut_layer)
5050

5151
# Test the updated model
52-
if not Config().clients.do_test:
52+
if not hasattr(Config().server, 'do_test') or Config().server.do_test:
5353
self.accuracy = self.trainer.test(self.testset)
5454
logging.info('[%s] Global model accuracy: %.2f%%\n', self,
5555
100 * self.accuracy)

0 commit comments

Comments
 (0)