From 5f1f43a3b6d7a9d3190f46cff3f22b6a8067135d Mon Sep 17 00:00:00 2001 From: stefanhellander <59477428+stefanhellander@users.noreply.github.com> Date: Tue, 7 May 2024 14:44:24 +0200 Subject: [PATCH 1/5] Update version (#598) --- docs/conf.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index acd52819f..913c35d9c 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -12,7 +12,7 @@ author = 'Scaleout Systems AB' # The full version, including alpha/beta/rc tags -release = '0.9.1' +release = '0.9.2' # Add any Sphinx extension module names here, as strings extensions = [ diff --git a/pyproject.toml b/pyproject.toml index 0e149a442..24233b9f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta" [project] name = "fedn" -version = "0.9.1" +version = "0.9.2" description = "Scaleout Federated Learning" authors = [{ name = "Scaleout Systems AB", email = "contact@scaleoutsystems.com" }] readme = "README.rst" From cbc7dbcb3ace2541c0def08849c362d98a206500 Mon Sep 17 00:00:00 2001 From: Fredrik Wrede Date: Tue, 7 May 2024 14:57:49 +0200 Subject: [PATCH 2/5] Feature/SK-838 | Ruff linting (#599) --- examples/async-clients/client/entrypoint.py | 32 +- examples/async-clients/init_fedn.py | 6 +- examples/async-clients/run_clients.py | 62 ++-- examples/async-clients/run_experiment.py | 12 +- examples/flower-client/client/entrypoint.py | 4 +- examples/flower-client/client/flwr_client.py | 7 +- examples/flower-client/init_fedn.py | 6 +- examples/mnist-keras/client/entrypoint.py | 64 ++-- examples/mnist-keras/client/get_data.py | 25 +- examples/mnist-pytorch/client/data.py | 75 ++-- examples/mnist-pytorch/client/model.py | 15 +- examples/mnist-pytorch/client/train.py | 19 +- examples/mnist-pytorch/client/validate.py | 10 +- fedn/cli/client_cmd.py | 140 ++++--- fedn/cli/combiner_cmd.py | 74 ++-- fedn/cli/config_cmd.py | 42 +-- fedn/cli/main.py | 2 +- fedn/cli/model_cmd.py | 29 +- fedn/cli/package_cmd.py | 38 +- fedn/cli/round_cmd.py | 29 +- fedn/cli/run_cmd.py | 171 +++++---- fedn/cli/session_cmd.py | 29 +- fedn/cli/shared.py | 52 +-- fedn/cli/status_cmd.py | 28 +- fedn/cli/validation_cmd.py | 28 +- fedn/common/certificate/certificate.py | 23 +- fedn/common/certificate/certificatemanager.py | 13 +- fedn/common/config.py | 8 +- fedn/common/log_config.py | 56 ++- fedn/network/api/auth.py | 35 +- fedn/network/api/client.py | 274 +++++++------- fedn/network/api/interface.py | 104 ++---- fedn/network/api/network.py | 47 ++- fedn/network/api/server.py | 19 +- fedn/network/api/v1/client_routes.py | 11 +- fedn/network/api/v1/combiner_routes.py | 11 +- fedn/network/api/v1/model_routes.py | 15 +- fedn/network/api/v1/package_routes.py | 12 +- fedn/network/api/v1/round_routes.py | 11 +- fedn/network/api/v1/session_routes.py | 11 +- fedn/network/api/v1/shared.py | 7 +- fedn/network/api/v1/status_routes.py | 32 +- fedn/network/api/v1/validation_routes.py | 35 +- fedn/network/clients/client.py | 249 ++++++------- fedn/network/clients/connect.py | 63 ++-- fedn/network/clients/package.py | 63 ++-- fedn/network/clients/state.py | 5 +- fedn/network/combiner/combiner.py | 170 ++++----- fedn/network/combiner/connect.py | 42 +-- fedn/network/combiner/interfaces.py | 89 ++--- fedn/network/combiner/modelservice.py | 39 +- fedn/network/combiner/roundhandler.py | 107 +++--- fedn/network/config.py | 8 +- fedn/network/controller/control.py | 76 ++-- fedn/network/controller/controlbase.py | 44 +-- fedn/network/grpc/auth.py | 58 ++- fedn/network/loadbalancer/firstavailable.py | 4 +- fedn/network/loadbalancer/leastpacked.py | 4 +- fedn/network/loadbalancer/loadbalancerbase.py | 4 +- fedn/network/state.py | 7 +- .../storage/models/memorymodelstorage.py | 7 +- fedn/network/storage/models/modelstorage.py | 13 +- .../storage/models/tempmodelstorage.py | 20 +- fedn/network/storage/s3/base.py | 6 +- fedn/network/storage/s3/miniorepository.py | 47 ++- fedn/network/storage/s3/repository.py | 33 +- .../storage/statestore/mongostatestore.py | 144 +++----- .../storage/statestore/statestorebase.py | 15 +- fedn/utils/checksum.py | 2 +- fedn/utils/dispatcher.py | 57 +-- fedn/utils/flowercompat/client_app_adapter.py | 49 +-- fedn/utils/helpers/helperbase.py | 14 +- fedn/utils/helpers/helpers.py | 10 +- fedn/utils/helpers/plugins/androidhelper.py | 16 +- fedn/utils/helpers/plugins/numpyhelper.py | 35 +- fedn/utils/plots.py | 346 ++++++++---------- 76 files changed, 1609 insertions(+), 1960 deletions(-) diff --git a/examples/async-clients/client/entrypoint.py b/examples/async-clients/client/entrypoint.py index 4ddddd956..220b5299b 100644 --- a/examples/async-clients/client/entrypoint.py +++ b/examples/async-clients/client/entrypoint.py @@ -8,7 +8,7 @@ from fedn.utils.helpers.helpers import get_helper, save_metadata, save_metrics -HELPER_MODULE = 'numpyhelper' +HELPER_MODULE = "numpyhelper" ARRAY_SIZE = 10000 @@ -22,7 +22,7 @@ def compile_model(max_iter=1): def save_parameters(model, out_path): - """ Save model to disk. + """Save model to disk. :param model: The model to save. :type model: torch.nn.Module @@ -36,7 +36,7 @@ def save_parameters(model, out_path): def load_parameters(model_path): - """ Load model from disk. + """Load model from disk. param model_path: The path to load from. :type model_path: str @@ -49,8 +49,8 @@ def load_parameters(model_path): return parameters -def init_seed(out_path='seed.npz'): - """ Initialize seed model. +def init_seed(out_path="seed.npz"): + """Initialize seed model. :param out_path: The path to save the seed model to. :type out_path: str @@ -61,7 +61,7 @@ def init_seed(out_path='seed.npz'): def make_data(n_min=50, n_max=100): - """ Generate / simulate a random number n data points. + """Generate / simulate a random number n data points. n will fall in the interval (n_min, n_max) @@ -78,14 +78,12 @@ def make_data(n_min=50, n_max=100): def train(in_model_path, out_model_path): - """ Train model. - - """ + """Train model.""" # Load model parameters = load_parameters(in_model_path) model = compile_model() - n = len(parameters)//2 + n = len(parameters) // 2 model.coefs_ = parameters[:n] model.intercepts_ = parameters[n:] @@ -97,7 +95,7 @@ def train(in_model_path, out_model_path): # Metadata needed for aggregation server side metadata = { - 'num_examples': len(X_train), + "num_examples": len(X_train), } # Save JSON metadata file @@ -108,7 +106,7 @@ def train(in_model_path, out_model_path): def validate(in_model_path, out_json_path): - """ Validate model. + """Validate model. :param in_model_path: The path to the input model. :type in_model_path: str @@ -119,7 +117,7 @@ def validate(in_model_path, out_json_path): """ parameters = load_parameters(in_model_path) model = compile_model() - n = len(parameters)//2 + n = len(parameters) // 2 model.coefs_ = parameters[:n] model.intercepts_ = parameters[n:] @@ -134,9 +132,5 @@ def validate(in_model_path, out_json_path): save_metrics(report, out_json_path) -if __name__ == '__main__': - fire.Fire({ - 'init_seed': init_seed, - 'train': train, - 'validate': validate - }) +if __name__ == "__main__": + fire.Fire({"init_seed": init_seed, "train": train, "validate": validate}) diff --git a/examples/async-clients/init_fedn.py b/examples/async-clients/init_fedn.py index 2aa298602..8677c472d 100644 --- a/examples/async-clients/init_fedn.py +++ b/examples/async-clients/init_fedn.py @@ -1,8 +1,8 @@ from fedn import APIClient -DISCOVER_HOST = '127.0.0.1' +DISCOVER_HOST = "127.0.0.1" DISCOVER_PORT = 8092 client = APIClient(DISCOVER_HOST, DISCOVER_PORT) -client.set_active_package('package.tgz', 'numpyhelper') -client.set_active_model('seed.npz') +client.set_active_package("package.tgz", "numpyhelper") +client.set_active_model("seed.npz") diff --git a/examples/async-clients/run_clients.py b/examples/async-clients/run_clients.py index 5b56c53c9..82da30ad9 100644 --- a/examples/async-clients/run_clients.py +++ b/examples/async-clients/run_clients.py @@ -26,24 +26,39 @@ # Use with a local deployment settings = { - 'DISCOVER_HOST': '127.0.0.1', - 'DISCOVER_PORT': 8092, - 'TOKEN': None, - 'N_CLIENTS': 10, - 'N_CYCLES': 100, - 'CLIENTS_MAX_DELAY': 10, - 'CLIENTS_ONLINE_FOR_SECONDS': 120 + "DISCOVER_HOST": "127.0.0.1", + "DISCOVER_PORT": 8092, + "TOKEN": None, + "N_CLIENTS": 10, + "N_CYCLES": 100, + "CLIENTS_MAX_DELAY": 10, + "CLIENTS_ONLINE_FOR_SECONDS": 120, } -client_config = {'discover_host': settings['DISCOVER_HOST'], 'discover_port': settings['DISCOVER_PORT'], 'token': settings['TOKEN'], 'name': 'testclient', - 'client_id': 1, 'remote_compute_context': True, 'force_ssl': False, 'dry_run': False, 'secure': False, - 'preshared_cert': False, 'verify': False, 'preferred_combiner': False, - 'validator': True, 'trainer': True, 'init': None, 'logfile': 'test.log', 'heartbeat_interval': 2, - 'reconnect_after_missed_heartbeat': 30} +client_config = { + "discover_host": settings["DISCOVER_HOST"], + "discover_port": settings["DISCOVER_PORT"], + "token": settings["TOKEN"], + "name": "testclient", + "client_id": 1, + "remote_compute_context": True, + "force_ssl": False, + "dry_run": False, + "secure": False, + "preshared_cert": False, + "verify": False, + "preferred_combiner": False, + "validator": True, + "trainer": True, + "init": None, + "logfile": "test.log", + "heartbeat_interval": 2, + "reconnect_after_missed_heartbeat": 30, +} -def run_client(online_for=120, name='client'): - """ Simulates a client that starts and stops +def run_client(online_for=120, name="client"): + """Simulates a client that starts and stops at random intervals. The client will start after a radom time 'mean_delay', @@ -55,23 +70,28 @@ def run_client(online_for=120, name='client'): """ conf = copy.deepcopy(client_config) - conf['name'] = name + conf["name"] = name - for i in range(settings['N_CYCLES']): + for i in range(settings["N_CYCLES"]): # Sample a delay until the client starts - t_start = np.random.randint(0, settings['CLIENTS_MAX_DELAY']) + t_start = np.random.randint(0, settings["CLIENTS_MAX_DELAY"]) time.sleep(t_start) fl_client = Client(conf) time.sleep(online_for) fl_client.disconnect() -if __name__ == '__main__': - +if __name__ == "__main__": # We start N_CLIENTS independent client processes processes = [] - for i in range(settings['N_CLIENTS']): - p = Process(target=run_client, args=(settings['CLIENTS_ONLINE_FOR_SECONDS'], 'client{}'.format(i),)) + for i in range(settings["N_CLIENTS"]): + p = Process( + target=run_client, + args=( + settings["CLIENTS_ONLINE_FOR_SECONDS"], + "client{}".format(i), + ), + ) processes.append(p) p.start() diff --git a/examples/async-clients/run_experiment.py b/examples/async-clients/run_experiment.py index d8d12dca2..f6d5e1f6d 100644 --- a/examples/async-clients/run_experiment.py +++ b/examples/async-clients/run_experiment.py @@ -3,16 +3,14 @@ from fedn import APIClient -DISCOVER_HOST = '127.0.0.1' +DISCOVER_HOST = "127.0.0.1" DISCOVER_PORT = 8092 client = APIClient(DISCOVER_HOST, DISCOVER_PORT) -if __name__ == '__main__': - +if __name__ == "__main__": # Run six sessions, each with 100 rounds. num_sessions = 6 for s in range(num_sessions): - session_config = { "helper": "numpyhelper", "id": str(uuid.uuid4()), @@ -23,12 +21,12 @@ } session = client.start_session(**session_config) - if session['success'] is False: - print(session['message']) + if session["success"] is False: + print(session["message"]) exit(0) print("Started session: {}".format(session)) # Wait for session to finish - while not client.session_is_finished(session_config['id']): + while not client.session_is_finished(session_config["id"]): time.sleep(2) diff --git a/examples/flower-client/client/entrypoint.py b/examples/flower-client/client/entrypoint.py index e80e7d4b5..1a9a8b8cf 100755 --- a/examples/flower-client/client/entrypoint.py +++ b/examples/flower-client/client/entrypoint.py @@ -56,9 +56,7 @@ def train(in_model_path, out_model_path): parameters_np = helper.load(in_model_path) # Train on flower client - params, num_examples = flwr_adapter.train( - parameters=parameters_np, partition_id=_get_node_id(), config={} - ) + params, num_examples = flwr_adapter.train(parameters=parameters_np, partition_id=_get_node_id(), config={}) # Metadata needed for aggregation server side metadata = { diff --git a/examples/flower-client/client/flwr_client.py b/examples/flower-client/client/flwr_client.py index 297df3ca7..843e9d31d 100644 --- a/examples/flower-client/client/flwr_client.py +++ b/examples/flower-client/client/flwr_client.py @@ -3,8 +3,7 @@ """ from flwr.client import ClientApp, NumPyClient -from flwr_task import (DEVICE, Net, get_weights, load_data, set_weights, test, - train) +from flwr_task import DEVICE, Net, get_weights, load_data, set_weights, test, train # Define FlowerClient and client_fn @@ -12,9 +11,7 @@ class FlowerClient(NumPyClient): def __init__(self, cid) -> None: super().__init__() self.net = Net().to(DEVICE) - self.trainloader, self.testloader = load_data( - partition_id=int(cid), num_clients=10 - ) + self.trainloader, self.testloader = load_data(partition_id=int(cid), num_clients=10) def get_parameters(self, config): return [val.cpu().numpy() for _, val in self.net.state_dict().items()] diff --git a/examples/flower-client/init_fedn.py b/examples/flower-client/init_fedn.py index 23078fcd9..81864293b 100644 --- a/examples/flower-client/init_fedn.py +++ b/examples/flower-client/init_fedn.py @@ -1,8 +1,8 @@ from fedn import APIClient -DISCOVER_HOST = '127.0.0.1' +DISCOVER_HOST = "127.0.0.1" DISCOVER_PORT = 8092 client = APIClient(DISCOVER_HOST, DISCOVER_PORT) -client.set_package('package.tgz', 'numpyhelper') -client.set_initial_model('seed.npz') +client.set_package("package.tgz", "numpyhelper") +client.set_initial_model("seed.npz") diff --git a/examples/mnist-keras/client/entrypoint.py b/examples/mnist-keras/client/entrypoint.py index 15c0693f2..5420a78bb 100755 --- a/examples/mnist-keras/client/entrypoint.py +++ b/examples/mnist-keras/client/entrypoint.py @@ -7,7 +7,7 @@ from fedn.utils.helpers.helpers import get_helper, save_metadata, save_metrics -HELPER_MODULE = 'numpyhelper' +HELPER_MODULE = "numpyhelper" helper = get_helper(HELPER_MODULE) NUM_CLASSES = 10 @@ -17,13 +17,13 @@ def _get_data_path(): - data_path = os.environ.get('FEDN_DATA_PATH', abs_path + '/data/clients/1/mnist.npz') + data_path = os.environ.get("FEDN_DATA_PATH", abs_path + "/data/clients/1/mnist.npz") return data_path def compile_model(img_rows=28, img_cols=28): - """ Compile the TF model. + """Compile the TF model. param: img_rows: The number of rows in the image type: img_rows: int @@ -38,13 +38,11 @@ def compile_model(img_rows=28, img_cols=28): # Define model model = tf.keras.models.Sequential() model.add(tf.keras.layers.Flatten(input_shape=input_shape)) - model.add(tf.keras.layers.Dense(64, activation='relu')) + model.add(tf.keras.layers.Dense(64, activation="relu")) model.add(tf.keras.layers.Dropout(0.5)) - model.add(tf.keras.layers.Dense(32, activation='relu')) - model.add(tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')) - model.compile(loss=tf.keras.losses.categorical_crossentropy, - optimizer=tf.keras.optimizers.Adam(), - metrics=['accuracy']) + model.add(tf.keras.layers.Dense(32, activation="relu")) + model.add(tf.keras.layers.Dense(NUM_CLASSES, activation="softmax")) + model.compile(loss=tf.keras.losses.categorical_crossentropy, optimizer=tf.keras.optimizers.Adam(), metrics=["accuracy"]) return model @@ -57,14 +55,14 @@ def load_data(data_path, is_train=True): data = np.load(data_path) if is_train: - X = data['x_train'] - y = data['y_train'] + X = data["x_train"] + y = data["y_train"] else: - X = data['x_test'] - y = data['y_test'] + X = data["x_test"] + y = data["y_test"] # Normalize - X = X.astype('float32') + X = X.astype("float32") X = np.expand_dims(X, -1) X = X / 255 y = tf.keras.utils.to_categorical(y, NUM_CLASSES) @@ -72,8 +70,8 @@ def load_data(data_path, is_train=True): return X, y -def init_seed(out_path='../seed.npz'): - """ Initialize seed model and save it to file. +def init_seed(out_path="../seed.npz"): + """Initialize seed model and save it to file. :param out_path: The path to save the seed model to. :type out_path: str @@ -83,7 +81,7 @@ def init_seed(out_path='../seed.npz'): def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1): - """ Complete a model update. + """Complete a model update. Load model paramters from in_model_path (managed by the FEDn client), perform a model update, and write updated paramters @@ -114,9 +112,9 @@ def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1 # Metadata needed for aggregation server side metadata = { # num_examples are mandatory - 'num_examples': len(x_train), - 'batch_size': batch_size, - 'epochs': epochs, + "num_examples": len(x_train), + "batch_size": batch_size, + "epochs": epochs, } # Save JSON metadata file (mandatory) @@ -128,7 +126,7 @@ def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1 def validate(in_model_path, out_json_path, data_path=None): - """ Validate model. + """Validate model. :param in_model_path: The path to the input model. :type in_model_path: str @@ -182,14 +180,16 @@ def predict(in_model_path, out_json_path, data_path=None): # Save JSON with open(out_json_path, "w") as fh: - fh.write(json.dumps({'predictions': y_pred.tolist()})) - - -if __name__ == '__main__': - fire.Fire({ - 'init_seed': init_seed, - 'train': train, - 'validate': validate, - 'predict': predict, - '_get_data_path': _get_data_path, # for testing - }) + fh.write(json.dumps({"predictions": y_pred.tolist()})) + + +if __name__ == "__main__": + fire.Fire( + { + "init_seed": init_seed, + "train": train, + "validate": validate, + "predict": predict, + "_get_data_path": _get_data_path, # for testing + } + ) diff --git a/examples/mnist-keras/client/get_data.py b/examples/mnist-keras/client/get_data.py index 28a12bd20..ed123a4a3 100755 --- a/examples/mnist-keras/client/get_data.py +++ b/examples/mnist-keras/client/get_data.py @@ -7,14 +7,14 @@ def splitset(dataset, parts): n = dataset.shape[0] - local_n = floor(n/parts) + local_n = floor(n / parts) result = [] for i in range(parts): - result.append(dataset[i*local_n: (i+1)*local_n]) + result.append(dataset[i * local_n : (i + 1) * local_n]) return np.array(result) -def split(dataset='data/mnist.npz', outdir='data', n_splits=2): +def split(dataset="data/mnist.npz", outdir="data", n_splits=2): # Load and convert to dict package = np.load(dataset) data = {} @@ -22,32 +22,27 @@ def split(dataset='data/mnist.npz', outdir='data', n_splits=2): data[key] = splitset(val, n_splits) # Make dir if necessary - if not os.path.exists(f'{outdir}/clients'): - os.mkdir(f'{outdir}/clients') + if not os.path.exists(f"{outdir}/clients"): + os.mkdir(f"{outdir}/clients") # Make splits for i in range(n_splits): - subdir = f'{outdir}/clients/{str(i+1)}' + subdir = f"{outdir}/clients/{str(i+1)}" if not os.path.exists(subdir): os.mkdir(subdir) - np.savez(f'{subdir}/mnist.npz', - x_train=data['x_train'][i], - y_train=data['y_train'][i], - x_test=data['x_test'][i], - y_test=data['y_test'][i]) + np.savez(f"{subdir}/mnist.npz", x_train=data["x_train"][i], y_train=data["y_train"][i], x_test=data["x_test"][i], y_test=data["y_test"][i]) -def get_data(out_dir='data'): +def get_data(out_dir="data"): # Make dir if necessary if not os.path.exists(out_dir): os.mkdir(out_dir) # Download data (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() - np.savez(f'{out_dir}/mnist.npz', x_train=x_train, - y_train=y_train, x_test=x_test, y_test=y_test) + np.savez(f"{out_dir}/mnist.npz", x_train=x_train, y_train=y_train, x_test=x_test, y_test=y_test) -if __name__ == '__main__': +if __name__ == "__main__": get_data() split() diff --git a/examples/mnist-pytorch/client/data.py b/examples/mnist-pytorch/client/data.py index d67274548..b921f3132 100644 --- a/examples/mnist-pytorch/client/data.py +++ b/examples/mnist-pytorch/client/data.py @@ -8,22 +8,20 @@ abs_path = os.path.abspath(dir_path) -def get_data(out_dir='data'): +def get_data(out_dir="data"): # Make dir if necessary if not os.path.exists(out_dir): os.mkdir(out_dir) # Only download if not already downloaded - if not os.path.exists(f'{out_dir}/train'): - torchvision.datasets.MNIST( - root=f'{out_dir}/train', transform=torchvision.transforms.ToTensor, train=True, download=True) - if not os.path.exists(f'{out_dir}/test'): - torchvision.datasets.MNIST( - root=f'{out_dir}/test', transform=torchvision.transforms.ToTensor, train=False, download=True) + if not os.path.exists(f"{out_dir}/train"): + torchvision.datasets.MNIST(root=f"{out_dir}/train", transform=torchvision.transforms.ToTensor, train=True, download=True) + if not os.path.exists(f"{out_dir}/test"): + torchvision.datasets.MNIST(root=f"{out_dir}/test", transform=torchvision.transforms.ToTensor, train=False, download=True) def load_data(data_path, is_train=True): - """ Load data from disk. + """Load data from disk. :param data_path: Path to data file. :type data_path: str @@ -33,16 +31,16 @@ def load_data(data_path, is_train=True): :rtype: tuple """ if data_path is None: - data_path = os.environ.get("FEDN_DATA_PATH", abs_path+'/data/clients/1/mnist.pt') + data_path = os.environ.get("FEDN_DATA_PATH", abs_path + "/data/clients/1/mnist.pt") data = torch.load(data_path) if is_train: - X = data['x_train'] - y = data['y_train'] + X = data["x_train"] + y = data["y_train"] else: - X = data['x_test'] - y = data['y_test'] + X = data["x_test"] + y = data["y_test"] # Normalize X = X / 255 @@ -52,49 +50,48 @@ def load_data(data_path, is_train=True): def splitset(dataset, parts): n = dataset.shape[0] - local_n = floor(n/parts) + local_n = floor(n / parts) result = [] for i in range(parts): - result.append(dataset[i*local_n: (i+1)*local_n]) + result.append(dataset[i * local_n : (i + 1) * local_n]) return result -def split(out_dir='data'): - +def split(out_dir="data"): n_splits = int(os.environ.get("FEDN_NUM_DATA_SPLITS", 2)) # Make dir - if not os.path.exists(f'{out_dir}/clients'): - os.mkdir(f'{out_dir}/clients') + if not os.path.exists(f"{out_dir}/clients"): + os.mkdir(f"{out_dir}/clients") # Load and convert to dict - train_data = torchvision.datasets.MNIST( - root=f'{out_dir}/train', transform=torchvision.transforms.ToTensor, train=True) - test_data = torchvision.datasets.MNIST( - root=f'{out_dir}/test', transform=torchvision.transforms.ToTensor, train=False) + train_data = torchvision.datasets.MNIST(root=f"{out_dir}/train", transform=torchvision.transforms.ToTensor, train=True) + test_data = torchvision.datasets.MNIST(root=f"{out_dir}/test", transform=torchvision.transforms.ToTensor, train=False) data = { - 'x_train': splitset(train_data.data, n_splits), - 'y_train': splitset(train_data.targets, n_splits), - 'x_test': splitset(test_data.data, n_splits), - 'y_test': splitset(test_data.targets, n_splits), + "x_train": splitset(train_data.data, n_splits), + "y_train": splitset(train_data.targets, n_splits), + "x_test": splitset(test_data.data, n_splits), + "y_test": splitset(test_data.targets, n_splits), } # Make splits for i in range(n_splits): - subdir = f'{out_dir}/clients/{str(i+1)}' + subdir = f"{out_dir}/clients/{str(i+1)}" if not os.path.exists(subdir): os.mkdir(subdir) - torch.save({ - 'x_train': data['x_train'][i], - 'y_train': data['y_train'][i], - 'x_test': data['x_test'][i], - 'y_test': data['y_test'][i], - }, - f'{subdir}/mnist.pt') - - -if __name__ == '__main__': + torch.save( + { + "x_train": data["x_train"][i], + "y_train": data["y_train"][i], + "x_test": data["x_test"][i], + "y_test": data["y_test"][i], + }, + f"{subdir}/mnist.pt", + ) + + +if __name__ == "__main__": # Prepare data if not already done - if not os.path.exists(abs_path+'/data/clients/1'): + if not os.path.exists(abs_path + "/data/clients/1"): get_data() split() diff --git a/examples/mnist-pytorch/client/model.py b/examples/mnist-pytorch/client/model.py index cb5b7afc2..6ad344770 100644 --- a/examples/mnist-pytorch/client/model.py +++ b/examples/mnist-pytorch/client/model.py @@ -4,16 +4,17 @@ from fedn.utils.helpers.helpers import get_helper -HELPER_MODULE = 'numpyhelper' +HELPER_MODULE = "numpyhelper" helper = get_helper(HELPER_MODULE) def compile_model(): - """ Compile the pytorch model. + """Compile the pytorch model. :return: The compiled model. :rtype: torch.nn.Module """ + class Net(torch.nn.Module): def __init__(self): super(Net, self).__init__() @@ -32,7 +33,7 @@ def forward(self, x): def save_parameters(model, out_path): - """ Save model paramters to file. + """Save model paramters to file. :param model: The model to serialize. :type model: torch.nn.Module @@ -44,7 +45,7 @@ def save_parameters(model, out_path): def load_parameters(model_path): - """ Load model parameters from file and populate model. + """Load model parameters from file and populate model. param model_path: The path to load from. :type model_path: str @@ -60,8 +61,8 @@ def load_parameters(model_path): return model -def init_seed(out_path='seed.npz'): - """ Initialize seed model and save it to file. +def init_seed(out_path="seed.npz"): + """Initialize seed model and save it to file. :param out_path: The path to save the seed model to. :type out_path: str @@ -72,4 +73,4 @@ def init_seed(out_path='seed.npz'): if __name__ == "__main__": - init_seed('../seed.npz') + init_seed("../seed.npz") diff --git a/examples/mnist-pytorch/client/train.py b/examples/mnist-pytorch/client/train.py index fdf2480ae..9ac9cce61 100644 --- a/examples/mnist-pytorch/client/train.py +++ b/examples/mnist-pytorch/client/train.py @@ -3,9 +3,9 @@ import sys import torch -from data import load_data from model import load_parameters, save_parameters +from data import load_data from fedn.utils.helpers.helpers import save_metadata dir_path = os.path.dirname(os.path.realpath(__file__)) @@ -13,7 +13,7 @@ def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1, lr=0.01): - """ Complete a model update. + """Complete a model update. Load model paramters from in_model_path (managed by the FEDn client), perform a model update, and write updated paramters @@ -45,8 +45,8 @@ def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1 for e in range(epochs): # epoch loop for b in range(n_batches): # batch loop # Retrieve current batch - batch_x = x_train[b * batch_size:(b + 1) * batch_size] - batch_y = y_train[b * batch_size:(b + 1) * batch_size] + batch_x = x_train[b * batch_size : (b + 1) * batch_size] + batch_y = y_train[b * batch_size : (b + 1) * batch_size] # Train on batch optimizer.zero_grad() outputs = model(batch_x) @@ -55,16 +55,15 @@ def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1 optimizer.step() # Log if b % 100 == 0: - print( - f"Epoch {e}/{epochs-1} | Batch: {b}/{n_batches-1} | Loss: {loss.item()}") + print(f"Epoch {e}/{epochs-1} | Batch: {b}/{n_batches-1} | Loss: {loss.item()}") # Metadata needed for aggregation server side metadata = { # num_examples are mandatory - 'num_examples': len(x_train), - 'batch_size': batch_size, - 'epochs': epochs, - 'lr': lr + "num_examples": len(x_train), + "batch_size": batch_size, + "epochs": epochs, + "lr": lr, } # Save JSON metadata file (mandatory) diff --git a/examples/mnist-pytorch/client/validate.py b/examples/mnist-pytorch/client/validate.py index 0a5592368..09328181f 100644 --- a/examples/mnist-pytorch/client/validate.py +++ b/examples/mnist-pytorch/client/validate.py @@ -2,9 +2,9 @@ import sys import torch -from data import load_data from model import load_parameters +from data import load_data from fedn.utils.helpers.helpers import save_metrics dir_path = os.path.dirname(os.path.realpath(__file__)) @@ -12,7 +12,7 @@ def validate(in_model_path, out_json_path, data_path=None): - """ Validate model. + """Validate model. :param in_model_path: The path to the input model. :type in_model_path: str @@ -34,12 +34,10 @@ def validate(in_model_path, out_json_path, data_path=None): with torch.no_grad(): train_out = model(x_train) training_loss = criterion(train_out, y_train) - training_accuracy = torch.sum(torch.argmax( - train_out, dim=1) == y_train) / len(train_out) + training_accuracy = torch.sum(torch.argmax(train_out, dim=1) == y_train) / len(train_out) test_out = model(x_test) test_loss = criterion(test_out, y_test) - test_accuracy = torch.sum(torch.argmax( - test_out, dim=1) == y_test) / len(test_out) + test_accuracy = torch.sum(torch.argmax(test_out, dim=1) == y_test) / len(test_out) # JSON schema report = { diff --git a/fedn/cli/client_cmd.py b/fedn/cli/client_cmd.py index f9916985c..e72f29569 100644 --- a/fedn/cli/client_cmd.py +++ b/fedn/cli/client_cmd.py @@ -7,8 +7,7 @@ from fedn.network.clients.client import Client from .main import main -from .shared import (CONTROLLER_DEFAULTS, apply_config, get_api_url, get_token, - print_response) +from .shared import CONTROLLER_DEFAULTS, apply_config, get_api_url, get_token, print_response def validate_client_config(config): @@ -18,16 +17,15 @@ def validate_client_config(config): """ try: - if config['discover_host'] is None or \ - config['discover_host'] == '': + if config["discover_host"] is None or config["discover_host"] == "": raise InvalidClientConfig("Missing required configuration: discover_host") - if 'discover_port' not in config.keys(): - config['discover_port'] = None + if "discover_port" not in config.keys(): + config["discover_port"] = None except Exception: raise InvalidClientConfig("Could not load config from file. Check config") -@main.group('client') +@main.group("client") @click.pass_context def client_cmd(ctx): """ @@ -37,12 +35,12 @@ def client_cmd(ctx): pass -@click.option('-p', '--protocol', required=False, default=CONTROLLER_DEFAULTS['protocol'], help='Communication protocol of controller (api)') -@click.option('-H', '--host', required=False, default=CONTROLLER_DEFAULTS['host'], help='Hostname of controller (api)') -@click.option('-P', '--port', required=False, default=CONTROLLER_DEFAULTS['port'], help='Port of controller (api)') -@click.option('-t', '--token', required=False, help='Authentication token') -@click.option('--n_max', required=False, help='Number of items to list') -@client_cmd.command('list') +@click.option("-p", "--protocol", required=False, default=CONTROLLER_DEFAULTS["protocol"], help="Communication protocol of controller (api)") +@click.option("-H", "--host", required=False, default=CONTROLLER_DEFAULTS["host"], help="Hostname of controller (api)") +@click.option("-P", "--port", required=False, default=CONTROLLER_DEFAULTS["port"], help="Port of controller (api)") +@click.option("-t", "--token", required=False, help="Authentication token") +@click.option("--n_max", required=False, help="Number of items to list") +@client_cmd.command("list") @click.pass_context def list_clients(ctx, protocol: str, host: str, port: str, token: str = None, n_max: int = None): """ @@ -52,53 +50,70 @@ def list_clients(ctx, protocol: str, host: str, port: str, token: str = None, n_ - result: list of clients """ - url = get_api_url(protocol=protocol, host=host, port=port, endpoint='clients') + url = get_api_url(protocol=protocol, host=host, port=port, endpoint="clients") headers = {} if n_max: - headers['X-Limit'] = n_max + headers["X-Limit"] = n_max _token = get_token(token) if _token: - headers['Authorization'] = _token + headers["Authorization"] = _token - click.echo(f'\nListing clients: {url}\n') - click.echo(f'Headers: {headers}') + click.echo(f"\nListing clients: {url}\n") + click.echo(f"Headers: {headers}") try: response = requests.get(url, headers=headers) - print_response(response, 'clients') + print_response(response, "clients") except requests.exceptions.ConnectionError: - click.echo(f'Error: Could not connect to {url}') - - -@client_cmd.command('start') -@click.option('-d', '--discoverhost', required=False, help='Hostname for discovery services(reducer).') -@click.option('-p', '--discoverport', required=False, help='Port for discovery services (reducer).') -@click.option('--token', required=False, help='Set token provided by reducer if enabled') -@click.option('-n', '--name', required=False, default="client" + str(uuid.uuid4())[:8]) -@click.option('-i', '--client_id', required=False) -@click.option('--local-package', is_flag=True, help='Enable local compute package') -@click.option('--force-ssl', is_flag=True, help='Force SSL/TLS for REST service') -@click.option('-u', '--dry-run', required=False, default=False) -@click.option('-s', '--secure', required=False, default=False) -@click.option('-pc', '--preshared-cert', required=False, default=False) -@click.option('-v', '--verify', is_flag=True, help='Verify SSL/TLS for REST service') -@click.option('-c', '--preferred-combiner', required=False, default=False) -@click.option('-va', '--validator', required=False, default=True) -@click.option('-tr', '--trainer', required=False, default=True) -@click.option('-in', '--init', required=False, default=None, - help='Set to a filename to (re)init client from file state.') -@click.option('-l', '--logfile', required=False, default=None, - help='Set logfile for client log to file.') -@click.option('--heartbeat-interval', required=False, default=2) -@click.option('--reconnect-after-missed-heartbeat', required=False, default=30) -@click.option('--verbosity', required=False, default='INFO', type=click.Choice(['CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG'], case_sensitive=False)) + click.echo(f"Error: Could not connect to {url}") + + +@client_cmd.command("start") +@click.option("-d", "--discoverhost", required=False, help="Hostname for discovery services(reducer).") +@click.option("-p", "--discoverport", required=False, help="Port for discovery services (reducer).") +@click.option("--token", required=False, help="Set token provided by reducer if enabled") +@click.option("-n", "--name", required=False, default="client" + str(uuid.uuid4())[:8]) +@click.option("-i", "--client_id", required=False) +@click.option("--local-package", is_flag=True, help="Enable local compute package") +@click.option("--force-ssl", is_flag=True, help="Force SSL/TLS for REST service") +@click.option("-u", "--dry-run", required=False, default=False) +@click.option("-s", "--secure", required=False, default=False) +@click.option("-pc", "--preshared-cert", required=False, default=False) +@click.option("-v", "--verify", is_flag=True, help="Verify SSL/TLS for REST service") +@click.option("-c", "--preferred-combiner", required=False, default=False) +@click.option("-va", "--validator", required=False, default=True) +@click.option("-tr", "--trainer", required=False, default=True) +@click.option("-in", "--init", required=False, default=None, help="Set to a filename to (re)init client from file state.") +@click.option("-l", "--logfile", required=False, default=None, help="Set logfile for client log to file.") +@click.option("--heartbeat-interval", required=False, default=2) +@click.option("--reconnect-after-missed-heartbeat", required=False, default=30) +@click.option("--verbosity", required=False, default="INFO", type=click.Choice(["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"], case_sensitive=False)) @click.pass_context -def client_cmd(ctx, discoverhost, discoverport, token, name, client_id, local_package, force_ssl, dry_run, secure, preshared_cert, - verify, preferred_combiner, validator, trainer, init, logfile, heartbeat_interval, reconnect_after_missed_heartbeat, - verbosity): +def client_cmd( + ctx, + discoverhost, + discoverport, + token, + name, + client_id, + local_package, + force_ssl, + dry_run, + secure, + preshared_cert, + verify, + preferred_combiner, + validator, + trainer, + init, + logfile, + heartbeat_interval, + reconnect_after_missed_heartbeat, + verbosity, +): """ :param ctx: @@ -121,21 +136,36 @@ def client_cmd(ctx, discoverhost, discoverport, token, name, client_id, local_pa :return: """ remote = False if local_package else True - config = {'discover_host': discoverhost, 'discover_port': discoverport, 'token': token, 'name': name, - 'client_id': client_id, 'remote_compute_context': remote, 'force_ssl': force_ssl, 'dry_run': dry_run, 'secure': secure, - 'preshared_cert': preshared_cert, 'verify': verify, 'preferred_combiner': preferred_combiner, - 'validator': validator, 'trainer': trainer, 'logfile': logfile, 'heartbeat_interval': heartbeat_interval, - 'reconnect_after_missed_heartbeat': reconnect_after_missed_heartbeat, 'verbosity': verbosity} + config = { + "discover_host": discoverhost, + "discover_port": discoverport, + "token": token, + "name": name, + "client_id": client_id, + "remote_compute_context": remote, + "force_ssl": force_ssl, + "dry_run": dry_run, + "secure": secure, + "preshared_cert": preshared_cert, + "verify": verify, + "preferred_combiner": preferred_combiner, + "validator": validator, + "trainer": trainer, + "logfile": logfile, + "heartbeat_interval": heartbeat_interval, + "reconnect_after_missed_heartbeat": reconnect_after_missed_heartbeat, + "verbosity": verbosity, + } if init: apply_config(init, config) - click.echo(f'\nClient configuration loaded from file: {init}') - click.echo('Values set in file override defaults and command line arguments...\n') + click.echo(f"\nClient configuration loaded from file: {init}") + click.echo("Values set in file override defaults and command line arguments...\n") try: validate_client_config(config) except InvalidClientConfig as e: - click.echo(f'Error: {e}') + click.echo(f"Error: {e}") return client = Client(config) diff --git a/fedn/cli/combiner_cmd.py b/fedn/cli/combiner_cmd.py index 758dda718..2b4447437 100644 --- a/fedn/cli/combiner_cmd.py +++ b/fedn/cli/combiner_cmd.py @@ -6,11 +6,10 @@ from fedn.network.combiner.combiner import Combiner from .main import main -from .shared import (CONTROLLER_DEFAULTS, apply_config, get_api_url, get_token, - print_response) +from .shared import CONTROLLER_DEFAULTS, apply_config, get_api_url, get_token, print_response -@main.group('combiner') +@main.group("combiner") @click.pass_context def combiner_cmd(ctx): """ @@ -20,19 +19,18 @@ def combiner_cmd(ctx): pass -@combiner_cmd.command('start') -@click.option('-d', '--discoverhost', required=False, help='Hostname for discovery services (reducer).') -@click.option('-p', '--discoverport', required=False, help='Port for discovery services (reducer).') -@click.option('-t', '--token', required=False, help='Set token provided by reducer if enabled') -@click.option('-n', '--name', required=False, default="combiner" + str(uuid.uuid4())[:8], help='Set name for combiner.') -@click.option('-h', '--host', required=False, default="combiner", help='Set hostname.') -@click.option('-i', '--port', required=False, default=12080, help='Set port.') -@click.option('-f', '--fqdn', required=False, default=None, help='Set fully qualified domain name') -@click.option('-s', '--secure', is_flag=True, help='Enable SSL/TLS encrypted gRPC channels.') -@click.option('-v', '--verify', is_flag=True, help='Verify SSL/TLS for REST discovery service (reducer)') -@click.option('-c', '--max_clients', required=False, default=30, help='The maximal number of client connections allowed.') -@click.option('-in', '--init', required=False, default=None, - help='Path to configuration file to (re)init combiner.') +@combiner_cmd.command("start") +@click.option("-d", "--discoverhost", required=False, help="Hostname for discovery services (reducer).") +@click.option("-p", "--discoverport", required=False, help="Port for discovery services (reducer).") +@click.option("-t", "--token", required=False, help="Set token provided by reducer if enabled") +@click.option("-n", "--name", required=False, default="combiner" + str(uuid.uuid4())[:8], help="Set name for combiner.") +@click.option("-h", "--host", required=False, default="combiner", help="Set hostname.") +@click.option("-i", "--port", required=False, default=12080, help="Set port.") +@click.option("-f", "--fqdn", required=False, default=None, help="Set fully qualified domain name") +@click.option("-s", "--secure", is_flag=True, help="Enable SSL/TLS encrypted gRPC channels.") +@click.option("-v", "--verify", is_flag=True, help="Verify SSL/TLS for REST discovery service (reducer)") +@click.option("-c", "--max_clients", required=False, default=30, help="The maximal number of client connections allowed.") +@click.option("-in", "--init", required=False, default=None, help="Path to configuration file to (re)init combiner.") @click.pass_context def start_cmd(ctx, discoverhost, discoverport, token, name, host, port, fqdn, secure, verify, max_clients, init): """ @@ -48,24 +46,34 @@ def start_cmd(ctx, discoverhost, discoverport, token, name, host, port, fqdn, se :param max_clients: :param init: """ - config = {'discover_host': discoverhost, 'discover_port': discoverport, 'token': token, 'host': host, - 'port': port, 'fqdn': fqdn, 'name': name, 'secure': secure, 'verify': verify, 'max_clients': max_clients} + config = { + "discover_host": discoverhost, + "discover_port": discoverport, + "token": token, + "host": host, + "port": port, + "fqdn": fqdn, + "name": name, + "secure": secure, + "verify": verify, + "max_clients": max_clients, + } if init: apply_config(init, config) - click.echo(f'\nCombiner configuration loaded from file: {init}') - click.echo('Values set in file override defaults and command line arguments...\n') + click.echo(f"\nCombiner configuration loaded from file: {init}") + click.echo("Values set in file override defaults and command line arguments...\n") combiner = Combiner(config) combiner.run() -@click.option('-p', '--protocol', required=False, default=CONTROLLER_DEFAULTS['protocol'], help='Communication protocol of controller (api)') -@click.option('-H', '--host', required=False, default=CONTROLLER_DEFAULTS['host'], help='Hostname of controller (api)') -@click.option('-P', '--port', required=False, default=CONTROLLER_DEFAULTS['port'], help='Port of controller (api)') -@click.option('-t', '--token', required=False, help='Authentication token') -@click.option('--n_max', required=False, help='Number of items to list') -@combiner_cmd.command('list') +@click.option("-p", "--protocol", required=False, default=CONTROLLER_DEFAULTS["protocol"], help="Communication protocol of controller (api)") +@click.option("-H", "--host", required=False, default=CONTROLLER_DEFAULTS["host"], help="Hostname of controller (api)") +@click.option("-P", "--port", required=False, default=CONTROLLER_DEFAULTS["port"], help="Port of controller (api)") +@click.option("-t", "--token", required=False, help="Authentication token") +@click.option("--n_max", required=False, help="Number of items to list") +@combiner_cmd.command("list") @click.pass_context def list_combiners(ctx, protocol: str, host: str, port: str, token: str = None, n_max: int = None): """ @@ -75,22 +83,22 @@ def list_combiners(ctx, protocol: str, host: str, port: str, token: str = None, - result: list of combiners """ - url = get_api_url(protocol=protocol, host=host, port=port, endpoint='combiners') + url = get_api_url(protocol=protocol, host=host, port=port, endpoint="combiners") headers = {} if n_max: - headers['X-Limit'] = n_max + headers["X-Limit"] = n_max _token = get_token(token) if _token: - headers['Authorization'] = _token + headers["Authorization"] = _token - click.echo(f'\nListing combiners: {url}\n') - click.echo(f'Headers: {headers}') + click.echo(f"\nListing combiners: {url}\n") + click.echo(f"Headers: {headers}") try: response = requests.get(url, headers=headers) - print_response(response, 'combiners') + print_response(response, "combiners") except requests.exceptions.ConnectionError: - click.echo(f'Error: Could not connect to {url}') + click.echo(f"Error: Could not connect to {url}") diff --git a/fedn/cli/config_cmd.py b/fedn/cli/config_cmd.py index 856882e62..d5286997f 100644 --- a/fedn/cli/config_cmd.py +++ b/fedn/cli/config_cmd.py @@ -5,50 +5,32 @@ from .main import main envs = [ - { - "name": "FEDN_CONTROLLER_PROTOCOL", - "description": "The protocol to use for communication with the controller." - }, - { - "name": "FEDN_CONTROLLER_HOST", - "description": "The host to use for communication with the controller." - }, - { - "name": "FEDN_CONTROLLER_PORT", - "description": "The port to use for communication with the controller." - }, - { - "name": "FEDN_AUTH_TOKEN", - "description": "The authentication token to use for communication with the controller and combiner." - }, - { - "name": "FEDN_AUTH_SCHEME", - "description": "The authentication scheme to use for communication with the controller and combiner." - }, + {"name": "FEDN_CONTROLLER_PROTOCOL", "description": "The protocol to use for communication with the controller."}, + {"name": "FEDN_CONTROLLER_HOST", "description": "The host to use for communication with the controller."}, + {"name": "FEDN_CONTROLLER_PORT", "description": "The port to use for communication with the controller."}, + {"name": "FEDN_AUTH_TOKEN", "description": "The authentication token to use for communication with the controller and combiner."}, + {"name": "FEDN_AUTH_SCHEME", "description": "The authentication scheme to use for communication with the controller and combiner."}, { "name": "FEDN_CONTROLLER_URL", - "description": "The URL of the controller. Overrides FEDN_CONTROLLER_PROTOCOL, FEDN_CONTROLLER_HOST and FEDN_CONTROLLER_PORT." + "description": "The URL of the controller. Overrides FEDN_CONTROLLER_PROTOCOL, FEDN_CONTROLLER_HOST and FEDN_CONTROLLER_PORT.", }, - { - "name": "FEDN_PACKAGE_EXTRACT_DIR", - "description": "The directory to extract packages to." - } + {"name": "FEDN_PACKAGE_EXTRACT_DIR", "description": "The directory to extract packages to."}, ] -@main.group('config', invoke_without_command=True) +@main.group("config", invoke_without_command=True) @click.pass_context def config_cmd(ctx): """ - Configuration commands for the FEDn CLI. """ if ctx.invoked_subcommand is None: - click.echo('\n--- FEDn Cli Configuration ---\n') - click.echo('Current configuration:\n') + click.echo("\n--- FEDn Cli Configuration ---\n") + click.echo("Current configuration:\n") for env in envs: - name = env['name'] + name = env["name"] value = os.environ.get(name) click.echo(f'{name}: {value or "Not set"}') click.echo(f'{env["description"]}\n') - click.echo('\n') + click.echo("\n") diff --git a/fedn/cli/main.py b/fedn/cli/main.py index d004c9605..52276c418 100644 --- a/fedn/cli/main.py +++ b/fedn/cli/main.py @@ -2,7 +2,7 @@ CONTEXT_SETTINGS = dict( # Support -h as a shortcut for --help - help_option_names=['-h', '--help'], + help_option_names=["-h", "--help"], ) diff --git a/fedn/cli/model_cmd.py b/fedn/cli/model_cmd.py index ddccd9e2d..e44793a9f 100644 --- a/fedn/cli/model_cmd.py +++ b/fedn/cli/model_cmd.py @@ -1,4 +1,3 @@ - import click import requests @@ -6,7 +5,7 @@ from .shared import CONTROLLER_DEFAULTS, get_api_url, get_token, print_response -@main.group('model') +@main.group("model") @click.pass_context def model_cmd(ctx): """ @@ -16,12 +15,12 @@ def model_cmd(ctx): pass -@click.option('-p', '--protocol', required=False, default=CONTROLLER_DEFAULTS['protocol'], help='Communication protocol of controller (api)') -@click.option('-H', '--host', required=False, default=CONTROLLER_DEFAULTS['host'], help='Hostname of controller (api)') -@click.option('-P', '--port', required=False, default=CONTROLLER_DEFAULTS['port'], help='Port of controller (api)') -@click.option('-t', '--token', required=False, help='Authentication token') -@click.option('--n_max', required=False, help='Number of items to list') -@model_cmd.command('list') +@click.option("-p", "--protocol", required=False, default=CONTROLLER_DEFAULTS["protocol"], help="Communication protocol of controller (api)") +@click.option("-H", "--host", required=False, default=CONTROLLER_DEFAULTS["host"], help="Hostname of controller (api)") +@click.option("-P", "--port", required=False, default=CONTROLLER_DEFAULTS["port"], help="Port of controller (api)") +@click.option("-t", "--token", required=False, help="Authentication token") +@click.option("--n_max", required=False, help="Number of items to list") +@model_cmd.command("list") @click.pass_context def list_models(ctx, protocol: str, host: str, port: str, token: str = None, n_max: int = None): """ @@ -31,22 +30,22 @@ def list_models(ctx, protocol: str, host: str, port: str, token: str = None, n_m - result: list of models """ - url = get_api_url(protocol=protocol, host=host, port=port, endpoint='models') + url = get_api_url(protocol=protocol, host=host, port=port, endpoint="models") headers = {} if n_max: - headers['X-Limit'] = n_max + headers["X-Limit"] = n_max _token = get_token(token) if _token: - headers['Authorization'] = _token + headers["Authorization"] = _token - click.echo(f'\nListing models: {url}\n') - click.echo(f'Headers: {headers}') + click.echo(f"\nListing models: {url}\n") + click.echo(f"Headers: {headers}") try: response = requests.get(url, headers=headers) - print_response(response, 'models') + print_response(response, "models") except requests.exceptions.ConnectionError: - click.echo(f'Error: Could not connect to {url}') + click.echo(f"Error: Could not connect to {url}") diff --git a/fedn/cli/package_cmd.py b/fedn/cli/package_cmd.py index a19ed2f9e..6d503d414 100644 --- a/fedn/cli/package_cmd.py +++ b/fedn/cli/package_cmd.py @@ -10,7 +10,7 @@ from .shared import CONTROLLER_DEFAULTS, get_api_url, get_token, print_response -@main.group('package') +@main.group("package") @click.pass_context def package_cmd(ctx): """ @@ -20,12 +20,12 @@ def package_cmd(ctx): pass -@package_cmd.command('create') -@click.option('-p', '--path', required=True, help='Path to package directory containing fedn.yaml') -@click.option('-n', '--name', required=False, default='package.tgz', help='Name of package tarball') +@package_cmd.command("create") +@click.option("-p", "--path", required=True, help="Path to package directory containing fedn.yaml") +@click.option("-n", "--name", required=False, default="package.tgz", help="Name of package tarball") @click.pass_context def create_cmd(ctx, path, name): - """ Create compute package. + """Create compute package. Make a tar.gz archive of folder given by --path @@ -33,7 +33,7 @@ def create_cmd(ctx, path, name): :param path: """ path = os.path.abspath(path) - yaml_file = os.path.join(path, 'fedn.yaml') + yaml_file = os.path.join(path, "fedn.yaml") if not os.path.exists(yaml_file): logger.error(f"Could not find fedn.yaml in {path}") exit(-1) @@ -43,12 +43,12 @@ def create_cmd(ctx, path, name): logger.info(f"Created package {name}") -@click.option('-p', '--protocol', required=False, default=CONTROLLER_DEFAULTS['protocol'], help='Communication protocol of controller (api)') -@click.option('-H', '--host', required=False, default=CONTROLLER_DEFAULTS['host'], help='Hostname of controller (api)') -@click.option('-P', '--port', required=False, default=CONTROLLER_DEFAULTS['port'], help='Port of controller (api)') -@click.option('-t', '--token', required=False, help='Authentication token') -@click.option('--n_max', required=False, help='Number of items to list') -@package_cmd.command('list') +@click.option("-p", "--protocol", required=False, default=CONTROLLER_DEFAULTS["protocol"], help="Communication protocol of controller (api)") +@click.option("-H", "--host", required=False, default=CONTROLLER_DEFAULTS["host"], help="Hostname of controller (api)") +@click.option("-P", "--port", required=False, default=CONTROLLER_DEFAULTS["port"], help="Port of controller (api)") +@click.option("-t", "--token", required=False, help="Authentication token") +@click.option("--n_max", required=False, help="Number of items to list") +@package_cmd.command("list") @click.pass_context def list_packages(ctx, protocol: str, host: str, port: str, token: str = None, n_max: int = None): """ @@ -58,22 +58,22 @@ def list_packages(ctx, protocol: str, host: str, port: str, token: str = None, n - result: list of packages """ - url = get_api_url(protocol=protocol, host=host, port=port, endpoint='packages') + url = get_api_url(protocol=protocol, host=host, port=port, endpoint="packages") headers = {} if n_max: - headers['X-Limit'] = n_max + headers["X-Limit"] = n_max _token = get_token(token) if _token: - headers['Authorization'] = _token + headers["Authorization"] = _token - click.echo(f'\nListing packages: {url}\n') - click.echo(f'Headers: {headers}') + click.echo(f"\nListing packages: {url}\n") + click.echo(f"Headers: {headers}") try: response = requests.get(url, headers=headers) - print_response(response, 'packages') + print_response(response, "packages") except requests.exceptions.ConnectionError: - click.echo(f'Error: Could not connect to {url}') + click.echo(f"Error: Could not connect to {url}") diff --git a/fedn/cli/round_cmd.py b/fedn/cli/round_cmd.py index 31f4accc4..ca23cafe7 100644 --- a/fedn/cli/round_cmd.py +++ b/fedn/cli/round_cmd.py @@ -1,4 +1,3 @@ - import click import requests @@ -6,7 +5,7 @@ from .shared import CONTROLLER_DEFAULTS, get_api_url, get_token, print_response -@main.group('round') +@main.group("round") @click.pass_context def round_cmd(ctx): """ @@ -16,12 +15,12 @@ def round_cmd(ctx): pass -@click.option('-p', '--protocol', required=False, default=CONTROLLER_DEFAULTS['protocol'], help='Communication protocol of controller (api)') -@click.option('-H', '--host', required=False, default=CONTROLLER_DEFAULTS['host'], help='Hostname of controller (api)') -@click.option('-P', '--port', required=False, default=CONTROLLER_DEFAULTS['port'], help='Port of controller (api)') -@click.option('-t', '--token', required=False, help='Authentication token') -@click.option('--n_max', required=False, help='Number of items to list') -@round_cmd.command('list') +@click.option("-p", "--protocol", required=False, default=CONTROLLER_DEFAULTS["protocol"], help="Communication protocol of controller (api)") +@click.option("-H", "--host", required=False, default=CONTROLLER_DEFAULTS["host"], help="Hostname of controller (api)") +@click.option("-P", "--port", required=False, default=CONTROLLER_DEFAULTS["port"], help="Port of controller (api)") +@click.option("-t", "--token", required=False, help="Authentication token") +@click.option("--n_max", required=False, help="Number of items to list") +@round_cmd.command("list") @click.pass_context def list_rounds(ctx, protocol: str, host: str, port: str, token: str = None, n_max: int = None): """ @@ -31,22 +30,22 @@ def list_rounds(ctx, protocol: str, host: str, port: str, token: str = None, n_m - result: list of rounds """ - url = get_api_url(protocol=protocol, host=host, port=port, endpoint='rounds') + url = get_api_url(protocol=protocol, host=host, port=port, endpoint="rounds") headers = {} if n_max: - headers['X-Limit'] = n_max + headers["X-Limit"] = n_max _token = get_token(token) if _token: - headers['Authorization'] = _token + headers["Authorization"] = _token - click.echo(f'\nListing rounds: {url}\n') - click.echo(f'Headers: {headers}') + click.echo(f"\nListing rounds: {url}\n") + click.echo(f"Headers: {headers}") try: response = requests.get(url, headers=headers) - print_response(response, 'rounds') + print_response(response, "rounds") except requests.exceptions.ConnectionError: - click.echo(f'Error: Could not connect to {url}') + click.echo(f"Error: Could not connect to {url}") diff --git a/fedn/cli/run_cmd.py b/fedn/cli/run_cmd.py index 87a54f7f1..b9fe4528e 100644 --- a/fedn/cli/run_cmd.py +++ b/fedn/cli/run_cmd.py @@ -22,7 +22,7 @@ def get_statestore_config_from_file(init): :param init: :return: """ - with open(init, 'r') as file: + with open(init, "r") as file: try: settings = dict(yaml.safe_load(file)) return settings @@ -31,7 +31,7 @@ def get_statestore_config_from_file(init): def check_helper_config_file(config): - control = config['control'] + control = config["control"] try: helper = control["helper"] except KeyError: @@ -40,7 +40,7 @@ def check_helper_config_file(config): return helper -@main.group('run') +@main.group("run") @click.pass_context def run_cmd(ctx): """ @@ -50,25 +50,25 @@ def run_cmd(ctx): pass -@run_cmd.command('build') -@click.option('-p', '--path', required=True, help='Path to package directory containing fedn.yaml') +@run_cmd.command("build") +@click.option("-p", "--path", required=True, help="Path to package directory containing fedn.yaml") @click.pass_context def build_cmd(ctx, path): - """ Execute 'build' entrypoint in fedn.yaml. + """Execute 'build' entrypoint in fedn.yaml. :param ctx: :param path: Path to folder containing fedn.yaml :type path: str """ path = os.path.abspath(path) - yaml_file = os.path.join(path, 'fedn.yaml') + yaml_file = os.path.join(path, "fedn.yaml") if not os.path.exists(yaml_file): logger.error(f"Could not find fedn.yaml in {path}") exit(-1) config = _read_yaml_file(yaml_file) # Check that build is defined in fedn.yaml under entry_points - if 'build' not in config['entry_points']: + if "build" not in config["entry_points"]: logger.error("No build command defined in fedn.yaml") exit(-1) @@ -82,32 +82,49 @@ def build_cmd(ctx, path): shutil.rmtree(dispatcher.python_env_path) -@run_cmd.command('client') -@click.option('-d', '--discoverhost', required=False, help='Hostname for discovery services(reducer).') -@click.option('-p', '--discoverport', required=False, help='Port for discovery services (reducer).') -@click.option('--token', required=False, help='Set token provided by reducer if enabled') -@click.option('-n', '--name', required=False, default="client" + str(uuid.uuid4())[:8]) -@click.option('-i', '--client_id', required=False) -@click.option('--local-package', is_flag=True, help='Enable local compute package') -@click.option('--force-ssl', is_flag=True, help='Force SSL/TLS for REST service') -@click.option('-u', '--dry-run', required=False, default=False) -@click.option('-s', '--secure', required=False, default=False) -@click.option('-pc', '--preshared-cert', required=False, default=False) -@click.option('-v', '--verify', is_flag=True, help='Verify SSL/TLS for REST service') -@click.option('-c', '--preferred-combiner', required=False, default=False) -@click.option('-va', '--validator', required=False, default=True) -@click.option('-tr', '--trainer', required=False, default=True) -@click.option('-in', '--init', required=False, default=None, - help='Set to a filename to (re)init client from file state.') -@click.option('-l', '--logfile', required=False, default=None, - help='Set logfile for client log to file.') -@click.option('--heartbeat-interval', required=False, default=2) -@click.option('--reconnect-after-missed-heartbeat', required=False, default=30) -@click.option('--verbosity', required=False, default='INFO', type=click.Choice(['CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG'], case_sensitive=False)) +@run_cmd.command("client") +@click.option("-d", "--discoverhost", required=False, help="Hostname for discovery services(reducer).") +@click.option("-p", "--discoverport", required=False, help="Port for discovery services (reducer).") +@click.option("--token", required=False, help="Set token provided by reducer if enabled") +@click.option("-n", "--name", required=False, default="client" + str(uuid.uuid4())[:8]) +@click.option("-i", "--client_id", required=False) +@click.option("--local-package", is_flag=True, help="Enable local compute package") +@click.option("--force-ssl", is_flag=True, help="Force SSL/TLS for REST service") +@click.option("-u", "--dry-run", required=False, default=False) +@click.option("-s", "--secure", required=False, default=False) +@click.option("-pc", "--preshared-cert", required=False, default=False) +@click.option("-v", "--verify", is_flag=True, help="Verify SSL/TLS for REST service") +@click.option("-c", "--preferred-combiner", required=False, default=False) +@click.option("-va", "--validator", required=False, default=True) +@click.option("-tr", "--trainer", required=False, default=True) +@click.option("-in", "--init", required=False, default=None, help="Set to a filename to (re)init client from file state.") +@click.option("-l", "--logfile", required=False, default=None, help="Set logfile for client log to file.") +@click.option("--heartbeat-interval", required=False, default=2) +@click.option("--reconnect-after-missed-heartbeat", required=False, default=30) +@click.option("--verbosity", required=False, default="INFO", type=click.Choice(["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"], case_sensitive=False)) @click.pass_context -def client_cmd(ctx, discoverhost, discoverport, token, name, client_id, local_package, force_ssl, dry_run, secure, preshared_cert, - verify, preferred_combiner, validator, trainer, init, logfile, heartbeat_interval, reconnect_after_missed_heartbeat, - verbosity): +def client_cmd( + ctx, + discoverhost, + discoverport, + token, + name, + client_id, + local_package, + force_ssl, + dry_run, + secure, + preshared_cert, + verify, + preferred_combiner, + validator, + trainer, + init, + logfile, + heartbeat_interval, + reconnect_after_missed_heartbeat, + verbosity, +): """ :param ctx: @@ -130,49 +147,58 @@ def client_cmd(ctx, discoverhost, discoverport, token, name, client_id, local_pa :return: """ remote = False if local_package else True - config = {'discover_host': discoverhost, 'discover_port': discoverport, 'token': token, 'name': name, - 'client_id': client_id, 'remote_compute_context': remote, 'force_ssl': force_ssl, 'dry_run': dry_run, 'secure': secure, - 'preshared_cert': preshared_cert, 'verify': verify, 'preferred_combiner': preferred_combiner, - 'validator': validator, 'trainer': trainer, 'logfile': logfile, 'heartbeat_interval': heartbeat_interval, - 'reconnect_after_missed_heartbeat': reconnect_after_missed_heartbeat, 'verbosity': verbosity} + config = { + "discover_host": discoverhost, + "discover_port": discoverport, + "token": token, + "name": name, + "client_id": client_id, + "remote_compute_context": remote, + "force_ssl": force_ssl, + "dry_run": dry_run, + "secure": secure, + "preshared_cert": preshared_cert, + "verify": verify, + "preferred_combiner": preferred_combiner, + "validator": validator, + "trainer": trainer, + "logfile": logfile, + "heartbeat_interval": heartbeat_interval, + "reconnect_after_missed_heartbeat": reconnect_after_missed_heartbeat, + "verbosity": verbosity, + } click.echo( - click.style( - '\n*** fedn run client is deprecated and will be removed. Please use fedn client start instead. ***\n', - blink=True, - bold=True, - fg='red' - ) + click.style("\n*** fedn run client is deprecated and will be removed. Please use fedn client start instead. ***\n", blink=True, bold=True, fg="red") ) if init: apply_config(init, config) - click.echo(f'\nClient configuration loaded from file: {init}') - click.echo('Values set in file override defaults and command line arguments...\n') + click.echo(f"\nClient configuration loaded from file: {init}") + click.echo("Values set in file override defaults and command line arguments...\n") try: validate_client_config(config) except InvalidClientConfig as e: - click.echo(f'Error: {e}') + click.echo(f"Error: {e}") return client = Client(config) client.run() -@run_cmd.command('combiner') -@click.option('-d', '--discoverhost', required=False, help='Hostname for discovery services (reducer).') -@click.option('-p', '--discoverport', required=False, help='Port for discovery services (reducer).') -@click.option('-t', '--token', required=False, help='Set token provided by reducer if enabled') -@click.option('-n', '--name', required=False, default="combiner" + str(uuid.uuid4())[:8], help='Set name for combiner.') -@click.option('-h', '--host', required=False, default="combiner", help='Set hostname.') -@click.option('-i', '--port', required=False, default=12080, help='Set port.') -@click.option('-f', '--fqdn', required=False, default=None, help='Set fully qualified domain name') -@click.option('-s', '--secure', is_flag=True, help='Enable SSL/TLS encrypted gRPC channels.') -@click.option('-v', '--verify', is_flag=True, help='Verify SSL/TLS for REST discovery service (reducer)') -@click.option('-c', '--max_clients', required=False, default=30, help='The maximal number of client connections allowed.') -@click.option('-in', '--init', required=False, default=None, - help='Path to configuration file to (re)init combiner.') +@run_cmd.command("combiner") +@click.option("-d", "--discoverhost", required=False, help="Hostname for discovery services (reducer).") +@click.option("-p", "--discoverport", required=False, help="Port for discovery services (reducer).") +@click.option("-t", "--token", required=False, help="Set token provided by reducer if enabled") +@click.option("-n", "--name", required=False, default="combiner" + str(uuid.uuid4())[:8], help="Set name for combiner.") +@click.option("-h", "--host", required=False, default="combiner", help="Set hostname.") +@click.option("-i", "--port", required=False, default=12080, help="Set port.") +@click.option("-f", "--fqdn", required=False, default=None, help="Set fully qualified domain name") +@click.option("-s", "--secure", is_flag=True, help="Enable SSL/TLS encrypted gRPC channels.") +@click.option("-v", "--verify", is_flag=True, help="Verify SSL/TLS for REST discovery service (reducer)") +@click.option("-c", "--max_clients", required=False, default=30, help="The maximal number of client connections allowed.") +@click.option("-in", "--init", required=False, default=None, help="Path to configuration file to (re)init combiner.") @click.pass_context def combiner_cmd(ctx, discoverhost, discoverport, token, name, host, port, fqdn, secure, verify, max_clients, init): """ @@ -188,22 +214,27 @@ def combiner_cmd(ctx, discoverhost, discoverport, token, name, host, port, fqdn, :param max_clients: :param init: """ - config = {'discover_host': discoverhost, 'discover_port': discoverport, 'token': token, 'host': host, - 'port': port, 'fqdn': fqdn, 'name': name, 'secure': secure, 'verify': verify, 'max_clients': max_clients} + config = { + "discover_host": discoverhost, + "discover_port": discoverport, + "token": token, + "host": host, + "port": port, + "fqdn": fqdn, + "name": name, + "secure": secure, + "verify": verify, + "max_clients": max_clients, + } click.echo( - click.style( - '\n*** fedn run combiner is deprecated and will be removed. Please use fedn combiner start instead. ***\n', - blink=True, - bold=True, - fg='red' - ) + click.style("\n*** fedn run combiner is deprecated and will be removed. Please use fedn combiner start instead. ***\n", blink=True, bold=True, fg="red") ) if init: apply_config(init, config) - click.echo(f'\nCombiner configuration loaded from file: {init}') - click.echo('Values set in file override defaults and command line arguments...\n') + click.echo(f"\nCombiner configuration loaded from file: {init}") + click.echo("Values set in file override defaults and command line arguments...\n") combiner = Combiner(config) combiner.run() diff --git a/fedn/cli/session_cmd.py b/fedn/cli/session_cmd.py index 37eb3a8a6..55597b5b3 100644 --- a/fedn/cli/session_cmd.py +++ b/fedn/cli/session_cmd.py @@ -1,4 +1,3 @@ - import click import requests @@ -6,7 +5,7 @@ from .shared import CONTROLLER_DEFAULTS, get_api_url, get_token, print_response -@main.group('session') +@main.group("session") @click.pass_context def session_cmd(ctx): """ @@ -16,12 +15,12 @@ def session_cmd(ctx): pass -@click.option('-p', '--protocol', required=False, default=CONTROLLER_DEFAULTS['protocol'], help='Communication protocol of controller (api)') -@click.option('-H', '--host', required=False, default=CONTROLLER_DEFAULTS['host'], help='Hostname of controller (api)') -@click.option('-P', '--port', required=False, default=CONTROLLER_DEFAULTS['port'], help='Port of controller (api)') -@click.option('-t', '--token', required=False, help='Authentication token') -@click.option('--n_max', required=False, help='Number of items to list') -@session_cmd.command('list') +@click.option("-p", "--protocol", required=False, default=CONTROLLER_DEFAULTS["protocol"], help="Communication protocol of controller (api)") +@click.option("-H", "--host", required=False, default=CONTROLLER_DEFAULTS["host"], help="Hostname of controller (api)") +@click.option("-P", "--port", required=False, default=CONTROLLER_DEFAULTS["port"], help="Port of controller (api)") +@click.option("-t", "--token", required=False, help="Authentication token") +@click.option("--n_max", required=False, help="Number of items to list") +@session_cmd.command("list") @click.pass_context def list_sessions(ctx, protocol: str, host: str, port: str, token: str = None, n_max: int = None): """ @@ -31,22 +30,22 @@ def list_sessions(ctx, protocol: str, host: str, port: str, token: str = None, n - result: list of sessions """ - url = get_api_url(protocol=protocol, host=host, port=port, endpoint='sessions') + url = get_api_url(protocol=protocol, host=host, port=port, endpoint="sessions") headers = {} if n_max: - headers['X-Limit'] = n_max + headers["X-Limit"] = n_max _token = get_token(token) if _token: - headers['Authorization'] = _token + headers["Authorization"] = _token - click.echo(f'\nListing sessions: {url}\n') - click.echo(f'Headers: {headers}') + click.echo(f"\nListing sessions: {url}\n") + click.echo(f"Headers: {headers}") try: response = requests.get(url, headers=headers) - print_response(response, 'sessions') + print_response(response, "sessions") except requests.exceptions.ConnectionError: - click.echo(f'Error: Could not connect to {url}') + click.echo(f"Error: Could not connect to {url}") diff --git a/fedn/cli/shared.py b/fedn/cli/shared.py index 81e5ebd07..2500d9e2b 100644 --- a/fedn/cli/shared.py +++ b/fedn/cli/shared.py @@ -5,28 +5,16 @@ from fedn.common.log_config import logger -CONTROLLER_DEFAULTS = { - 'protocol': 'http', - 'host': 'localhost', - 'port': 8092, - 'debug': False -} +CONTROLLER_DEFAULTS = {"protocol": "http", "host": "localhost", "port": 8092, "debug": False} -COMBINER_DEFAULTS = { - 'discover_host': 'localhost', - 'discover_port': 8092, - 'host': 'localhost', - 'port': 12080, - "name": "combiner", - "max_clients": 30 -} +COMBINER_DEFAULTS = {"discover_host": "localhost", "discover_port": 8092, "host": "localhost", "port": 12080, "name": "combiner", "max_clients": 30} CLIENT_DEFAULTS = { - 'discover_host': 'localhost', - 'discover_port': 8092, + "discover_host": "localhost", + "discover_port": 8092, } -API_VERSION = 'v1' +API_VERSION = "v1" def apply_config(path: str, config: dict): @@ -36,11 +24,11 @@ def apply_config(path: str, config: dict): :param config: Client config (dict). """ - with open(path, 'r') as file: + with open(path, "r") as file: try: settings = dict(yaml.safe_load(file)) except Exception: - logger.error('Failed to read config from settings file, exiting.') + logger.error("Failed to read config from settings file, exiting.") return for key, val in settings.items(): @@ -48,16 +36,16 @@ def apply_config(path: str, config: dict): def get_api_url(protocol: str, host: str, port: str, endpoint: str) -> str: - _url = os.environ.get('FEDN_CONTROLLER_URL') + _url = os.environ.get("FEDN_CONTROLLER_URL") if _url: - return f'{_url}/api/{API_VERSION}/{endpoint}/' + return f"{_url}/api/{API_VERSION}/{endpoint}/" - _protocol = protocol or os.environ.get('FEDN_CONTROLLER_PROTOCOL') or CONTROLLER_DEFAULTS['protocol'] - _host = host or os.environ.get('FEDN_CONTROLLER_HOST') or CONTROLLER_DEFAULTS['host'] - _port = port or os.environ.get('FEDN_CONTROLLER_PORT') or CONTROLLER_DEFAULTS['port'] + _protocol = protocol or os.environ.get("FEDN_CONTROLLER_PROTOCOL") or CONTROLLER_DEFAULTS["protocol"] + _host = host or os.environ.get("FEDN_CONTROLLER_HOST") or CONTROLLER_DEFAULTS["host"] + _port = port or os.environ.get("FEDN_CONTROLLER_PORT") or CONTROLLER_DEFAULTS["port"] - return f'{_protocol}://{_host}:{_port}/api/{API_VERSION}/{endpoint}/' + return f"{_protocol}://{_host}:{_port}/api/{API_VERSION}/{endpoint}/" def get_token(token: str) -> str: @@ -72,7 +60,7 @@ def get_token(token: str) -> str: def get_client_package_dir(path: str) -> str: - return path or os.environ.get('FEDN_PACKAGE_DIR', None) + return path or os.environ.get("FEDN_PACKAGE_DIR", None) # Print response from api (list of entities) @@ -90,15 +78,15 @@ def print_response(response, entity_name: str): if response.status_code == 200: json_data = response.json() count, result = json_data.values() - click.echo(f'Found {count} {entity_name}') - click.echo('\n---------------------------------\n') + click.echo(f"Found {count} {entity_name}") + click.echo("\n---------------------------------\n") for obj in result: - click.echo('{') + click.echo("{") for k, v in obj.items(): - click.echo(f'\t{k}: {v}') - click.echo('}') + click.echo(f"\t{k}: {v}") + click.echo("}") elif response.status_code == 500: json_data = response.json() click.echo(f'Error: {json_data["message"]}') else: - click.echo(f'Error: {response.status_code}') + click.echo(f"Error: {response.status_code}") diff --git a/fedn/cli/status_cmd.py b/fedn/cli/status_cmd.py index 457fc9c00..a4f17e349 100644 --- a/fedn/cli/status_cmd.py +++ b/fedn/cli/status_cmd.py @@ -5,7 +5,7 @@ from .shared import CONTROLLER_DEFAULTS, get_api_url, get_token, print_response -@main.group('status') +@main.group("status") @click.pass_context def status_cmd(ctx): """ @@ -15,12 +15,12 @@ def status_cmd(ctx): pass -@click.option('-p', '--protocol', required=False, default=CONTROLLER_DEFAULTS['protocol'], help='Communication protocol of controller (api)') -@click.option('-H', '--host', required=False, default=CONTROLLER_DEFAULTS['host'], help='Hostname of controller (api)') -@click.option('-P', '--port', required=False, default=CONTROLLER_DEFAULTS['port'], help='Port of controller (api)') -@click.option('-t', '--token', required=False, help='Authentication token') -@click.option('--n_max', required=False, help='Number of items to list') -@status_cmd.command('list') +@click.option("-p", "--protocol", required=False, default=CONTROLLER_DEFAULTS["protocol"], help="Communication protocol of controller (api)") +@click.option("-H", "--host", required=False, default=CONTROLLER_DEFAULTS["host"], help="Hostname of controller (api)") +@click.option("-P", "--port", required=False, default=CONTROLLER_DEFAULTS["port"], help="Port of controller (api)") +@click.option("-t", "--token", required=False, help="Authentication token") +@click.option("--n_max", required=False, help="Number of items to list") +@status_cmd.command("list") @click.pass_context def list_statuses(ctx, protocol: str, host: str, port: str, token: str = None, n_max: int = None): """ @@ -30,22 +30,22 @@ def list_statuses(ctx, protocol: str, host: str, port: str, token: str = None, n - result: list of statuses """ - url = get_api_url(protocol=protocol, host=host, port=port, endpoint='statuses') + url = get_api_url(protocol=protocol, host=host, port=port, endpoint="statuses") headers = {} if n_max: - headers['X-Limit'] = n_max + headers["X-Limit"] = n_max _token = get_token(token) if _token: - headers['Authorization'] = _token + headers["Authorization"] = _token - click.echo(f'\nListing statuses: {url}\n') - click.echo(f'Headers: {headers}') + click.echo(f"\nListing statuses: {url}\n") + click.echo(f"Headers: {headers}") try: response = requests.get(url, headers=headers) - print_response(response, 'statuses') + print_response(response, "statuses") except requests.exceptions.ConnectionError: - click.echo(f'Error: Could not connect to {url}') + click.echo(f"Error: Could not connect to {url}") diff --git a/fedn/cli/validation_cmd.py b/fedn/cli/validation_cmd.py index 3707f9bb8..055be0c65 100644 --- a/fedn/cli/validation_cmd.py +++ b/fedn/cli/validation_cmd.py @@ -5,7 +5,7 @@ from .shared import CONTROLLER_DEFAULTS, get_api_url, get_token, print_response -@main.group('validation') +@main.group("validation") @click.pass_context def validation_cmd(ctx): """ @@ -15,12 +15,12 @@ def validation_cmd(ctx): pass -@click.option('-p', '--protocol', required=False, default=CONTROLLER_DEFAULTS['protocol'], help='Communication protocol of controller (api)') -@click.option('-H', '--host', required=False, default=CONTROLLER_DEFAULTS['host'], help='Hostname of controller (api)') -@click.option('-P', '--port', required=False, default=CONTROLLER_DEFAULTS['port'], help='Port of controller (api)') -@click.option('-t', '--token', required=False, help='Authentication token') -@click.option('--n_max', required=False, help='Number of items to list') -@validation_cmd.command('list') +@click.option("-p", "--protocol", required=False, default=CONTROLLER_DEFAULTS["protocol"], help="Communication protocol of controller (api)") +@click.option("-H", "--host", required=False, default=CONTROLLER_DEFAULTS["host"], help="Hostname of controller (api)") +@click.option("-P", "--port", required=False, default=CONTROLLER_DEFAULTS["port"], help="Port of controller (api)") +@click.option("-t", "--token", required=False, help="Authentication token") +@click.option("--n_max", required=False, help="Number of items to list") +@validation_cmd.command("list") @click.pass_context def list_validations(ctx, protocol: str, host: str, port: str, token: str = None, n_max: int = None): """ @@ -30,22 +30,22 @@ def list_validations(ctx, protocol: str, host: str, port: str, token: str = None - result: list of validations """ - url = get_api_url(protocol=protocol, host=host, port=port, endpoint='validations') + url = get_api_url(protocol=protocol, host=host, port=port, endpoint="validations") headers = {} if n_max: - headers['X-Limit'] = n_max + headers["X-Limit"] = n_max _token = get_token(token) if _token: - headers['Authorization'] = _token + headers["Authorization"] = _token - click.echo(f'\nListing validations: {url}\n') - click.echo(f'Headers: {headers}') + click.echo(f"\nListing validations: {url}\n") + click.echo(f"Headers: {headers}") try: response = requests.get(url, headers=headers) - print_response(response, 'validations') + print_response(response, "validations") except requests.exceptions.ConnectionError: - click.echo(f'Error: Could not connect to {url}') + click.echo(f"Error: Could not connect to {url}") diff --git a/fedn/common/certificate/certificate.py b/fedn/common/certificate/certificate.py index a2c059748..857a05e7c 100644 --- a/fedn/common/certificate/certificate.py +++ b/fedn/common/certificate/certificate.py @@ -13,19 +13,18 @@ class Certificate: Utility to generate unsigned certificates. """ + CERT_NAME = "cert.pem" KEY_NAME = "key.pem" BITS = 2048 def __init__(self, cwd, name=None, key_name="key.pem", cert_name="cert.pem", create_dirs=True): - try: os.makedirs(cwd) except OSError: logger.info("Directory exists, will store all cert and keys here.") else: - logger.info( - "Successfully created the directory to store cert and keys in {}".format(cwd)) + logger.info("Successfully created the directory to store cert and keys in {}".format(cwd)) self.key_path = os.path.join(cwd, key_name) self.cert_path = os.path.join(cwd, cert_name) @@ -35,7 +34,9 @@ def __init__(self, cwd, name=None, key_name="key.pem", cert_name="cert.pem", cre else: self.name = str(uuid.uuid4()) - def gen_keypair(self, ): + def gen_keypair( + self, + ): """ Generate keypair. @@ -70,21 +71,19 @@ def set_keypair_raw(self, certificate, privatekey): :param privatekey: """ with open(self.key_path, "wb") as keyfile: - keyfile.write(crypto.dump_privatekey( - crypto.FILETYPE_PEM, privatekey)) + keyfile.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, privatekey)) with open(self.cert_path, "wb") as certfile: - certfile.write(crypto.dump_certificate( - crypto.FILETYPE_PEM, certificate)) + certfile.write(crypto.dump_certificate(crypto.FILETYPE_PEM, certificate)) def get_keypair_raw(self): """ :return: """ - with open(self.key_path, 'rb') as keyfile: + with open(self.key_path, "rb") as keyfile: key_buf = keyfile.read() - with open(self.cert_path, 'rb') as certfile: + with open(self.cert_path, "rb") as certfile: cert_buf = certfile.read() return copy.deepcopy(cert_buf), copy.deepcopy(key_buf) @@ -93,7 +92,7 @@ def get_key(self): :return: """ - with open(self.key_path, 'rb') as keyfile: + with open(self.key_path, "rb") as keyfile: key_buf = keyfile.read() key = crypto.load_privatekey(crypto.FILETYPE_PEM, key_buf) return key @@ -103,7 +102,7 @@ def get_cert(self): :return: """ - with open(self.cert_path, 'rb') as certfile: + with open(self.cert_path, "rb") as certfile: cert_buf = certfile.read() cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert_buf) return cert diff --git a/fedn/common/certificate/certificatemanager.py b/fedn/common/certificate/certificatemanager.py index 3d34fa1ad..ce165d862 100644 --- a/fedn/common/certificate/certificatemanager.py +++ b/fedn/common/certificate/certificatemanager.py @@ -10,7 +10,6 @@ class CertificateManager: """ def __init__(self, directory): - self.directory = directory self.certificates = [] self.allowed = dict() @@ -28,8 +27,7 @@ def get_or_create(self, name): if search: return search else: - cert = Certificate(self.directory, name=name, - cert_name=name + '-cert.pem', key_name=name + '-key.pem') + cert = Certificate(self.directory, name=name, cert_name=name + "-cert.pem", key_name=name + "-key.pem") cert.gen_keypair() self.certificates.append(cert) return cert @@ -53,12 +51,11 @@ def load_all(self): """ for filename in sorted(os.listdir(self.directory)): - if filename.endswith('cert.pem'): - name = filename.split('-')[0] - key_name = name + '-key.pem' + if filename.endswith("cert.pem"): + name = filename.split("-")[0] + key_name = name + "-key.pem" - c = Certificate(self.directory, name=name, - cert_name=filename, key_name=key_name) + c = Certificate(self.directory, name=name, cert_name=filename, key_name=key_name) self.certificates.append(c) def find(self, name): diff --git a/fedn/common/config.py b/fedn/common/config.py index 71f4b1698..4864ce1ef 100644 --- a/fedn/common/config.py +++ b/fedn/common/config.py @@ -24,12 +24,8 @@ def get_environment_config(): global STATESTORE_CONFIG global MODELSTORAGE_CONFIG - STATESTORE_CONFIG = os.environ.get( - "STATESTORE_CONFIG", "/workspaces/fedn/config/settings-reducer.yaml.template" - ) - MODELSTORAGE_CONFIG = os.environ.get( - "MODELSTORAGE_CONFIG", "/workspaces/fedn/config/settings-reducer.yaml.template" - ) + STATESTORE_CONFIG = os.environ.get("STATESTORE_CONFIG", "/workspaces/fedn/config/settings-reducer.yaml.template") + MODELSTORAGE_CONFIG = os.environ.get("MODELSTORAGE_CONFIG", "/workspaces/fedn/config/settings-reducer.yaml.template") def get_statestore_config(file=None): diff --git a/fedn/common/log_config.py b/fedn/common/log_config.py index 0e61a6a83..b8aa1218b 100644 --- a/fedn/common/log_config.py +++ b/fedn/common/log_config.py @@ -5,13 +5,7 @@ import requests import urllib3 -log_levels = { - "DEBUG": logging.DEBUG, - "INFO": logging.INFO, - "WARNING": logging.WARNING, - "ERROR": logging.ERROR, - "CRITICAL": logging.CRITICAL -} +log_levels = {"DEBUG": logging.DEBUG, "INFO": logging.INFO, "WARNING": logging.WARNING, "ERROR": logging.ERROR, "CRITICAL": logging.CRITICAL} urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) @@ -21,12 +15,12 @@ logger = logging.getLogger("fedn") logger.addHandler(handler) logger.setLevel(logging.DEBUG) -formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S') +formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S") handler.setFormatter(formatter) class StudioHTTPHandler(logging.handlers.HTTPHandler): - def __init__(self, host, url, method='POST', token=None): + def __init__(self, host, url, method="POST", token=None): super().__init__(host, url, method) self.token = token @@ -34,41 +28,35 @@ def emit(self, record): log_entry = self.mapLogRecord(record) log_entry = { - "msg": log_entry['msg'], - "levelname": log_entry['levelname'], + "msg": log_entry["msg"], + "levelname": log_entry["levelname"], "project": os.environ.get("PROJECT_ID"), - "appinstance": os.environ.get("APP_ID") - + "appinstance": os.environ.get("APP_ID"), } # Setup headers headers = { - 'Content-type': 'application/json', + "Content-type": "application/json", } if self.token: - remote_token_protocol = os.environ.get('FEDN_REMOTE_LOG_TOKEN_PROTOCOL', "Token") - headers['Authorization'] = f"{remote_token_protocol} {self.token}" - if self.method.lower() == 'post': - requests.post(self.host+self.url, json=log_entry, headers=headers) + remote_token_protocol = os.environ.get("FEDN_REMOTE_LOG_TOKEN_PROTOCOL", "Token") + headers["Authorization"] = f"{remote_token_protocol} {self.token}" + if self.method.lower() == "post": + requests.post(self.host + self.url, json=log_entry, headers=headers) else: # No other methods implemented. return # Remote logging can only be configured via environment variables for now. -REMOTE_LOG_SERVER = os.environ.get('FEDN_REMOTE_LOG_SERVER', False) -REMOTE_LOG_PATH = os.environ.get('FEDN_REMOTE_LOG_PATH', False) -REMOTE_LOG_LEVEL = os.environ.get('FEDN_REMOTE_LOG_LEVEL', 'INFO') +REMOTE_LOG_SERVER = os.environ.get("FEDN_REMOTE_LOG_SERVER", False) +REMOTE_LOG_PATH = os.environ.get("FEDN_REMOTE_LOG_PATH", False) +REMOTE_LOG_LEVEL = os.environ.get("FEDN_REMOTE_LOG_LEVEL", "INFO") if REMOTE_LOG_SERVER: rloglevel = log_levels.get(REMOTE_LOG_LEVEL, logging.INFO) - remote_token = os.environ.get('FEDN_REMOTE_LOG_TOKEN', None) - - http_handler = StudioHTTPHandler( - host=REMOTE_LOG_SERVER, - url=REMOTE_LOG_PATH, - method='POST', - token=remote_token - ) + remote_token = os.environ.get("FEDN_REMOTE_LOG_TOKEN", None) + + http_handler = StudioHTTPHandler(host=REMOTE_LOG_SERVER, url=REMOTE_LOG_PATH, method="POST", token=remote_token) http_handler.setLevel(rloglevel) logger.addHandler(http_handler) @@ -79,11 +67,11 @@ def set_log_level_from_string(level_str): """ # Mapping of string representation to logging constants level_mapping = { - 'CRITICAL': logging.CRITICAL, - 'ERROR': logging.ERROR, - 'WARNING': logging.WARNING, - 'INFO': logging.INFO, - 'DEBUG': logging.DEBUG, + "CRITICAL": logging.CRITICAL, + "ERROR": logging.ERROR, + "WARNING": logging.WARNING, + "INFO": logging.INFO, + "DEBUG": logging.DEBUG, } # Get the logging level from the mapping diff --git a/fedn/network/api/auth.py b/fedn/network/api/auth.py index bf43c2f69..e3240e743 100644 --- a/fedn/network/api/auth.py +++ b/fedn/network/api/auth.py @@ -3,16 +3,20 @@ import jwt from flask import jsonify, request -from fedn.common.config import (FEDN_AUTH_SCHEME, - FEDN_AUTH_WHITELIST_URL_PREFIX, - FEDN_JWT_ALGORITHM, FEDN_JWT_CUSTOM_CLAIM_KEY, - FEDN_JWT_CUSTOM_CLAIM_VALUE, SECRET_KEY) +from fedn.common.config import ( + FEDN_AUTH_SCHEME, + FEDN_AUTH_WHITELIST_URL_PREFIX, + FEDN_JWT_ALGORITHM, + FEDN_JWT_CUSTOM_CLAIM_KEY, + FEDN_JWT_CUSTOM_CLAIM_VALUE, + SECRET_KEY, +) def check_role_claims(payload, role): - if 'role' not in payload: + if "role" not in payload: return False - if payload['role'] != role: + if payload["role"] != role: return False return True @@ -41,30 +45,29 @@ def actual_decorator(func): def decorated(*args, **kwargs): if if_whitelisted_url_prefix(request.path): return func(*args, **kwargs) - token = request.headers.get('Authorization') + token = request.headers.get("Authorization") if not token: - return jsonify({'message': 'Missing token'}), 401 + return jsonify({"message": "Missing token"}), 401 # Get token from the header Bearer if token.startswith(FEDN_AUTH_SCHEME): - token = token.split(' ')[1] + token = token.split(" ")[1] else: - return jsonify({'message': - f'Invalid token scheme, expected {FEDN_AUTH_SCHEME}' - }), 401 + return jsonify({"message": f"Invalid token scheme, expected {FEDN_AUTH_SCHEME}"}), 401 try: payload = jwt.decode(token, SECRET_KEY, algorithms=[FEDN_JWT_ALGORITHM]) if not check_role_claims(payload, role): - return jsonify({'message': 'Invalid token'}), 401 + return jsonify({"message": "Invalid token"}), 401 if not check_custom_claims(payload): - return jsonify({'message': 'Invalid token'}), 401 + return jsonify({"message": "Invalid token"}), 401 except jwt.ExpiredSignatureError: - return jsonify({'message': 'Token expired'}), 401 + return jsonify({"message": "Token expired"}), 401 except jwt.InvalidTokenError: - return jsonify({'message': 'Invalid token'}), 401 + return jsonify({"message": "Invalid token"}), 401 return func(*args, **kwargs) return decorated + return actual_decorator diff --git a/fedn/network/api/client.py b/fedn/network/api/client.py index 60b4f3cd9..43678fbcf 100644 --- a/fedn/network/api/client.py +++ b/fedn/network/api/client.py @@ -2,11 +2,11 @@ import requests -__all__ = ['APIClient'] +__all__ = ["APIClient"] class APIClient: - """ An API client for interacting with the statestore and controller. + """An API client for interacting with the statestore and controller. :param host: The host of the api server. :type host: str @@ -37,34 +37,34 @@ def __init__(self, host, port=None, secure=False, verify=False, token=None, auth def _get_url(self, endpoint): if self.secure: - protocol = 'https' + protocol = "https" else: - protocol = 'http' + protocol = "http" if self.port: - return f'{protocol}://{self.host}:{self.port}/{endpoint}' - return f'{protocol}://{self.host}/{endpoint}' + return f"{protocol}://{self.host}:{self.port}/{endpoint}" + return f"{protocol}://{self.host}/{endpoint}" def _get_url_api_v1(self, endpoint): - return self._get_url(f'api/v1/{endpoint}') + return self._get_url(f"api/v1/{endpoint}") # --- Clients --- # def get_client(self, id: str): - """ Get a client from the statestore. + """Get a client from the statestore. :param id: The client id to get. :type id: str :return: Client. :rtype: dict """ - response = requests.get(self._get_url_api_v1(f'clients/{id}'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1(f"clients/{id}"), verify=self.verify, headers=self.headers) _json = response.json() return _json def get_clients(self, n_max: int = None): - """ Get clients from the statestore. + """Get clients from the statestore. :param n_max: The maximum number of clients to get (If none all will be fetched). :type n_max: int @@ -74,28 +74,28 @@ def get_clients(self, n_max: int = None): _headers = self.headers.copy() if n_max: - _headers['X-Limit'] = str(n_max) + _headers["X-Limit"] = str(n_max) - response = requests.get(self._get_url_api_v1('clients/'), verify=self.verify, headers=_headers) + response = requests.get(self._get_url_api_v1("clients/"), verify=self.verify, headers=_headers) _json = response.json() return _json def get_clients_count(self): - """ Get the number of clients in the statestore. + """Get the number of clients in the statestore. :return: The number of clients. :rtype: dict """ - response = requests.get(self._get_url_api_v1('clients/count'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1("clients/count"), verify=self.verify, headers=self.headers) _json = response.json() return _json def get_client_config(self, checksum=True): - """ Get client config from controller. Optionally include the checksum. + """Get client config from controller. Optionally include the checksum. The config is used for clients to connect to the controller and ask for combiner assignment. :param checksum: Whether to include the checksum of the package. @@ -103,18 +103,16 @@ def get_client_config(self, checksum=True): :return: The client configuration. :rtype: dict """ - _params = { - 'checksum': "true" if checksum else "false" - } + _params = {"checksum": "true" if checksum else "false"} - response = requests.get(self._get_url('get_client_config'), params=_params, verify=self.verify, headers=self.headers) + response = requests.get(self._get_url("get_client_config"), params=_params, verify=self.verify, headers=self.headers) _json = response.json() return _json def get_active_clients(self, combiner_id: str = None, n_max: int = None): - """ Get active clients from the statestore. + """Get active clients from the statestore. :param combiner_id: The combiner id to get active clients for. :type combiner_id: str @@ -126,14 +124,14 @@ def get_active_clients(self, combiner_id: str = None, n_max: int = None): _params = {"status": "online"} if combiner_id: - _params['combiner'] = combiner_id + _params["combiner"] = combiner_id _headers = self.headers.copy() if n_max: - _headers['X-Limit'] = str(n_max) + _headers["X-Limit"] = str(n_max) - response = requests.get(self._get_url_api_v1('clients/'), params=_params, verify=self.verify, headers=_headers) + response = requests.get(self._get_url_api_v1("clients/"), params=_params, verify=self.verify, headers=_headers) _json = response.json() @@ -142,21 +140,21 @@ def get_active_clients(self, combiner_id: str = None, n_max: int = None): # --- Combiners --- # def get_combiner(self, id: str): - """ Get a combiner from the statestore. + """Get a combiner from the statestore. :param id: The combiner id to get. :type id: str :return: Combiner. :rtype: dict """ - response = requests.get(self._get_url_api_v1(f'combiners/{id}'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1(f"combiners/{id}"), verify=self.verify, headers=self.headers) _json = response.json() return _json def get_combiners(self, n_max: int = None): - """ Get combiners in the network. + """Get combiners in the network. :param n_max: The maximum number of combiners to get (If none all will be fetched). :type n_max: int @@ -166,21 +164,21 @@ def get_combiners(self, n_max: int = None): _headers = self.headers.copy() if n_max: - _headers['X-Limit'] = str(n_max) + _headers["X-Limit"] = str(n_max) - response = requests.get(self._get_url_api_v1('combiners/'), verify=self.verify, headers=_headers) + response = requests.get(self._get_url_api_v1("combiners/"), verify=self.verify, headers=_headers) _json = response.json() return _json def get_combiners_count(self): - """ Get the number of combiners in the statestore. + """Get the number of combiners in the statestore. :return: The number of combiners. :rtype: dict """ - response = requests.get(self._get_url_api_v1('combiners/count'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1("combiners/count"), verify=self.verify, headers=self.headers) _json = response.json() @@ -189,12 +187,12 @@ def get_combiners_count(self): # --- Controllers --- # def get_controller_status(self): - """ Get the status of the controller. + """Get the status of the controller. :return: The status of the controller. :rtype: dict """ - response = requests.get(self._get_url('get_controller_status'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url("get_controller_status"), verify=self.verify, headers=self.headers) _json = response.json() @@ -203,21 +201,21 @@ def get_controller_status(self): # --- Models --- # def get_model(self, id: str): - """ Get a model from the statestore. + """Get a model from the statestore. :param id: The id (or model property) of the model to get. :type id: str :return: Model. :rtype: dict """ - response = requests.get(self._get_url_api_v1(f'models/{id}'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1(f"models/{id}"), verify=self.verify, headers=self.headers) _json = response.json() return _json def get_models(self, session_id: str = None, n_max: int = None): - """ Get models from the statestore. + """Get models from the statestore. :param session_id: The session id to get models for. (optional) :type session_id: str @@ -229,41 +227,41 @@ def get_models(self, session_id: str = None, n_max: int = None): _params = {} if session_id: - _params['session_id'] = session_id + _params["session_id"] = session_id _headers = self.headers.copy() if n_max: - _headers['X-Limit'] = str(n_max) + _headers["X-Limit"] = str(n_max) - response = requests.get(self._get_url_api_v1('models/'), params=_params, verify=self.verify, headers=_headers) + response = requests.get(self._get_url_api_v1("models/"), params=_params, verify=self.verify, headers=_headers) _json = response.json() return _json def get_models_count(self): - """ Get the number of models in the statestore. + """Get the number of models in the statestore. :return: The number of models. :rtype: dict """ - response = requests.get(self._get_url_api_v1('models/count'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1("models/count"), verify=self.verify, headers=self.headers) _json = response.json() return _json def get_active_model(self): - """ Get the latest model from the statestore. + """Get the latest model from the statestore. :return: The latest model. :rtype: dict """ _headers = self.headers.copy() - _headers['X-Limit'] = "1" + _headers["X-Limit"] = "1" - response = requests.get(self._get_url_api_v1('models/'), verify=self.verify, headers=_headers) + response = requests.get(self._get_url_api_v1("models/"), verify=self.verify, headers=_headers) _json = response.json() if "result" in _json and len(_json["result"]) > 0: @@ -272,7 +270,7 @@ def get_active_model(self): return _json def get_model_trail(self, id: str = None, include_self: bool = True, reverse: bool = True, n_max: int = None): - """ Get the model trail. + """Get the model trail. :param id: The id (or model property) of the model to start the trail from. (optional) :type id: str @@ -291,18 +289,18 @@ def get_model_trail(self, id: str = None, include_self: bool = True, reverse: bo _headers = self.headers.copy() _count: int = n_max if n_max else self.get_models_count() - _headers['X-Limit'] = str(_count) - _headers['X-Reverse'] = "true" if reverse else "false" + _headers["X-Limit"] = str(_count) + _headers["X-Reverse"] = "true" if reverse else "false" _include_self_str: str = "true" if include_self else "false" - response = requests.get(self._get_url_api_v1(f'models/{id}/ancestors?include_self={_include_self_str}'), verify=self.verify, headers=_headers) + response = requests.get(self._get_url_api_v1(f"models/{id}/ancestors?include_self={_include_self_str}"), verify=self.verify, headers=_headers) _json = response.json() return _json def download_model(self, id: str, path: str): - """ Download the model with id id. + """Download the model with id id. :param id: The id (or model property) of the model to download. :type id: str @@ -311,46 +309,45 @@ def download_model(self, id: str, path: str): :return: Message with success or failure. :rtype: dict """ - response = requests.get(self._get_url_api_v1(f'models/{id}/download'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1(f"models/{id}/download"), verify=self.verify, headers=self.headers) if response.status_code == 200: - - with open(path, 'wb') as file: + with open(path, "wb") as file: file.write(response.content) - return {'success': True, 'message': 'Model downloaded successfully.'} + return {"success": True, "message": "Model downloaded successfully."} else: - return {'success': False, 'message': 'Failed to download model.'} + return {"success": False, "message": "Failed to download model."} def set_active_model(self, path): - """ Set the initial model in the statestore and upload to model repository. + """Set the initial model in the statestore and upload to model repository. :param path: The file path of the initial model to set. :type path: str :return: A dict with success or failure message. :rtype: dict """ - with open(path, 'rb') as file: - response = requests.post(self._get_url('set_initial_model'), files={'file': file}, verify=self.verify, headers=self.headers) + with open(path, "rb") as file: + response = requests.post(self._get_url("set_initial_model"), files={"file": file}, verify=self.verify, headers=self.headers) return response.json() # --- Packages --- # def get_package(self, id: str): - """ Get a compute package from the statestore. + """Get a compute package from the statestore. :param id: The id of the compute package to get. :type id: str :return: Package. :rtype: dict """ - response = requests.get(self._get_url_api_v1(f'packages/{id}'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1(f"packages/{id}"), verify=self.verify, headers=self.headers) _json = response.json() return _json def get_packages(self, n_max: int = None): - """ Get compute packages from the statestore. + """Get compute packages from the statestore. :param n_max: The maximum number of packages to get (If none all will be fetched). :type n_max: int @@ -360,68 +357,68 @@ def get_packages(self, n_max: int = None): _headers = self.headers.copy() if n_max: - _headers['X-Limit'] = str(n_max) + _headers["X-Limit"] = str(n_max) - response = requests.get(self._get_url_api_v1('packages/'), verify=self.verify, headers=_headers) + response = requests.get(self._get_url_api_v1("packages/"), verify=self.verify, headers=_headers) _json = response.json() return _json def get_packages_count(self): - """ Get the number of compute packages in the statestore. + """Get the number of compute packages in the statestore. :return: The number of packages. :rtype: dict """ - response = requests.get(self._get_url_api_v1('packages/count'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1("packages/count"), verify=self.verify, headers=self.headers) _json = response.json() return _json def get_active_package(self): - """ Get the (active) compute package from the statestore. + """Get the (active) compute package from the statestore. :return: Package. :rtype: dict """ - response = requests.get(self._get_url_api_v1('packages/active'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1("packages/active"), verify=self.verify, headers=self.headers) _json = response.json() return _json def get_package_checksum(self): - """ Get the checksum of the compute package. + """Get the checksum of the compute package. :return: The checksum. :rtype: dict """ - response = requests.get(self._get_url('get_package_checksum'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url("get_package_checksum"), verify=self.verify, headers=self.headers) _json = response.json() return _json def download_package(self, path: str): - """ Download the compute package. + """Download the compute package. :param path: The path to download the compute package to. :type path: str :return: Message with success or failure. :rtype: dict """ - response = requests.get(self._get_url('download_package'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url("download_package"), verify=self.verify, headers=self.headers) if response.status_code == 200: - with open(path, 'wb') as file: + with open(path, "wb") as file: file.write(response.content) - return {'success': True, 'message': 'Package downloaded successfully.'} + return {"success": True, "message": "Package downloaded successfully."} else: - return {'success': False, 'message': 'Failed to download package.'} + return {"success": False, "message": "Failed to download package."} def set_active_package(self, path: str, helper: str, name: str = None, description: str = None): - """ Set the compute package in the statestore. + """Set the compute package in the statestore. :param path: The file path of the compute package to set. :type path: str @@ -430,9 +427,14 @@ def set_active_package(self, path: str, helper: str, name: str = None, descripti :return: A dict with success or failure message. :rtype: dict """ - with open(path, 'rb') as file: - response = requests.post(self._get_url('set_package'), files={'file': file}, data={ - 'helper': helper, 'name': name, 'description': description}, verify=self.verify, headers=self.headers) + with open(path, "rb") as file: + response = requests.post( + self._get_url("set_package"), + files={"file": file}, + data={"helper": helper, "name": name, "description": description}, + verify=self.verify, + headers=self.headers, + ) _json = response.json() @@ -441,21 +443,21 @@ def set_active_package(self, path: str, helper: str, name: str = None, descripti # --- Rounds --- # def get_round(self, id: str): - """ Get a round from the statestore. + """Get a round from the statestore. :param round_id: The round id to get. :type round_id: str :return: Round (config and metrics). :rtype: dict """ - response = requests.get(self._get_url_api_v1(f'rounds/{id}'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1(f"rounds/{id}"), verify=self.verify, headers=self.headers) _json = response.json() return _json def get_rounds(self, n_max: int = None): - """ Get all rounds from the statestore. + """Get all rounds from the statestore. :param n_max: The maximum number of rounds to get (If none all will be fetched). :type n_max: int @@ -465,21 +467,21 @@ def get_rounds(self, n_max: int = None): _headers = self.headers.copy() if n_max: - _headers['X-Limit'] = str(n_max) + _headers["X-Limit"] = str(n_max) - response = requests.get(self._get_url_api_v1('rounds/'), verify=self.verify, headers=_headers) + response = requests.get(self._get_url_api_v1("rounds/"), verify=self.verify, headers=_headers) _json = response.json() return _json def get_rounds_count(self): - """ Get the number of rounds in the statestore. + """Get the number of rounds in the statestore. :return: The number of rounds. :rtype: dict """ - response = requests.get(self._get_url_api_v1('rounds/count'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1("rounds/count"), verify=self.verify, headers=self.headers) _json = response.json() @@ -488,7 +490,7 @@ def get_rounds_count(self): # --- Sessions --- # def get_session(self, id: str): - """ Get a session from the statestore. + """Get a session from the statestore. :param id: The session id to get. :type id: str @@ -496,14 +498,14 @@ def get_session(self, id: str): :rtype: dict """ - response = requests.get(self._get_url_api_v1(f'sessions/{id}'), self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1(f"sessions/{id}"), self.verify, headers=self.headers) _json = response.json() return _json def get_sessions(self, n_max: int = None): - """ Get sessions from the statestore. + """Get sessions from the statestore. :param n_max: The maximum number of sessions to get (If none all will be fetched). :type n_max: int @@ -513,28 +515,28 @@ def get_sessions(self, n_max: int = None): _headers = self.headers.copy() if n_max: - _headers['X-Limit'] = str(n_max) + _headers["X-Limit"] = str(n_max) - response = requests.get(self._get_url_api_v1('sessions/'), verify=self.verify, headers=_headers) + response = requests.get(self._get_url_api_v1("sessions/"), verify=self.verify, headers=_headers) _json = response.json() return _json def get_sessions_count(self): - """ Get the number of sessions in the statestore. + """Get the number of sessions in the statestore. :return: The number of sessions. :rtype: dict """ - response = requests.get(self._get_url_api_v1('sessions/count'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1("sessions/count"), verify=self.verify, headers=self.headers) _json = response.json() return _json def get_session_status(self, id: str): - """ Get the status of a session. + """Get the status of a session. :param id: The id of the session to get. :type id: str @@ -549,7 +551,7 @@ def get_session_status(self, id: str): return "Could not retrieve session status." def session_is_finished(self, id: str): - """ Check if a session with id has finished. + """Check if a session with id has finished. :param id: The id of the session to get. :type id: str @@ -560,18 +562,21 @@ def session_is_finished(self, id: str): return status and status.lower() == "finished" def start_session( - self, - id: str = None, - aggregator: str = 'fedavg', - aggregator_kwargs: dict = None, - model_id: str = None, - round_timeout: int = 180, - rounds: int = 5, - round_buffer_size: int = -1, - delete_models: bool = True, - validate: bool = True, helper: str = 'numpyhelper', min_clients: int = 1, requested_clients: int = 8 + self, + id: str = None, + aggregator: str = "fedavg", + aggregator_kwargs: dict = None, + model_id: str = None, + round_timeout: int = 180, + rounds: int = 5, + round_buffer_size: int = -1, + delete_models: bool = True, + validate: bool = True, + helper: str = "numpyhelper", + min_clients: int = 1, + requested_clients: int = 8, ): - """ Start a new session. + """Start a new session. :param id: The session id to start. :type id: str @@ -598,20 +603,25 @@ def start_session( :return: A dict with success or failure message and session config. :rtype: dict """ - response = requests.post(self._get_url('start_session'), json={ - 'session_id': id, - 'aggregator': aggregator, - 'aggregator_kwargs': aggregator_kwargs, - 'model_id': model_id, - 'round_timeout': round_timeout, - 'rounds': rounds, - 'round_buffer_size': round_buffer_size, - 'delete_models': delete_models, - 'validate': validate, - 'helper': helper, - 'min_clients': min_clients, - 'requested_clients': requested_clients - }, verify=self.verify, headers=self.headers) + response = requests.post( + self._get_url("start_session"), + json={ + "session_id": id, + "aggregator": aggregator, + "aggregator_kwargs": aggregator_kwargs, + "model_id": model_id, + "round_timeout": round_timeout, + "rounds": rounds, + "round_buffer_size": round_buffer_size, + "delete_models": delete_models, + "validate": validate, + "helper": helper, + "min_clients": min_clients, + "requested_clients": requested_clients, + }, + verify=self.verify, + headers=self.headers, + ) _json = response.json() @@ -620,21 +630,21 @@ def start_session( # --- Statuses --- # def get_status(self, id: str): - """ Get a status object (event) from the statestore. + """Get a status object (event) from the statestore. :param id: The id of the status to get. :type id: str :return: Status. :rtype: dict """ - response = requests.get(self._get_url_api_v1(f'statuses/{id}'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1(f"statuses/{id}"), verify=self.verify, headers=self.headers) _json = response.json() return _json def get_statuses(self, session_id: str = None, event_type: str = None, sender_name: str = None, sender_role: str = None, n_max: int = None): - """ Get statuses from the statestore. Filter by input parameters + """Get statuses from the statestore. Filter by input parameters :param session_id: The session id to get statuses for. :type session_id: str @@ -665,21 +675,21 @@ def get_statuses(self, session_id: str = None, event_type: str = None, sender_na _headers = self.headers.copy() if n_max: - _headers['X-Limit'] = str(n_max) + _headers["X-Limit"] = str(n_max) - response = requests.get(self._get_url_api_v1('statuses/'), params=_params, verify=self.verify, headers=_headers) + response = requests.get(self._get_url_api_v1("statuses/"), params=_params, verify=self.verify, headers=_headers) _json = response.json() return _json def get_statuses_count(self): - """ Get the number of statuses in the statestore. + """Get the number of statuses in the statestore. :return: The number of statuses. :rtype: dict """ - response = requests.get(self._get_url_api_v1('statuses/count'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1("statuses/count"), verify=self.verify, headers=self.headers) _json = response.json() @@ -688,14 +698,14 @@ def get_statuses_count(self): # --- Validations --- # def get_validation(self, id: str): - """ Get a validation from the statestore. + """Get a validation from the statestore. :param id: The id of the validation to get. :type id: str :return: Validation. :rtype: dict """ - response = requests.get(self._get_url_api_v1(f'validations/{id}'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1(f"validations/{id}"), verify=self.verify, headers=self.headers) _json = response.json() @@ -710,9 +720,9 @@ def get_validations( sender_role: str = None, receiver_name: str = None, receiver_role: str = None, - n_max: int = None + n_max: int = None, ): - """ Get validations from the statestore. Filter by input parameters. + """Get validations from the statestore. Filter by input parameters. :param session_id: The session id to get validations for. :type session_id: str @@ -759,21 +769,21 @@ def get_validations( _headers = self.headers.copy() if n_max: - _headers['X-Limit'] = str(n_max) + _headers["X-Limit"] = str(n_max) - response = requests.get(self._get_url_api_v1('validations/'), params=_params, verify=self.verify, headers=_headers) + response = requests.get(self._get_url_api_v1("validations/"), params=_params, verify=self.verify, headers=_headers) _json = response.json() return _json def get_validations_count(self): - """ Get the number of validations in the statestore. + """Get the number of validations in the statestore. :return: The number of validations. :rtype: dict """ - response = requests.get(self._get_url_api_v1('validations/count'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1("validations/count"), verify=self.verify, headers=self.headers) _json = response.json() diff --git a/fedn/network/api/interface.py b/fedn/network/api/interface.py index 38d6c8bf6..718cd8a18 100644 --- a/fedn/network/api/interface.py +++ b/fedn/network/api/interface.py @@ -10,8 +10,7 @@ from fedn.common.config import get_controller_config, get_network_config from fedn.common.log_config import logger -from fedn.network.combiner.interfaces import (CombinerInterface, - CombinerUnavailableError) +from fedn.network.combiner.interfaces import CombinerInterface, CombinerUnavailableError from fedn.network.state import ReducerState, ReducerStateToString from fedn.utils.checksum import sha from fedn.utils.plots import Plot @@ -36,9 +35,7 @@ def _to_dict(self): data = {"name": self.name} return data - def _allowed_file_extension( - self, filename, ALLOWED_EXTENSIONS={"gz", "bz2", "tar", "zip", "tgz"} - ): + def _allowed_file_extension(self, filename, ALLOWED_EXTENSIONS={"gz", "bz2", "tar", "zip", "tgz"}): """Check if file extension is allowed. :param filename: The filename to check. @@ -170,11 +167,10 @@ def get_session(self, session_id): info = session_object["session_config"][0] status = session_object["status"] payload[id] = info - payload['status'] = status + payload["status"] = status return jsonify(payload) def set_active_compute_package(self, id: str): - success = self.statestore.set_active_compute_package(id) if not success: @@ -199,16 +195,12 @@ def set_compute_package(self, file, helper_type: str, name: str = None, descript :rtype: :class:`flask.Response` """ - if ( - self.control.state() == ReducerState.instructing - or self.control.state() == ReducerState.monitoring - ): + if self.control.state() == ReducerState.instructing or self.control.state() == ReducerState.monitoring: return ( jsonify( { "success": False, - "message": "Reducer is in instructing or monitoring state." - "Cannot set compute package.", + "message": "Reducer is in instructing or monitoring state." "Cannot set compute package.", } ), 400, @@ -288,9 +280,7 @@ def get_compute_package(self): result = self.statestore.get_compute_package() if result is None: return ( - jsonify( - {"success": False, "message": "No compute package found."} - ), + jsonify({"success": False, "message": "No compute package found."}), 404, ) @@ -327,9 +317,7 @@ def list_compute_packages(self, limit: str = None, skip: str = None, include_act result = self.statestore.list_compute_packages(limit, skip) if result is None: return ( - jsonify( - {"success": False, "message": "No compute packages found."} - ), + jsonify({"success": False, "message": "No compute packages found."}), 404, ) @@ -386,9 +374,7 @@ def download_compute_package(self, name): mutex = threading.Lock() mutex.acquire() # TODO: make configurable, perhaps in config.py or package.py - return send_from_directory( - "/app/client/package/", name, as_attachment=True - ) + return send_from_directory("/app/client/package/", name, as_attachment=True) except Exception: try: data = self.control.get_compute_package(name) @@ -397,9 +383,7 @@ def download_compute_package(self, name): with open(file_path, "wb") as fh: fh.write(data) # TODO: make configurable, perhaps in config.py or package.py - return send_from_directory( - "/app/client/package/", name, as_attachment=True - ) + return send_from_directory("/app/client/package/", name, as_attachment=True) except Exception: raise finally: @@ -418,9 +402,7 @@ def _create_checksum(self, name=None): name, message = self._get_compute_package_name() if name is None: return False, message, "" - file_path = os.path.join( - "/app/client/package/", name - ) # TODO: make configurable, perhaps in config.py or package.py + file_path = os.path.join("/app/client/package/", name) # TODO: make configurable, perhaps in config.py or package.py try: sum = str(sha(file_path)) except FileNotFoundError: @@ -505,9 +487,7 @@ def get_all_validations(self, **kwargs): payload[id] = info return jsonify(payload) - def add_combiner( - self, combiner_id, secure_grpc, address, remote_addr, fqdn, port - ): + def add_combiner(self, combiner_id, secure_grpc, address, remote_addr, fqdn, port): """Add a combiner to the network. :param combiner_id: The combiner id to add. @@ -540,9 +520,7 @@ def add_combiner( combiner = self.control.network.get_combiner(combiner_id) if not combiner: if secure_grpc == "True": - certificate, key = self.certificate_manager.get_or_create( - address - ).get_keypair_raw() + certificate, key = self.certificate_manager.get_or_create(address).get_keypair_raw() _ = base64.b64encode(certificate) _ = base64.b64encode(key) @@ -566,9 +544,7 @@ def add_combiner( # Check combiner now exists combiner = self.control.network.get_combiner(combiner_id) if not combiner: - return jsonify( - {"success": False, "message": "Combiner not added."} - ) + return jsonify({"success": False, "message": "Combiner not added."}) payload = { "success": True, @@ -623,9 +599,7 @@ def add_client(self, client_id, preferred_combiner, remote_addr): combiner = self.control.network.find_available_combiner() if combiner is None: return ( - jsonify( - {"success": False, "message": "No combiner available."} - ), + jsonify({"success": False, "message": "No combiner available."}), 400, ) @@ -691,9 +665,7 @@ def set_initial_model(self, file): logger.debug(e) return jsonify({"success": False, "message": e}) - return jsonify( - {"success": True, "message": "Initial model added successfully."} - ) + return jsonify({"success": True, "message": "Initial model added successfully."}) def get_latest_model(self): """Get the latest model from the statestore. @@ -706,9 +678,7 @@ def get_latest_model(self): payload = {"model_id": model_id} return jsonify(payload) else: - return jsonify( - {"success": False, "message": "No initial model set."} - ) + return jsonify({"success": False, "message": "No initial model set."}) def set_current_model(self, model_id: str): """Set the active model in the statestore. @@ -745,7 +715,6 @@ def get_models(self, session_id: str = None, limit: str = None, skip: str = None include_active: bool = include_active == "true" if include_active: - latest_model = self.statestore.get_latest_model() arr = [ @@ -801,9 +770,7 @@ def get_model_trail(self): if model_info: return jsonify(model_info) else: - return jsonify( - {"success": False, "message": "No model trail available."} - ) + return jsonify({"success": False, "message": "No model trail available."}) def get_model_ancestors(self, model_id: str, limit: str = None): """Get the model ancestors for a given model. @@ -816,15 +783,12 @@ def get_model_ancestors(self, model_id: str, limit: str = None): :rtype: :class:`flask.Response` """ if model_id is None: - return jsonify( - {"success": False, "message": "No model id provided."} - ) + return jsonify({"success": False, "message": "No model id provided."}) limit: int = int(limit) if limit is not None else 10 # if limit is None, default to 10 response = self.statestore.get_model_ancestors(model_id, limit) if response: - arr: list = [] for element in response: @@ -840,9 +804,7 @@ def get_model_ancestors(self, model_id: str, limit: str = None): return jsonify(result) else: - return jsonify( - {"success": False, "message": "No model ancestors available."} - ) + return jsonify({"success": False, "message": "No model ancestors available."}) def get_model_descendants(self, model_id: str, limit: str = None): """Get the model descendants for a given model. @@ -856,16 +818,13 @@ def get_model_descendants(self, model_id: str, limit: str = None): """ if model_id is None: - return jsonify( - {"success": False, "message": "No model id provided."} - ) + return jsonify({"success": False, "message": "No model id provided."}) limit: int = int(limit) if limit is not None else 10 response: list = self.statestore.get_model_descendants(model_id, limit) if response: - arr: list = [] for element in response: @@ -881,9 +840,7 @@ def get_model_descendants(self, model_id: str, limit: str = None): return jsonify(result) else: - return jsonify( - {"success": False, "message": "No model descendants available."} - ) + return jsonify({"success": False, "message": "No model descendants available."}) def get_all_rounds(self): """Get all rounds. @@ -926,8 +883,8 @@ def get_round(self, round_id): if round_object is None: return jsonify({"success": False, "message": "Round not found."}) payload = { - 'round_id': round_object['round_id'], - 'combiners': round_object['combiners'], + "round_id": round_object["round_id"], + "combiners": round_object["combiners"], } return jsonify(payload) @@ -992,7 +949,6 @@ def list_combiners_data(self, combiners): # order list by combiner name for element in response: - obj = { "combiner": element["_id"], "count": element["count"], @@ -1007,7 +963,7 @@ def list_combiners_data(self, combiners): def start_session( self, session_id, - aggregator='fedavg', + aggregator="fedavg", aggregator_kwargs=None, model_id=None, rounds=5, @@ -1047,15 +1003,11 @@ def start_session( # Check if session already exists session = self.statestore.get_session(session_id) if session: - return jsonify( - {"success": False, "message": "Session already exists."} - ) + return jsonify({"success": False, "message": "Session already exists."}) # Check if session is running if self.control.state() == ReducerState.monitoring: - return jsonify( - {"success": False, "message": "A session is already running."} - ) + return jsonify({"success": False, "message": "A session is already running."}) # Check if compute package is set if not self.statestore.get_compute_package(): @@ -1123,9 +1075,7 @@ def start_session( } # Start session - threading.Thread( - target=self.control.session, args=(session_config,) - ).start() + threading.Thread(target=self.control.session, args=(session_config,)).start() # Return success response return jsonify( diff --git a/fedn/network/api/network.py b/fedn/network/api/network.py index c6a3c3838..cb105f10a 100644 --- a/fedn/network/api/network.py +++ b/fedn/network/api/network.py @@ -4,14 +4,14 @@ from fedn.network.combiner.interfaces import CombinerInterface from fedn.network.loadbalancer.leastpacked import LeastPacked -__all__ = 'Network', +__all__ = ("Network",) class Network: - """ FEDn network interface. This class is used to interact with the network. - Note: This class contain redundant code, which is not used in the current version of FEDn. - Some methods has been moved to :class:`fedn.network.api.interface.API`. - """ + """FEDn network interface. This class is used to interact with the network. + Note: This class contain redundant code, which is not used in the current version of FEDn. + Some methods has been moved to :class:`fedn.network.api.interface.API`. + """ def __init__(self, control, statestore, load_balancer=None): """ """ @@ -25,7 +25,7 @@ def __init__(self, control, statestore, load_balancer=None): self.load_balancer = load_balancer def get_combiner(self, name): - """ Get combiner by name. + """Get combiner by name. :param name: name of combiner :type name: str @@ -39,7 +39,7 @@ def get_combiner(self, name): return None def get_combiners(self): - """ Get all combiners in the network. + """Get all combiners in the network. :return: list of combiners objects :rtype: list(:class:`fedn.network.combiner.interfaces.CombinerInterface`) @@ -47,21 +47,19 @@ def get_combiners(self): data = self.statestore.get_combiners() combiners = [] for c in data["result"]: - if c['certificate']: - cert = base64.b64decode(c['certificate']) - key = base64.b64decode(c['key']) + if c["certificate"]: + cert = base64.b64decode(c["certificate"]) + key = base64.b64decode(c["key"]) else: cert = None key = None - combiners.append( - CombinerInterface(c['parent'], c['name'], c['address'], c['fqdn'], c['port'], - certificate=cert, key=key, ip=c['ip'])) + combiners.append(CombinerInterface(c["parent"], c["name"], c["address"], c["fqdn"], c["port"], certificate=cert, key=key, ip=c["ip"])) return combiners def add_combiner(self, combiner): - """ Add a new combiner to the network. + """Add a new combiner to the network. :param combiner: The combiner instance object :type combiner: :class:`fedn.network.combiner.interfaces.CombinerInterface` @@ -78,7 +76,7 @@ def add_combiner(self, combiner): self.statestore.set_combiner(combiner.to_dict()) def remove_combiner(self, combiner): - """ Remove a combiner from the network. + """Remove a combiner from the network. :param combiner: The combiner instance object :type combiner: :class:`fedn.network.combiner.interfaces.CombinerInterface` @@ -90,7 +88,7 @@ def remove_combiner(self, combiner): self.statestore.delete_combiner(combiner.name) def find_available_combiner(self): - """ Find an available combiner in the network. + """Find an available combiner in the network. :return: The combiner instance object :rtype: :class:`fedn.network.combiner.interfaces.CombinerInterface` @@ -99,32 +97,31 @@ def find_available_combiner(self): return combiner def handle_unavailable_combiner(self, combiner): - """ This callback is triggered if a combiner is found to be unresponsive. + """This callback is triggered if a combiner is found to be unresponsive. :param combiner: The combiner instance object :type combiner: :class:`fedn.network.combiner.interfaces.CombinerInterface` :return: None """ # TODO: Implement strategy to handle an unavailable combiner. - logger.warning("REDUCER CONTROL: Combiner {} unavailable.".format( - combiner.name)) + logger.warning("REDUCER CONTROL: Combiner {} unavailable.".format(combiner.name)) def add_client(self, client): - """ Add a new client to the network. + """Add a new client to the network. :param client: The client instance object :type client: dict :return: None """ - if self.get_client(client['name']): + if self.get_client(client["name"]): return - logger.info("adding client {}".format(client['name'])) + logger.info("adding client {}".format(client["name"])) self.statestore.set_client(client) def get_client(self, name): - """ Get client by name. + """Get client by name. :param name: name of client :type name: str @@ -135,7 +132,7 @@ def get_client(self, name): return ret def update_client_data(self, client_data, status, role): - """ Update client status in statestore. + """Update client status in statestore. :param client_data: The client instance object :type client_data: dict @@ -148,7 +145,7 @@ def update_client_data(self, client_data, status, role): self.statestore.update_client_status(client_data, status, role) def get_client_info(self): - """ list available client in statestore. + """List available client in statestore. :return: list of client objects :rtype: list(ObjectId) diff --git a/fedn/network/api/server.py b/fedn/network/api/server.py index 45cf410ce..5f645e4e2 100644 --- a/fedn/network/api/server.py +++ b/fedn/network/api/server.py @@ -2,8 +2,7 @@ from flask import Flask, jsonify, request -from fedn.common.config import (get_controller_config, get_modelstorage_config, - get_network_config, get_statestore_config) +from fedn.common.config import get_controller_config, get_modelstorage_config, get_network_config, get_statestore_config from fedn.network.api.auth import jwt_auth_required from fedn.network.api.interface import API from fedn.network.api.v1 import _routes @@ -23,14 +22,12 @@ for bp in _routes: app.register_blueprint(bp) if custom_url_prefix: - app.register_blueprint(bp, - name=f"{bp.name}_custom", - url_prefix=f"{custom_url_prefix}{bp.url_prefix}") + app.register_blueprint(bp, name=f"{bp.name}_custom", url_prefix=f"{custom_url_prefix}{bp.url_prefix}") -@app.route('/health', methods=["GET"]) +@app.route("/health", methods=["GET"]) def health_check(): - return 'OK', 200 + return "OK", 200 if custom_url_prefix: @@ -364,9 +361,7 @@ def set_package(): file = request.files["file"] except KeyError: return jsonify({"success": False, "message": "Missing file."}), 400 - return api.set_compute_package( - file=file, helper_type=helper_type, name=name, description=description - ) + return api.set_compute_package(file=file, helper_type=helper_type, name=name, description=description) if custom_url_prefix: @@ -399,9 +394,7 @@ def list_compute_packages(): skip = request.args.get("skip", None) include_active = request.args.get("include_active", None) - return api.list_compute_packages( - limit=limit, skip=skip, include_active=include_active - ) + return api.list_compute_packages(limit=limit, skip=skip, include_active=include_active) if custom_url_prefix: diff --git a/fedn/network/api/v1/client_routes.py b/fedn/network/api/v1/client_routes.py index 30322a9b7..d5ccc58ee 100644 --- a/fedn/network/api/v1/client_routes.py +++ b/fedn/network/api/v1/client_routes.py @@ -1,8 +1,7 @@ from flask import Blueprint, jsonify, request from fedn.network.api.auth import jwt_auth_required -from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs, - get_typed_list_headers, mdb) +from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers, mdb from fedn.network.storage.statestore.stores.client_store import ClientStore from fedn.network.storage.statestore.stores.shared import EntityNotFound @@ -115,9 +114,7 @@ def get_clients(): limit, skip, sort_key, sort_order, _ = get_typed_list_headers(request.headers) kwargs = request.args.to_dict() - clients = client_store.list( - limit, skip, sort_key, sort_order, use_typing=False, **kwargs - ) + clients = client_store.list(limit, skip, sort_key, sort_order, use_typing=False, **kwargs) result = clients["result"] @@ -202,9 +199,7 @@ def list_clients(): try: limit, skip, sort_key, sort_order, _ = get_typed_list_headers(request.headers) kwargs = get_post_data_to_kwargs(request) - clients = client_store.list( - limit, skip, sort_key, sort_order, use_typing=False, **kwargs - ) + clients = client_store.list(limit, skip, sort_key, sort_order, use_typing=False, **kwargs) result = clients["result"] diff --git a/fedn/network/api/v1/combiner_routes.py b/fedn/network/api/v1/combiner_routes.py index 7d1761bee..1f9360461 100644 --- a/fedn/network/api/v1/combiner_routes.py +++ b/fedn/network/api/v1/combiner_routes.py @@ -1,8 +1,7 @@ from flask import Blueprint, jsonify, request from fedn.network.api.auth import jwt_auth_required -from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs, - get_typed_list_headers, mdb) +from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers, mdb from fedn.network.storage.statestore.stores.combiner_store import CombinerStore from fedn.network.storage.statestore.stores.shared import EntityNotFound @@ -107,9 +106,7 @@ def get_combiners(): kwargs = request.args.to_dict() - combiners = combiner_store.list( - limit, skip, sort_key, sort_order, use_typing=False, **kwargs - ) + combiners = combiner_store.list(limit, skip, sort_key, sort_order, use_typing=False, **kwargs) result = combiners["result"] @@ -192,9 +189,7 @@ def list_combiners(): kwargs = get_post_data_to_kwargs(request) - combiners = combiner_store.list( - limit, skip, sort_key, sort_order, use_typing=False, **kwargs - ) + combiners = combiner_store.list(limit, skip, sort_key, sort_order, use_typing=False, **kwargs) result = combiners["result"] diff --git a/fedn/network/api/v1/model_routes.py b/fedn/network/api/v1/model_routes.py index 8e9308408..f9708a149 100644 --- a/fedn/network/api/v1/model_routes.py +++ b/fedn/network/api/v1/model_routes.py @@ -4,10 +4,7 @@ from flask import Blueprint, jsonify, request, send_file from fedn.network.api.auth import jwt_auth_required -from fedn.network.api.v1.shared import (api_version, get_limit, - get_post_data_to_kwargs, get_reverse, - get_typed_list_headers, mdb, - modelstorage_config) +from fedn.network.api.v1.shared import api_version, get_limit, get_post_data_to_kwargs, get_reverse, get_typed_list_headers, mdb, modelstorage_config from fedn.network.storage.s3.base import RepositoryBase from fedn.network.storage.s3.miniorepository import MINIORepository from fedn.network.storage.statestore.stores.model_store import ModelStore @@ -112,9 +109,7 @@ def get_models(): limit, skip, sort_key, sort_order, _ = get_typed_list_headers(request.headers) kwargs = request.args.to_dict() - models = model_store.list( - limit, skip, sort_key, sort_order, use_typing=False, **kwargs - ) + models = model_store.list(limit, skip, sort_key, sort_order, use_typing=False, **kwargs) result = models["result"] @@ -199,9 +194,7 @@ def list_models(): limit, skip, sort_key, sort_order, _ = get_typed_list_headers(request.headers) kwargs = get_post_data_to_kwargs(request) - models = model_store.list( - limit, skip, sort_key, sort_order, use_typing=False, **kwargs - ) + models = model_store.list(limit, skip, sort_key, sort_order, use_typing=False, **kwargs) result = models["result"] @@ -466,7 +459,7 @@ def get_ancestors(id: str): try: limit = get_limit(request.headers) reverse = get_reverse(request.headers) - include_self_param: str = request.args.get('include_self') + include_self_param: str = request.args.get("include_self") include_self: bool = include_self_param and include_self_param.lower() == "true" diff --git a/fedn/network/api/v1/package_routes.py b/fedn/network/api/v1/package_routes.py index 30ac4d51e..65783f54b 100644 --- a/fedn/network/api/v1/package_routes.py +++ b/fedn/network/api/v1/package_routes.py @@ -1,9 +1,7 @@ from flask import Blueprint, jsonify, request from fedn.network.api.auth import jwt_auth_required -from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs, - get_typed_list_headers, get_use_typing, - mdb) +from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers, get_use_typing, mdb from fedn.network.storage.statestore.stores.package_store import PackageStore from fedn.network.storage.statestore.stores.shared import EntityNotFound @@ -120,9 +118,7 @@ def get_packages(): limit, skip, sort_key, sort_order, _ = get_typed_list_headers(request.headers) kwargs = request.args.to_dict() - packages = package_store.list( - limit, skip, sort_key, sort_order, use_typing=True, **kwargs - ) + packages = package_store.list(limit, skip, sort_key, sort_order, use_typing=True, **kwargs) result = [package.__dict__ for package in packages["result"]] @@ -210,9 +206,7 @@ def list_packages(): limit, skip, sort_key, sort_order, _ = get_typed_list_headers(request.headers) kwargs = get_post_data_to_kwargs(request) - packages = package_store.list( - limit, skip, sort_key, sort_order, use_typing=True, **kwargs - ) + packages = package_store.list(limit, skip, sort_key, sort_order, use_typing=True, **kwargs) result = [package.__dict__ for package in packages["result"]] diff --git a/fedn/network/api/v1/round_routes.py b/fedn/network/api/v1/round_routes.py index 8890c510a..4c2eb0c44 100644 --- a/fedn/network/api/v1/round_routes.py +++ b/fedn/network/api/v1/round_routes.py @@ -1,8 +1,7 @@ from flask import Blueprint, jsonify, request from fedn.network.api.auth import jwt_auth_required -from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs, - get_typed_list_headers, mdb) +from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers, mdb from fedn.network.storage.statestore.stores.round_store import RoundStore from fedn.network.storage.statestore.stores.shared import EntityNotFound @@ -95,9 +94,7 @@ def get_rounds(): kwargs = request.args.to_dict() - rounds = round_store.list( - limit, skip, sort_key, sort_order, use_typing=False, **kwargs - ) + rounds = round_store.list(limit, skip, sort_key, sort_order, use_typing=False, **kwargs) result = rounds["result"] @@ -176,9 +173,7 @@ def list_rounds(): kwargs = get_post_data_to_kwargs(request) - rounds = round_store.list( - limit, skip, sort_key, sort_order, use_typing=False, **kwargs - ) + rounds = round_store.list(limit, skip, sort_key, sort_order, use_typing=False, **kwargs) result = rounds["result"] diff --git a/fedn/network/api/v1/session_routes.py b/fedn/network/api/v1/session_routes.py index 99c52d8db..ccfde590a 100644 --- a/fedn/network/api/v1/session_routes.py +++ b/fedn/network/api/v1/session_routes.py @@ -1,8 +1,7 @@ from flask import Blueprint, jsonify, request from fedn.network.api.auth import jwt_auth_required -from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs, - get_typed_list_headers, mdb) +from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers, mdb from fedn.network.storage.statestore.stores.session_store import SessionStore from fedn.network.storage.statestore.stores.shared import EntityNotFound @@ -87,9 +86,7 @@ def get_sessions(): limit, skip, sort_key, sort_order, _ = get_typed_list_headers(request.headers) kwargs = request.args.to_dict() - sessions = session_store.list( - limit, skip, sort_key, sort_order, use_typing=False, **kwargs - ) + sessions = session_store.list(limit, skip, sort_key, sort_order, use_typing=False, **kwargs) result = sessions["result"] @@ -167,9 +164,7 @@ def list_sessions(): limit, skip, sort_key, sort_order, _ = get_typed_list_headers(request.headers) kwargs = get_post_data_to_kwargs(request) - sessions = session_store.list( - limit, skip, sort_key, sort_order, use_typing=False, **kwargs - ) + sessions = session_store.list(limit, skip, sort_key, sort_order, use_typing=False, **kwargs) result = sessions["result"] diff --git a/fedn/network/api/v1/shared.py b/fedn/network/api/v1/shared.py index 753414324..2fb6063c0 100644 --- a/fedn/network/api/v1/shared.py +++ b/fedn/network/api/v1/shared.py @@ -3,8 +3,7 @@ import pymongo from pymongo.database import Database -from fedn.common.config import (get_modelstorage_config, get_network_config, - get_statestore_config) +from fedn.common.config import get_modelstorage_config, get_network_config, get_statestore_config api_version = "v1" @@ -58,9 +57,7 @@ def get_typed_list_headers( use_typing: bool = get_use_typing(headers) if sort_order is not None: - sort_order = ( - pymongo.ASCENDING if sort_order.lower() == "asc" else pymongo.DESCENDING - ) + sort_order = pymongo.ASCENDING if sort_order.lower() == "asc" else pymongo.DESCENDING else: sort_order = pymongo.DESCENDING diff --git a/fedn/network/api/v1/status_routes.py b/fedn/network/api/v1/status_routes.py index e78c18533..b88772b01 100644 --- a/fedn/network/api/v1/status_routes.py +++ b/fedn/network/api/v1/status_routes.py @@ -1,9 +1,7 @@ from flask import Blueprint, jsonify, request from fedn.network.api.auth import jwt_auth_required -from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs, - get_typed_list_headers, get_use_typing, - mdb) +from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers, get_use_typing, mdb from fedn.network.storage.statestore.stores.shared import EntityNotFound from fedn.network.storage.statestore.stores.status_store import StatusStore @@ -123,20 +121,12 @@ def get_statuses(): type: string """ try: - limit, skip, sort_key, sort_order, use_typing = get_typed_list_headers( - request.headers - ) + limit, skip, sort_key, sort_order, use_typing = get_typed_list_headers(request.headers) kwargs = request.args.to_dict() - statuses = status_store.list( - limit, skip, sort_key, sort_order, use_typing=use_typing, **kwargs - ) + statuses = status_store.list(limit, skip, sort_key, sort_order, use_typing=use_typing, **kwargs) - result = ( - [status.__dict__ for status in statuses["result"]] - if use_typing - else statuses["result"] - ) + result = [status.__dict__ for status in statuses["result"]] if use_typing else statuses["result"] response = {"count": statuses["count"], "result": result} @@ -226,20 +216,12 @@ def list_statuses(): type: string """ try: - limit, skip, sort_key, sort_order, use_typing = get_typed_list_headers( - request.headers - ) + limit, skip, sort_key, sort_order, use_typing = get_typed_list_headers(request.headers) kwargs = get_post_data_to_kwargs(request) - statuses = status_store.list( - limit, skip, sort_key, sort_order, use_typing=use_typing, **kwargs - ) + statuses = status_store.list(limit, skip, sort_key, sort_order, use_typing=use_typing, **kwargs) - result = ( - [status.__dict__ for status in statuses["result"]] - if use_typing - else statuses["result"] - ) + result = [status.__dict__ for status in statuses["result"]] if use_typing else statuses["result"] response = {"count": statuses["count"], "result": result} diff --git a/fedn/network/api/v1/validation_routes.py b/fedn/network/api/v1/validation_routes.py index 96fbac55c..59767e3e8 100644 --- a/fedn/network/api/v1/validation_routes.py +++ b/fedn/network/api/v1/validation_routes.py @@ -1,12 +1,9 @@ from flask import Blueprint, jsonify, request from fedn.network.api.auth import jwt_auth_required -from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs, - get_typed_list_headers, get_use_typing, - mdb) +from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers, get_use_typing, mdb from fedn.network.storage.statestore.stores.shared import EntityNotFound -from fedn.network.storage.statestore.stores.validation_store import \ - ValidationStore +from fedn.network.storage.statestore.stores.validation_store import ValidationStore bp = Blueprint("validation", __name__, url_prefix=f"/api/{api_version}/validations") @@ -131,20 +128,12 @@ def get_validations(): type: string """ try: - limit, skip, sort_key, sort_order, use_typing = get_typed_list_headers( - request.headers - ) + limit, skip, sort_key, sort_order, use_typing = get_typed_list_headers(request.headers) kwargs = request.args.to_dict() - validations = validation_store.list( - limit, skip, sort_key, sort_order, use_typing=use_typing, **kwargs - ) + validations = validation_store.list(limit, skip, sort_key, sort_order, use_typing=use_typing, **kwargs) - result = ( - [validation.__dict__ for validation in validations["result"]] - if use_typing - else validations["result"] - ) + result = [validation.__dict__ for validation in validations["result"]] if use_typing else validations["result"] response = {"count": validations["count"], "result": result} @@ -237,20 +226,12 @@ def list_validations(): type: string """ try: - limit, skip, sort_key, sort_order, use_typing = get_typed_list_headers( - request.headers - ) + limit, skip, sort_key, sort_order, use_typing = get_typed_list_headers(request.headers) kwargs = get_post_data_to_kwargs(request) - validations = validation_store.list( - limit, skip, sort_key, sort_order, use_typing=use_typing, **kwargs - ) + validations = validation_store.list(limit, skip, sort_key, sort_order, use_typing=use_typing, **kwargs) - result = ( - [validation.__dict__ for validation in validations["result"]] - if use_typing - else validations["result"] - ) + result = [validation.__dict__ for validation in validations["result"]] if use_typing else validations["result"] response = {"count": validations["count"], "result": result} diff --git a/fedn/network/clients/client.py b/fedn/network/clients/client.py index e830006e4..c8a5afc4f 100644 --- a/fedn/network/clients/client.py +++ b/fedn/network/clients/client.py @@ -22,18 +22,16 @@ import fedn.network.grpc.fedn_pb2 as fedn import fedn.network.grpc.fedn_pb2_grpc as rpc from fedn.common.config import FEDN_AUTH_SCHEME, FEDN_PACKAGE_EXTRACT_DIR -from fedn.common.log_config import (logger, set_log_level_from_string, - set_log_stream) +from fedn.common.log_config import logger, set_log_level_from_string, set_log_stream from fedn.network.clients.connect import ConnectorClient, Status from fedn.network.clients.package import PackageRuntime from fedn.network.clients.state import ClientState, ClientStateToString -from fedn.network.combiner.modelservice import (get_tmp_path, - upload_request_generator) +from fedn.network.combiner.modelservice import get_tmp_path, upload_request_generator from fedn.utils.dispatcher import Dispatcher from fedn.utils.helpers.helpers import get_helper CHUNK_SIZE = 1024 * 1024 -VALID_NAME_REGEX = '^[a-zA-Z0-9_-]*$' +VALID_NAME_REGEX = "^[a-zA-Z0-9_-]*$" class GrpcAuth(grpc.AuthMetadataPlugin): @@ -41,7 +39,7 @@ def __init__(self, key): self._key = key def __call__(self, context, callback): - callback((('authorization', f'{FEDN_AUTH_SCHEME} {self._key}'),), None) + callback((("authorization", f"{FEDN_AUTH_SCHEME} {self._key}"),), None) class Client: @@ -61,30 +59,32 @@ def __init__(self, config): self._missed_heartbeat = 0 self.config = config self.trace_attribs = False - set_log_level_from_string(config.get('verbosity', "INFO")) - set_log_stream(config.get('logfile', None)) - - self.connector = ConnectorClient(host=config['discover_host'], - port=config['discover_port'], - token=config['token'], - name=config['name'], - remote_package=config['remote_compute_context'], - force_ssl=config['force_ssl'], - verify=config['verify'], - combiner=config['preferred_combiner'], - id=config['client_id']) + set_log_level_from_string(config.get("verbosity", "INFO")) + set_log_stream(config.get("logfile", None)) + + self.connector = ConnectorClient( + host=config["discover_host"], + port=config["discover_port"], + token=config["token"], + name=config["name"], + remote_package=config["remote_compute_context"], + force_ssl=config["force_ssl"], + verify=config["verify"], + combiner=config["preferred_combiner"], + id=config["client_id"], + ) # Validate client name - match = re.search(VALID_NAME_REGEX, config['name']) + match = re.search(VALID_NAME_REGEX, config["name"]) if not match: - raise ValueError('Unallowed character in client name. Allowed characters: a-z, A-Z, 0-9, _, -.') + raise ValueError("Unallowed character in client name. Allowed characters: a-z, A-Z, 0-9, _, -.") # Folder where the client will store downloaded compute package and logs - self.name = config['name'] + self.name = config["name"] if FEDN_PACKAGE_EXTRACT_DIR: self.run_path = os.path.join(os.getcwd(), FEDN_PACKAGE_EXTRACT_DIR) else: - dirname = self.name+"-"+time.strftime("%Y%m%d-%H%M%S") + dirname = self.name + "-" + time.strftime("%Y%m%d-%H%M%S") self.run_path = os.path.join(os.getcwd(), dirname) if not os.path.exists(self.run_path): os.mkdir(self.run_path) @@ -102,8 +102,7 @@ def __init__(self, config): self._initialize_helper(combiner_config) if not self.helper: - logger.warning("Failed to retrieve helper class settings: {}".format( - combiner_config)) + logger.warning("Failed to retrieve helper class settings: {}".format(combiner_config)) self._subscribe_to_combiner(self.config) @@ -147,14 +146,14 @@ def _add_grpc_metadata(self, key, value): :type value: str """ # Check if metadata exists and add if not - if not hasattr(self, 'metadata'): + if not hasattr(self, "metadata"): self.metadata = () # Check if metadata key already exists and replace value if so for i, (k, v) in enumerate(self.metadata): if k == key: # Replace value - self.metadata = self.metadata[:i] + ((key, value),) + self.metadata[i + 1:] + self.metadata = self.metadata[:i] + ((key, value),) + self.metadata[i + 1 :] return # Set metadata using tuple concatenation @@ -186,39 +185,38 @@ def connect(self, combiner_config): return None # TODO use the combiner_config['certificate'] for setting up secure comms' - host = combiner_config['host'] + host = combiner_config["host"] # Add host to gRPC metadata - self._add_grpc_metadata('grpc-server', host) + self._add_grpc_metadata("grpc-server", host) logger.debug("Client using metadata: {}.".format(self.metadata)) - port = combiner_config['port'] + port = combiner_config["port"] secure = False - if combiner_config['fqdn'] is not None: - host = combiner_config['fqdn'] + if combiner_config["fqdn"] is not None: + host = combiner_config["fqdn"] # assuming https if fqdn is used port = 443 logger.info(f"Initiating connection to combiner host at: {host}:{port}") - if combiner_config['certificate']: + if combiner_config["certificate"]: logger.info("Utilizing CA certificate for GRPC channel authentication.") secure = True - cert = base64.b64decode( - combiner_config['certificate']) # .decode('utf-8') + cert = base64.b64decode(combiner_config["certificate"]) # .decode('utf-8') credentials = grpc.ssl_channel_credentials(root_certificates=cert) channel = grpc.secure_channel("{}:{}".format(host, str(port)), credentials) elif os.getenv("FEDN_GRPC_ROOT_CERT_PATH"): secure = True logger.info("Using root certificate from environment variable for GRPC channel.") - with open(os.environ["FEDN_GRPC_ROOT_CERT_PATH"], 'rb') as f: + with open(os.environ["FEDN_GRPC_ROOT_CERT_PATH"], "rb") as f: credentials = grpc.ssl_channel_credentials(f.read()) channel = grpc.secure_channel("{}:{}".format(host, str(port)), credentials) - elif self.config['secure']: + elif self.config["secure"]: secure = True logger.info("Using CA certificate for GRPC channel.") cert = self._get_ssl_certificate(host, port=port) - credentials = grpc.ssl_channel_credentials(cert.encode('utf-8')) - if self.config['token']: - token = self.config['token'] + credentials = grpc.ssl_channel_credentials(cert.encode("utf-8")) + if self.config["token"]: + token = self.config["token"] auth_creds = grpc.metadata_call_credentials(GrpcAuth(token)) channel = grpc.secure_channel("{}:{}".format(host, str(port)), grpc.composite_channel_credentials(credentials, auth_creds)) else: @@ -227,9 +225,7 @@ def connect(self, combiner_config): logger.info("Using insecure GRPC channel.") if port == 443: port = 80 - channel = grpc.insecure_channel("{}:{}".format( - host, - str(port))) + channel = grpc.insecure_channel("{}:{}".format(host, str(port))) self.channel = channel @@ -237,12 +233,9 @@ def connect(self, combiner_config): self.combinerStub = rpc.CombinerStub(channel) self.modelStub = rpc.ModelServiceStub(channel) - logger.info("Successfully established {} connection to {}:{}".format("secure" if secure else "insecure", - host, - port)) + logger.info("Successfully established {} connection to {}:{}".format("secure" if secure else "insecure", host, port)) - logger.info("Using {} compute package.".format( - combiner_config["package"])) + logger.info("Using {} compute package.".format(combiner_config["package"])) self._connected = True @@ -265,8 +258,8 @@ def _initialize_helper(self, combiner_config): :return: """ - if 'helper_type' in combiner_config.keys(): - self.helper = get_helper(combiner_config['helper_type']) + if "helper_type" in combiner_config.keys(): + self.helper = get_helper(combiner_config["helper_type"]) def _subscribe_to_combiner(self, config): """Listen to combiner message stream and start all processing threads. @@ -277,12 +270,10 @@ def _subscribe_to_combiner(self, config): """ # Start sending heartbeats to the combiner. - threading.Thread(target=self._send_heartbeat, kwargs={ - 'update_frequency': config['heartbeat_interval']}, daemon=True).start() + threading.Thread(target=self._send_heartbeat, kwargs={"update_frequency": config["heartbeat_interval"]}, daemon=True).start() # Start listening for combiner training and validation messages - threading.Thread( - target=self._listen_to_task_stream, daemon=True).start() + threading.Thread(target=self._listen_to_task_stream, daemon=True).start() self._connected = True # Start processing the client message inbox @@ -294,7 +285,7 @@ def untar_package(self, package_runtime): return package_runpath def _initialize_dispatcher(self, config): - """ Initialize the dispatcher for the client. + """Initialize the dispatcher for the client. :param config: A configuration dictionary containing connection information for | the discovery service (controller) and settings governing e.g. @@ -310,11 +301,7 @@ def _initialize_dispatcher(self, config): while tries > 0: retval = pr.download( - host=config['discover_host'], - port=config['discover_port'], - token=config['token'], - force_ssl=config['force_ssl'], - secure=config['secure'] + host=config["discover_host"], port=config["discover_port"], token=config["token"], force_ssl=config["force_ssl"], secure=config["secure"] ) if retval: break @@ -323,10 +310,10 @@ def _initialize_dispatcher(self, config): tries -= 1 if retval: - if 'checksum' not in config: + if "checksum" not in config: logger.warning("Bypassing validation of package checksum. Ensure the package source is trusted.") else: - checks_out = pr.validate(config['checksum']) + checks_out = pr.validate(config["checksum"]) if not checks_out: logger.critical("Validation of local package failed. Client terminating.") self.error_state = True @@ -348,11 +335,14 @@ def _initialize_dispatcher(self, config): else: # TODO: Deprecate - dispatch_config = {'entry_points': - {'predict': {'command': 'python3 predict.py'}, - 'train': {'command': 'python3 train.py'}, - 'validate': {'command': 'python3 validate.py'}}} - from_path = os.path.join(os.getcwd(), 'client') + dispatch_config = { + "entry_points": { + "predict": {"command": "python3 predict.py"}, + "train": {"command": "python3 train.py"}, + "validate": {"command": "python3 validate.py"}, + } + } + from_path = os.path.join(os.getcwd(), "client") copy_tree(from_path, self.run_path) self.dispatcher = Dispatcher(dispatch_config, self.run_path) @@ -378,7 +368,6 @@ def get_model_from_combiner(self, id, timeout=20): try: for part in self.modelStub.Download(request, metadata=self.metadata): - if part.status == fedn.ModelStatus.IN_PROGRESS: data.write(part.data) @@ -436,7 +425,7 @@ def _listen_to_task_stream(self): r.sender.name = self.name r.sender.role = fedn.WORKER # Add client to metadata - self._add_grpc_metadata('client', self.name) + self._add_grpc_metadata("client", self.name) while self._connected: try: @@ -445,14 +434,19 @@ def _listen_to_task_stream(self): logger.debug("Received model update request from combiner: {}.".format(request)) if request.sender.role == fedn.COMBINER: # Process training request - self.send_status("Received model update request.", log_level=fedn.Status.AUDIT, - type=fedn.StatusType.MODEL_UPDATE_REQUEST, request=request, sesssion_id=request.session_id) + self.send_status( + "Received model update request.", + log_level=fedn.Status.AUDIT, + type=fedn.StatusType.MODEL_UPDATE_REQUEST, + request=request, + sesssion_id=request.session_id, + ) logger.info("Received model update request of type {} for model_id {}".format(request.type, request.model_id)) - if request.type == fedn.StatusType.MODEL_UPDATE and self.config['trainer']: - self.inbox.put(('train', request)) - elif request.type == fedn.StatusType.MODEL_VALIDATION and self.config['validator']: - self.inbox.put(('validate', request)) + if request.type == fedn.StatusType.MODEL_UPDATE and self.config["trainer"]: + self.inbox.put(("train", request)) + elif request.type == fedn.StatusType.MODEL_VALIDATION and self.config["validator"]: + self.inbox.put(("validate", request)) else: logger.error("Unknown request type: {}".format(request.type)) @@ -465,7 +459,7 @@ def _listen_to_task_stream(self): time.sleep(5) if status_code == grpc.StatusCode.UNAUTHENTICATED: details = e.details() - if details == 'Token expired': + if details == "Token expired": logger.warning("GRPC TaskStream: Token expired. Reconnecting.") self.detach() @@ -496,8 +490,7 @@ def _process_training_request(self, model_id: str, session_id: str = None): :rtype: tuple """ - self.send_status( - "\t Starting processing of training request for model_id {}".format(model_id), sesssion_id=session_id) + self.send_status("\t Starting processing of training request for model_id {}".format(model_id), sesssion_id=session_id) self.state = ClientState.training try: @@ -507,10 +500,10 @@ def _process_training_request(self, model_id: str, session_id: str = None): if mdl is None: logger.error("Could not retrieve model from combiner. Aborting training request.") return None, None - meta['fetch_model'] = time.time() - tic + meta["fetch_model"] = time.time() - tic inpath = self.helper.get_tmp_path() - with open(inpath, 'wb') as fh: + with open(inpath, "wb") as fh: fh.write(mdl.getbuffer()) outpath = self.helper.get_tmp_path() @@ -519,7 +512,7 @@ def _process_training_request(self, model_id: str, session_id: str = None): self.dispatcher.run_cmd("train {} {}".format(inpath, outpath)) - meta['exec_training'] = time.time() - tic + meta["exec_training"] = time.time() - tic tic = time.time() out_model = None @@ -530,21 +523,21 @@ def _process_training_request(self, model_id: str, session_id: str = None): # Stream model update to combiner server updated_model_id = uuid.uuid4() self.send_model_to_combiner(out_model, str(updated_model_id)) - meta['upload_model'] = time.time() - tic + meta["upload_model"] = time.time() - tic # Read the metadata file - with open(outpath+'-metadata', 'r') as fh: + with open(outpath + "-metadata", "r") as fh: training_metadata = json.loads(fh.read()) - meta['training_metadata'] = training_metadata + meta["training_metadata"] = training_metadata os.unlink(inpath) os.unlink(outpath) - os.unlink(outpath+'-metadata') + os.unlink(outpath + "-metadata") except Exception as e: logger.error("Could not process training request due to error: {}".format(e)) updated_model_id = None - meta = {'status': 'failed', 'error': str(e)} + meta = {"status": "failed", "error": str(e)} self.state = ClientState.idle @@ -564,12 +557,11 @@ def _process_validation_request(self, model_id: str, is_inference: bool, session """ # Figure out cmd if is_inference: - cmd = 'infer' + cmd = "infer" else: - cmd = 'validate' + cmd = "validate" - self.send_status( - f"Processing {cmd} request for model_id {model_id}", sesssion_id=session_id) + self.send_status(f"Processing {cmd} request for model_id {model_id}", sesssion_id=session_id) self.state = ClientState.validating try: model = self.get_model_from_combiner(str(model_id)) @@ -599,25 +591,22 @@ def _process_validation_request(self, model_id: str, is_inference: bool, session return validation def process_request(self): - """Process training and validation tasks. """ + """Process training and validation tasks.""" while True: - if not self._connected: return try: (task_type, request) = self.inbox.get(timeout=1.0) - if task_type == 'train': - + if task_type == "train": tic = time.time() self.state = ClientState.training - model_id, meta = self._process_training_request( - request.model_id, session_id=request.session_id) + model_id, meta = self._process_training_request(request.model_id, session_id=request.session_id) if meta is not None: - processing_time = time.time()-tic - meta['processing_time'] = processing_time - meta['config'] = request.data + processing_time = time.time() - tic + meta["processing_time"] = processing_time + meta["config"] = request.data if model_id is not None: # Send model update to combiner @@ -634,28 +623,31 @@ def process_request(self): try: _ = self.combinerStub.SendModelUpdate(update, metadata=self.metadata) - self.send_status("Model update completed.", log_level=fedn.Status.AUDIT, - type=fedn.StatusType.MODEL_UPDATE, request=update, sesssion_id=request.session_id) + self.send_status( + "Model update completed.", + log_level=fedn.Status.AUDIT, + type=fedn.StatusType.MODEL_UPDATE, + request=update, + sesssion_id=request.session_id, + ) except grpc.RpcError as e: status_code = e.code() - logger.error("GRPC error, {}.".format( - status_code.name)) + logger.error("GRPC error, {}.".format(status_code.name)) logger.debug(e) except ValueError as e: logger.error("GRPC error, RPC channel closed. {}".format(e)) logger.debug(e) else: - self.send_status("Client {} failed to complete model update.", - log_level=fedn.Status.WARNING, - request=request, sesssion_id=request.session_id) + self.send_status( + "Client {} failed to complete model update.", log_level=fedn.Status.WARNING, request=request, sesssion_id=request.session_id + ) self.state = ClientState.idle self.inbox.task_done() - elif task_type == 'validate': + elif task_type == "validate": self.state = ClientState.validating - metrics = self._process_validation_request( - request.model_id, False, request.session_id) + metrics = self._process_validation_request(request.model_id, False, request.session_id) if metrics is not None: # Send validation @@ -671,23 +663,26 @@ def process_request(self): validation.session_id = request.session_id try: - _ = self.combinerStub.SendModelValidation( - validation, metadata=self.metadata) + _ = self.combinerStub.SendModelValidation(validation, metadata=self.metadata) status_type = fedn.StatusType.MODEL_VALIDATION - self.send_status("Model validation completed.", log_level=fedn.Status.AUDIT, - type=status_type, request=validation, sesssion_id=request.session_id) + self.send_status( + "Model validation completed.", log_level=fedn.Status.AUDIT, type=status_type, request=validation, sesssion_id=request.session_id + ) except grpc.RpcError as e: status_code = e.code() - logger.error("GRPC error, {}.".format( - status_code.name)) + logger.error("GRPC error, {}.".format(status_code.name)) logger.debug(e) except ValueError as e: logger.error("GRPC error, RPC channel closed. {}".format(e)) logger.debug(e) else: - self.send_status("Client {} failed to complete model validation.".format(self.name), - log_level=fedn.Status.WARNING, request=request, sesssion_id=request.session_id) + self.send_status( + "Client {} failed to complete model validation.".format(self.name), + log_level=fedn.Status.WARNING, + request=request, + sesssion_id=request.session_id, + ) self.state = ClientState.idle self.inbox.task_done() @@ -705,8 +700,7 @@ def _send_heartbeat(self, update_frequency=2.0): :rtype: None """ while True: - heartbeat = fedn.Heartbeat(sender=fedn.Client( - name=self.name, role=fedn.WORKER)) + heartbeat = fedn.Heartbeat(sender=fedn.Client(name=self.name, role=fedn.WORKER)) try: self.connectorStub.SendHeartbeat(heartbeat, metadata=self.metadata) self._missed_heartbeat = 0 @@ -714,13 +708,16 @@ def _send_heartbeat(self, update_frequency=2.0): status_code = e.code() if status_code == grpc.StatusCode.UNAVAILABLE: self._missed_heartbeat += 1 - logger.error("GRPC hearbeat: combiner unavailable, retrying (attempt {}/{}).".format(self._missed_heartbeat, - self.config['reconnect_after_missed_heartbeat'])) - if self._missed_heartbeat > self.config['reconnect_after_missed_heartbeat']: + logger.error( + "GRPC hearbeat: combiner unavailable, retrying (attempt {}/{}).".format( + self._missed_heartbeat, self.config["reconnect_after_missed_heartbeat"] + ) + ) + if self._missed_heartbeat > self.config["reconnect_after_missed_heartbeat"]: self.disconnect() if status_code == grpc.StatusCode.UNAUTHENTICATED: details = e.details() - if details == 'Token expired': + if details == "Token expired": logger.error("GRPC hearbeat: Token expired. Disconnecting.") self.disconnect() sys.exit("Unauthorized. Token expired. Please obtain a new token.") @@ -761,9 +758,7 @@ def send_status(self, msg, log_level=fedn.Status.INFO, type=None, request=None, if request is not None: status.data = MessageToJson(request) - self.logs.append( - "{} {} LOG LEVEL {} MESSAGE {}".format(str(datetime.now()), status.sender.name, status.log_level, - status.status)) + self.logs.append("{} {} LOG LEVEL {} MESSAGE {}".format(str(datetime.now()), status.sender.name, status.log_level, status.status)) try: _ = self.connectorStub.SendStatus(status, metadata=self.metadata) except grpc.RpcError as e: @@ -772,11 +767,11 @@ def send_status(self, msg, log_level=fedn.Status.INFO, type=None, request=None, logger.warning("GRPC SendStatus: server unavailable during send status.") if status_code == grpc.StatusCode.UNAUTHENTICATED: details = e.details() - if details == 'Token expired': + if details == "Token expired": logger.warning("GRPC SendStatus: Token expired.") def run(self): - """ Run the client. """ + """Run the client.""" try: cnt = 0 old_state = self.state diff --git a/fedn/network/clients/connect.py b/fedn/network/clients/connect.py index bbe2599b8..efeb3d1e9 100644 --- a/fedn/network/clients/connect.py +++ b/fedn/network/clients/connect.py @@ -8,14 +8,13 @@ import requests -from fedn.common.config import (FEDN_AUTH_REFRESH_TOKEN, - FEDN_AUTH_REFRESH_TOKEN_URI, FEDN_AUTH_SCHEME, - FEDN_CUSTOM_URL_PREFIX) +from fedn.common.config import FEDN_AUTH_REFRESH_TOKEN, FEDN_AUTH_REFRESH_TOKEN_URI, FEDN_AUTH_SCHEME, FEDN_CUSTOM_URL_PREFIX from fedn.common.log_config import logger class Status(enum.Enum): - """ Enum for representing the status of a client assignment.""" + """Enum for representing the status of a client assignment.""" + Unassigned = 0 Assigned = 1 TryAgain = 2 @@ -24,7 +23,7 @@ class Status(enum.Enum): class ConnectorClient: - """ Connector for assigning client to a combiner in the FEDn network. + """Connector for assigning client to a combiner in the FEDn network. :param host: host of discovery service :type host: str @@ -46,7 +45,6 @@ class ConnectorClient: """ def __init__(self, host, port, token, name, remote_package, force_ssl=False, verify=False, combiner=None, id=None): - self.host = host self.port = port self.token = token @@ -54,7 +52,7 @@ def __init__(self, host, port, token, name, remote_package, force_ssl=False, ver self.verify = verify self.preferred_combiner = combiner self.id = id - self.package = 'remote' if remote_package else 'local' + self.package = "remote" if remote_package else "local" # for https we assume a an ingress handles permanent redirect (308) if force_ssl: @@ -62,11 +60,9 @@ def __init__(self, host, port, token, name, remote_package, force_ssl=False, ver else: self.prefix = "http://" if self.port: - self.connect_string = "{}{}:{}".format( - self.prefix, self.host, self.port) + self.connect_string = "{}{}:{}".format(self.prefix, self.host, self.port) else: - self.connect_string = "{}{}".format( - self.prefix, self.host) + self.connect_string = "{}{}".format(self.prefix, self.host) logger.info("Setting connection string to {}.".format(self.connect_string)) @@ -79,26 +75,28 @@ def assign(self): """ try: retval = None - payload = {'client_id': self.name, 'preferred_combiner': self.preferred_combiner} - retval = requests.post(self.connect_string + FEDN_CUSTOM_URL_PREFIX + '/add_client', - json=payload, - verify=self.verify, - allow_redirects=True, - headers={'Authorization': f"{FEDN_AUTH_SCHEME} {self.token}"}) + payload = {"client_id": self.name, "preferred_combiner": self.preferred_combiner} + retval = requests.post( + self.connect_string + FEDN_CUSTOM_URL_PREFIX + "/add_client", + json=payload, + verify=self.verify, + allow_redirects=True, + headers={"Authorization": f"{FEDN_AUTH_SCHEME} {self.token}"}, + ) except Exception as e: - logger.debug('***** {}'.format(e)) + logger.debug("***** {}".format(e)) return Status.Unassigned, {} if retval.status_code == 400: # Get error messange from response - reason = retval.json()['message'] + reason = retval.json()["message"] return Status.UnMatchedConfig, reason if retval.status_code == 401: - if 'message' in retval.json(): - reason = retval.json()['message'] + if "message" in retval.json(): + reason = retval.json()["message"] logger.warning(reason) - if reason == 'Token expired': + if reason == "Token expired": status_code = self.refresh_token() if status_code >= 200 and status_code < 204: logger.info("Token refreshed.") @@ -109,19 +107,19 @@ def assign(self): return Status.UnAuthorized, reason if retval.status_code >= 200 and retval.status_code < 204: - if retval.json()['status'] == 'retry': - if 'message' in retval.json(): - reason = retval.json()['message'] + if retval.json()["status"] == "retry": + if "message" in retval.json(): + reason = retval.json()["message"] else: reason = "Reducer was not ready. Try again later." return Status.TryAgain, reason - reducer_package = retval.json()['package'] + reducer_package = retval.json()["package"] if reducer_package != self.package: - reason = "Unmatched config of compute package between client and reducer.\n" +\ - "Reducer uses {} package and client uses {}.".format( - reducer_package, self.package) + reason = "Unmatched config of compute package between client and reducer.\n" + "Reducer uses {} package and client uses {}.".format( + reducer_package, self.package + ) return Status.UnMatchedConfig, reason return Status.Assigned, retval.json() @@ -139,9 +137,6 @@ def refresh_token(self): logger.error("No refresh token URI/Token set, cannot refresh token.") return 401 - payload = requests.post(FEDN_AUTH_REFRESH_TOKEN_URI, - verify=self.verify, - allow_redirects=True, - json={'refresh': FEDN_AUTH_REFRESH_TOKEN}) - self.token = payload.json()['access'] + payload = requests.post(FEDN_AUTH_REFRESH_TOKEN_URI, verify=self.verify, allow_redirects=True, json={"refresh": FEDN_AUTH_REFRESH_TOKEN}) + self.token = payload.json()["access"] return payload.status_code diff --git a/fedn/network/clients/package.py b/fedn/network/clients/package.py index c06d79fc5..54f45b883 100644 --- a/fedn/network/clients/package.py +++ b/fedn/network/clients/package.py @@ -14,7 +14,7 @@ class PackageRuntime: - """ PackageRuntime is used to download, validate and unpack compute packages. + """PackageRuntime is used to download, validate and unpack compute packages. :param package_path: path to compute package :type package_path: str @@ -23,11 +23,13 @@ class PackageRuntime: """ def __init__(self, package_path): - - self.dispatch_config = {'entry_points': - {'predict': {'command': 'python3 predict.py'}, - 'train': {'command': 'python3 train.py'}, - 'validate': {'command': 'python3 validate.py'}}} + self.dispatch_config = { + "entry_points": { + "predict": {"command": "python3 predict.py"}, + "train": {"command": "python3 train.py"}, + "validate": {"command": "python3 validate.py"}, + } + } self.pkg_path = package_path self.pkg_name = None @@ -35,7 +37,7 @@ def __init__(self, package_path): self.expected_checksum = None def download(self, host, port, token, force_ssl=False, secure=False, name=None): - """ Download compute package from controller + """Download compute package from controller :param host: host of controller :param port: port of controller @@ -56,18 +58,16 @@ def download(self, host, port, token, force_ssl=False, secure=False, name=None): if name: path = path + "?name={}".format(name) - with requests.get(path, stream=True, verify=False, headers={'Authorization': f'{FEDN_AUTH_SCHEME} {token}'}) as r: + with requests.get(path, stream=True, verify=False, headers={"Authorization": f"{FEDN_AUTH_SCHEME} {token}"}) as r: if 200 <= r.status_code < 204: - - params = cgi.parse_header( - r.headers.get('Content-Disposition', ''))[-1] + params = cgi.parse_header(r.headers.get("Content-Disposition", ""))[-1] try: - self.pkg_name = params['filename'] + self.pkg_name = params["filename"] except KeyError: logger.error("No package returned.") return None r.raise_for_status() - with open(os.path.join(self.pkg_path, self.pkg_name), 'wb') as f: + with open(os.path.join(self.pkg_path, self.pkg_name), "wb") as f: for chunk in r.iter_content(chunk_size=8192): f.write(chunk) if port: @@ -77,19 +77,18 @@ def download(self, host, port, token, force_ssl=False, secure=False, name=None): if name: path = path + "?name={}".format(name) - with requests.get(path, verify=False, headers={'Authorization': f'{FEDN_AUTH_SCHEME} {token}'}) as r: + with requests.get(path, verify=False, headers={"Authorization": f"{FEDN_AUTH_SCHEME} {token}"}) as r: if 200 <= r.status_code < 204: - data = r.json() try: - self.checksum = data['checksum'] + self.checksum = data["checksum"] except Exception: logger.error("Could not extract checksum.") return True def validate(self, expected_checksum): - """ Validate the package against the checksum provided by the controller + """Validate the package against the checksum provided by the controller :param expected_checksum: checksum provided by the controller :return: True if checksums match, False otherwise @@ -107,37 +106,27 @@ def validate(self, expected_checksum): return False def unpack(self): - """ Unpack the compute package + """Unpack the compute package :return: True if unpacking was successful, False otherwise :rtype: bool """ if self.pkg_name: f = None - if self.pkg_name.endswith('tar.gz'): - f = tarfile.open(os.path.join( - self.pkg_path, self.pkg_name), 'r:gz') - if self.pkg_name.endswith('.tgz'): - f = tarfile.open(os.path.join( - self.pkg_path, self.pkg_name), 'r:gz') - if self.pkg_name.endswith('tar.bz2'): - f = tarfile.open(os.path.join( - self.pkg_path, self.pkg_name), 'r:bz2') + if self.pkg_name.endswith("tar.gz"): + f = tarfile.open(os.path.join(self.pkg_path, self.pkg_name), "r:gz") + if self.pkg_name.endswith(".tgz"): + f = tarfile.open(os.path.join(self.pkg_path, self.pkg_name), "r:gz") + if self.pkg_name.endswith("tar.bz2"): + f = tarfile.open(os.path.join(self.pkg_path, self.pkg_name), "r:bz2") else: - logger.error( - "Failed to unpack compute package, no pkg_name set." - "Has the reducer been configured with a compute package?" - ) + logger.error("Failed to unpack compute package, no pkg_name set." "Has the reducer been configured with a compute package?") return False try: if f: f.extractall(self.pkg_path) - logger.info( - "Successfully extracted compute package content in {}".format( - self.pkg_path - ) - ) + logger.info("Successfully extracted compute package content in {}".format(self.pkg_path)) # delete the tarball logger.info("Deleting temporary package tarball file.") os.remove(os.path.join(self.pkg_path, self.pkg_name)) @@ -157,7 +146,7 @@ def unpack(self): return False, "" def dispatcher(self, run_path): - """ Dispatch the compute package + """Dispatch the compute package :param run_path: path to dispatch the compute package :type run_path: str diff --git a/fedn/network/clients/state.py b/fedn/network/clients/state.py index 262f5862e..a349f846e 100644 --- a/fedn/network/clients/state.py +++ b/fedn/network/clients/state.py @@ -2,14 +2,15 @@ class ClientState(Enum): - """ Enum for representing the state of a client.""" + """Enum for representing the state of a client.""" + idle = 1 training = 2 validating = 3 def ClientStateToString(state): - """ Convert a ClientState to a string representation. + """Convert a ClientState to a string representation. :param state: the state to convert :type state: :class:`fedn.network.clients.state.ClientState` diff --git a/fedn/network/combiner/combiner.py b/fedn/network/combiner/combiner.py index b3b97fd6a..f19674b73 100644 --- a/fedn/network/combiner/combiner.py +++ b/fedn/network/combiner/combiner.py @@ -12,8 +12,7 @@ import fedn.network.grpc.fedn_pb2 as fedn import fedn.network.grpc.fedn_pb2_grpc as rpc -from fedn.common.log_config import (logger, set_log_level_from_string, - set_log_stream) +from fedn.common.log_config import logger, set_log_level_from_string, set_log_stream from fedn.network.combiner.connect import ConnectorCombiner, Status from fedn.network.combiner.modelservice import ModelService from fedn.network.combiner.roundhandler import RoundHandler @@ -21,11 +20,12 @@ from fedn.network.storage.s3.repository import Repository from fedn.network.storage.statestore.mongostatestore import MongoStateStore -VALID_NAME_REGEX = '^[a-zA-Z0-9_-]*$' +VALID_NAME_REGEX = "^[a-zA-Z0-9_-]*$" class Role(Enum): - """ Enum for combiner roles. """ + """Enum for combiner roles.""" + WORKER = 1 COMBINER = 2 REDUCER = 3 @@ -33,7 +33,7 @@ class Role(Enum): def role_to_proto_role(role): - """ Convert a Role to a proto Role. + """Convert a Role to a proto Role. :param role: the role to convert :type role: :class:`fedn.network.combiner.server.Role` @@ -51,17 +51,17 @@ def role_to_proto_role(role): class Combiner(rpc.CombinerServicer, rpc.ReducerServicer, rpc.ConnectorServicer, rpc.ControlServicer): - """ Combiner gRPC server. + """Combiner gRPC server. :param config: configuration for the combiner :type config: dict """ def __init__(self, config): - """ Initialize Combiner server.""" + """Initialize Combiner server.""" - set_log_level_from_string(config.get('verbosity', "INFO")) - set_log_stream(config.get('logfile', None)) + set_log_level_from_string(config.get("verbosity", "INFO")) + set_log_stream(config.get("logfile", None)) # Client queues self.clients = {} @@ -69,24 +69,26 @@ def __init__(self, config): self.modelservice = ModelService() # Validate combiner name - match = re.search(VALID_NAME_REGEX, config['name']) + match = re.search(VALID_NAME_REGEX, config["name"]) if not match: - raise ValueError('Unallowed character in combiner name. Allowed characters: a-z, A-Z, 0-9, _, -.') + raise ValueError("Unallowed character in combiner name. Allowed characters: a-z, A-Z, 0-9, _, -.") - self.id = config['name'] + self.id = config["name"] self.role = Role.COMBINER - self.max_clients = config['max_clients'] + self.max_clients = config["max_clients"] # Connector to announce combiner to discover service (reducer) - announce_client = ConnectorCombiner(host=config['discover_host'], - port=config['discover_port'], - myhost=config['host'], - fqdn=config['fqdn'], - myport=config['port'], - token=config['token'], - name=config['name'], - secure=config['secure'], - verify=config['verify']) + announce_client = ConnectorCombiner( + host=config["discover_host"], + port=config["discover_port"], + myhost=config["host"], + fqdn=config["fqdn"], + myport=config["port"], + token=config["token"], + name=config["name"], + secure=config["secure"], + verify=config["verify"], + ) while True: # Announce combiner to discover service @@ -107,27 +109,20 @@ def __init__(self, config): logger.info("Status.UnMatchedConfig") sys.exit("Exiting: Missing config") - cert = announce_config['certificate'] - key = announce_config['key'] + cert = announce_config["certificate"] + key = announce_config["key"] - if announce_config['certificate']: - cert = base64.b64decode(announce_config['certificate']) # .decode('utf-8') - key = base64.b64decode(announce_config['key']) # .decode('utf-8') + if announce_config["certificate"]: + cert = base64.b64decode(announce_config["certificate"]) # .decode('utf-8') + key = base64.b64decode(announce_config["key"]) # .decode('utf-8') # Set up gRPC server configuration - grpc_config = {'port': config['port'], - 'secure': config['secure'], - 'certificate': cert, - 'key': key} + grpc_config = {"port": config["port"], "secure": config["secure"], "certificate": cert, "key": key} # Set up model repository - self.repository = Repository( - announce_config['storage']['storage_config']) + self.repository = Repository(announce_config["storage"]["storage_config"]) - self.statestore = MongoStateStore( - announce_config['statestore']['network_id'], - announce_config['statestore']['mongo_config'] - ) + self.statestore = MongoStateStore(announce_config["statestore"]["network_id"], announce_config["statestore"]["mongo_config"]) # Create gRPC server self.server = Server(self, self.modelservice, grpc_config) @@ -144,7 +139,7 @@ def __init__(self, config): self.server.start() def __whoami(self, client, instance): - """ Set the client id and role in a proto message. + """Set the client id and role in a proto message. :param client: the client to set the id and role for :type client: :class:`fedn.network.grpc.fedn_pb2.Client` @@ -158,7 +153,7 @@ def __whoami(self, client, instance): return client def request_model_update(self, config, clients=[]): - """ Ask clients to update the current global model. + """Ask clients to update the current global model. :param config: the model configuration to send to clients :type config: dict @@ -168,12 +163,12 @@ def request_model_update(self, config, clients=[]): """ # The request to be added to the client queue request = fedn.TaskRequest() - request.model_id = config['model_id'] + request.model_id = config["model_id"] request.correlation_id = str(uuid.uuid4()) request.timestamp = str(datetime.now()) request.data = json.dumps(config) request.type = fedn.StatusType.MODEL_UPDATE - request.session_id = config['session_id'] + request.session_id = config["session_id"] request.sender.name = self.id request.sender.role = fedn.COMBINER @@ -187,14 +182,12 @@ def request_model_update(self, config, clients=[]): self._put_request_to_client_queue(request, fedn.Queue.TASK_QUEUE) if len(clients) < 20: - logger.info("Sent model update request for model {} to clients {}".format( - request.model_id, clients)) + logger.info("Sent model update request for model {} to clients {}".format(request.model_id, clients)) else: - logger.info("Sent model update request for model {} to {} clients".format( - request.model_id, len(clients))) + logger.info("Sent model update request for model {} to {} clients".format(request.model_id, len(clients))) def request_model_validation(self, model_id, config, clients=[]): - """ Ask clients to validate the current global model. + """Ask clients to validate the current global model. :param model_id: the model id to validate :type model_id: str @@ -225,14 +218,12 @@ def request_model_validation(self, model_id, config, clients=[]): self._put_request_to_client_queue(request, fedn.Queue.TASK_QUEUE) if len(clients) < 20: - logger.info("Sent model validation request for model {} to clients {}".format( - request.model_id, clients)) + logger.info("Sent model validation request for model {} to clients {}".format(request.model_id, clients)) else: - logger.info("Sent model validation request for model {} to {} clients".format( - request.model_id, len(clients))) + logger.info("Sent model validation request for model {} to {} clients".format(request.model_id, len(clients))) def get_active_trainers(self): - """ Get a list of active trainers. + """Get a list of active trainers. :return: the list of active trainers :rtype: list @@ -241,7 +232,7 @@ def get_active_trainers(self): return trainers def get_active_validators(self): - """ Get a list of active validators. + """Get a list of active validators. :return: the list of active validators :rtype: list @@ -250,7 +241,7 @@ def get_active_validators(self): return validators def nr_active_trainers(self): - """ Get the number of active trainers. + """Get the number of active trainers. :return: the number of active trainers :rtype: int @@ -260,7 +251,7 @@ def nr_active_trainers(self): #################################################################################################################### def __join_client(self, client): - """ Add a client to the list of active clients. + """Add a client to the list of active clients. :param client: the client to add :type client: :class:`fedn.network.grpc.fedn_pb2.Client` @@ -270,7 +261,7 @@ def __join_client(self, client): self.clients[client.name] = {"lastseen": datetime.now(), "status": "offline"} def _subscribe_client_to_queue(self, client, queue_name): - """ Subscribe a client to the queue. + """Subscribe a client to the queue. :param client: the client to subscribe :type client: :class:`fedn.network.grpc.fedn_pb2.Client` @@ -282,7 +273,7 @@ def _subscribe_client_to_queue(self, client, queue_name): self.clients[client.name][queue_name] = queue.Queue() def __get_queue(self, client, queue_name): - """ Get the queue for a client. + """Get the queue for a client. :param client: the client to get the queue for :type client: :class:`fedn.network.grpc.fedn_pb2.Client` @@ -299,7 +290,7 @@ def __get_queue(self, client, queue_name): raise def _list_subscribed_clients(self, queue_name): - """ List all clients subscribed to a queue. + """List all clients subscribed to a queue. :param queue_name: the name of the queue :type queue_name: str @@ -313,7 +304,7 @@ def _list_subscribed_clients(self, queue_name): return subscribed_clients def _list_active_clients(self, channel): - """ List all clients that have sent a status message in the last 10 seconds. + """List all clients that have sent a status message in the last 10 seconds. :param channel: the name of the channel :type channel: str @@ -350,14 +341,14 @@ def _list_active_clients(self, channel): return clients["active_clients"] def _deamon_thread_client_status(self, timeout=10): - """ Deamon thread that checks for inactive clients and updates statestore. """ + """Deamon thread that checks for inactive clients and updates statestore.""" while True: time.sleep(timeout) # TODO: Also update validation clients self._list_active_clients(fedn.Queue.TASK_QUEUE) def _put_request_to_client_queue(self, request, queue_name): - """ Get a client specific queue and add a request to it. + """Get a client specific queue and add a request to it. The client is identified by the request.receiver. :param request: the request to send @@ -369,14 +360,11 @@ def _put_request_to_client_queue(self, request, queue_name): q = self.__get_queue(request.receiver, queue_name) q.put(request) except Exception as e: - logger.error("Failed to put request to client queue {} for client {}: {}".format( - queue_name, - request.receiver.name, - str(e))) + logger.error("Failed to put request to client queue {} for client {}: {}".format(queue_name, request.receiver.name, str(e))) raise def _send_status(self, status): - """ Report a status to backend db. + """Report a status to backend db. :param status: the status to report :type status: :class:`fedn.network.grpc.fedn_pb2.Status` @@ -406,7 +394,7 @@ def _flush_model_update_queue(self): # Controller Service def Start(self, control: fedn.ControlRequest, context): - """ Start a round of federated learning" + """Start a round of federated learning" :param control: the control request :type control: :class:`fedn.network.grpc.fedn_pb2.ControlRequest` @@ -434,7 +422,7 @@ def Start(self, control: fedn.ControlRequest, context): return response def SetAggregator(self, control: fedn.ControlRequest, context): - """ Set the active aggregator. + """Set the active aggregator. :param control: the control request :type control: :class:`fedn.network.grpc.fedn_pb2.ControlRequest` @@ -451,14 +439,14 @@ def SetAggregator(self, control: fedn.ControlRequest, context): response = fedn.ControlResponse() if status: - response.message = 'Success' + response.message = "Success" else: - response.message = 'Failed' + response.message = "Failed" return response def FlushAggregationQueue(self, control: fedn.ControlRequest, context): - """ Flush the queue. + """Flush the queue. :param control: the control request :type control: :class:`fedn.network.grpc.fedn_pb2.ControlRequest` @@ -472,16 +460,16 @@ def FlushAggregationQueue(self, control: fedn.ControlRequest, context): response = fedn.ControlResponse() if status: - response.message = 'Success' + response.message = "Success" else: - response.message = 'Failed' + response.message = "Failed" return response ############################################################################## def Stop(self, control: fedn.ControlRequest, context): - """ TODO: Not yet implemented. + """TODO: Not yet implemented. :param control: the control request :type control: :class:`fedn.network.grpc.fedn_pb2.ControlRequest` @@ -497,7 +485,7 @@ def Stop(self, control: fedn.ControlRequest, context): ##################################################################################################################### def SendStatus(self, status: fedn.Status, context): - """ A client RPC endpoint that accepts status messages. + """A client RPC endpoint that accepts status messages. :param status: the status message :type status: :class:`fedn.network.grpc.fedn_pb2.Status` @@ -514,7 +502,7 @@ def SendStatus(self, status: fedn.Status, context): return response def ListActiveClients(self, request: fedn.ListClientsRequest, context): - """ RPC endpoint that returns a ClientList containing the names of all active clients. + """RPC endpoint that returns a ClientList containing the names of all active clients. An active client has sent a status message / responded to a heartbeat request in the last 10 seconds. @@ -538,7 +526,7 @@ def ListActiveClients(self, request: fedn.ListClientsRequest, context): return clients def AcceptingClients(self, request: fedn.ConnectionRequest, context): - """ RPC endpoint that returns a ConnectionResponse indicating whether the server + """RPC endpoint that returns a ConnectionResponse indicating whether the server is accepting clients or not. :param request: the request (unused) @@ -549,8 +537,7 @@ def AcceptingClients(self, request: fedn.ConnectionRequest, context): :rtype: :class:`fedn.network.grpc.fedn_pb2.ConnectionResponse` """ response = fedn.ConnectionResponse() - active_clients = self._list_active_clients( - fedn.Queue.TASK_QUEUE) + active_clients = self._list_active_clients(fedn.Queue.TASK_QUEUE) try: requested = int(self.max_clients) @@ -569,7 +556,7 @@ def AcceptingClients(self, request: fedn.ConnectionRequest, context): return response def SendHeartbeat(self, heartbeat: fedn.Heartbeat, context): - """ RPC that lets clients send a hearbeat, notifying the server that + """RPC that lets clients send a hearbeat, notifying the server that the client is available. :param heartbeat: the heartbeat @@ -594,7 +581,7 @@ def SendHeartbeat(self, heartbeat: fedn.Heartbeat, context): # Combiner Service def TaskStream(self, response, context): - """ A server stream RPC endpoint (Update model). Messages from client stream. + """A server stream RPC endpoint (Update model). Messages from client stream. :param response: the response :type response: :class:`fedn.network.grpc.fedn_pb2.ModelUpdateRequest` @@ -606,17 +593,15 @@ def TaskStream(self, response, context): metadata = context.invocation_metadata() if metadata: metadata = dict(metadata) - logger.info("grpc.Combiner.TaskStream: Client connected: {}\n".format(metadata['client'])) + logger.info("grpc.Combiner.TaskStream: Client connected: {}\n".format(metadata["client"])) - status = fedn.Status( - status="Client {} connecting to TaskStream.".format(client.name)) + status = fedn.Status(status="Client {} connecting to TaskStream.".format(client.name)) status.log_level = fedn.Status.INFO status.timestamp.GetCurrentTime() self.__whoami(status.sender, self) - self._subscribe_client_to_queue( - client, fedn.Queue.TASK_QUEUE) + self._subscribe_client_to_queue(client, fedn.Queue.TASK_QUEUE) q = self.__get_queue(client, fedn.Queue.TASK_QUEUE) self._send_status(status) @@ -637,7 +622,7 @@ def TaskStream(self, response, context): logger.error("Error in ModelUpdateRequestStream: {}".format(e)) def SendModelUpdate(self, request, context): - """ Send a model update response. + """Send a model update response. :param request: the request :type request: :class:`fedn.network.grpc.fedn_pb2.ModelUpdate` @@ -649,8 +634,7 @@ def SendModelUpdate(self, request, context): self.round_handler.aggregator.on_model_update(request) response = fedn.Response() - response.response = "RECEIVED ModelUpdate {} from client {}".format( - response, response.sender.name) + response.response = "RECEIVED ModelUpdate {} from client {}".format(response, response.sender.name) return response # TODO Fill later def register_model_validation(self, validation): @@ -663,7 +647,7 @@ def register_model_validation(self, validation): self.statestore.report_validation(validation) def SendModelValidation(self, request, context): - """ Send a model validation response. + """Send a model validation response. :param request: the request :type request: :class:`fedn.network.grpc.fedn_pb2.ModelValidation` @@ -677,17 +661,15 @@ def SendModelValidation(self, request, context): self.register_model_validation(request) response = fedn.Response() - response.response = "RECEIVED ModelValidation {} from client {}".format( - response, response.sender.name) + response.response = "RECEIVED ModelValidation {} from client {}".format(response, response.sender.name) return response #################################################################################################################### def run(self): - """ Start the server.""" + """Start the server.""" - logger.info("COMBINER: {} started, ready for gRPC requests.".format( - self.id)) + logger.info("COMBINER: {} started, ready for gRPC requests.".format(self.id)) try: while True: signal.pause() diff --git a/fedn/network/combiner/connect.py b/fedn/network/combiner/connect.py index 7dc388261..e144baa94 100644 --- a/fedn/network/combiner/connect.py +++ b/fedn/network/combiner/connect.py @@ -13,7 +13,8 @@ class Status(enum.Enum): - """ Enum for representing the status of a combiner announcement.""" + """Enum for representing the status of a combiner announcement.""" + Unassigned = 0 Assigned = 1 TryAgain = 2 @@ -22,7 +23,7 @@ class Status(enum.Enum): class ConnectorCombiner: - """ Connector for annnouncing combiner to the FEDn network. + """Connector for annnouncing combiner to the FEDn network. :param host: host of discovery service :type host: str @@ -45,7 +46,7 @@ class ConnectorCombiner: """ def __init__(self, host, port, myhost, fqdn, myport, token, name, secure=False, verify=False): - """ Initialize the ConnectorCombiner. + """Initialize the ConnectorCombiner. :param host: The host of the discovery service. :type host: str @@ -73,22 +74,20 @@ def __init__(self, host, port, myhost, fqdn, myport, token, name, secure=False, self.myhost = myhost self.myport = myport self.token = token - self.token_scheme = os.environ.get('FEDN_AUTH_SCHEME', 'Bearer') + self.token_scheme = os.environ.get("FEDN_AUTH_SCHEME", "Bearer") self.name = name self.secure = secure self.verify = verify if not self.token: - self.token = os.environ.get('FEDN_AUTH_TOKEN', None) + self.token = os.environ.get("FEDN_AUTH_TOKEN", None) # for https we assume a an ingress handles permanent redirect (308) self.prefix = "http://" if port: - self.connect_string = "{}{}:{}".format( - self.prefix, self.host, self.port) + self.connect_string = "{}{}:{}".format(self.prefix, self.host, self.port) else: - self.connect_string = "{}{}".format( - self.prefix, self.host) + self.connect_string = "{}{}".format(self.prefix, self.host) logger.info("Setting connection string to {}".format(self.connect_string)) @@ -99,24 +98,21 @@ def announce(self): :return: Tuple with announcement Status, FEDn network configuration if sucessful, else None. :rtype: :class:`fedn.network.combiner.connect.Status`, str """ - payload = { - "combiner_id": self.name, - "address": self.myhost, - "fqdn": self.fqdn, - "port": self.myport, - "secure_grpc": self.secure - } - url_prefix = os.environ.get('FEDN_CUSTOM_URL_PREFIX', '') + payload = {"combiner_id": self.name, "address": self.myhost, "fqdn": self.fqdn, "port": self.myport, "secure_grpc": self.secure} + url_prefix = os.environ.get("FEDN_CUSTOM_URL_PREFIX", "") try: - retval = requests.post(self.connect_string + url_prefix + '/add_combiner', json=payload, - verify=self.verify, - headers={'Authorization': f'{self.token_scheme} {self.token}'}) + retval = requests.post( + self.connect_string + url_prefix + "/add_combiner", + json=payload, + verify=self.verify, + headers={"Authorization": f"{self.token_scheme} {self.token}"}, + ) except Exception: return Status.Unassigned, {} if retval.status_code == 400: # Get error messange from response - reason = retval.json()['message'] + reason = retval.json()["message"] return Status.UnMatchedConfig, reason if retval.status_code == 401: @@ -124,8 +120,8 @@ def announce(self): return Status.UnAuthorized, reason if retval.status_code >= 200 and retval.status_code < 204: - if retval.json()['status'] == 'retry': - reason = retval.json()['message'] + if retval.json()["status"] == "retry": + reason = retval.json()["message"] return Status.TryAgain, reason return Status.Assigned, retval.json() diff --git a/fedn/network/combiner/interfaces.py b/fedn/network/combiner/interfaces.py index 69b987be0..f247c2bc1 100644 --- a/fedn/network/combiner/interfaces.py +++ b/fedn/network/combiner/interfaces.py @@ -15,7 +15,7 @@ class CombinerUnavailableError(Exception): class Channel: - """ Wrapper for a gRPC channel. + """Wrapper for a gRPC channel. :param address: The address for the gRPC server. :type address: str @@ -26,7 +26,7 @@ class Channel: """ def __init__(self, address, port, certificate=None): - """ Create a channel. + """Create a channel. If a valid certificate is given, a secure channel is created, else insecure. @@ -42,16 +42,13 @@ def __init__(self, address, port, certificate=None): self.certificate = certificate if self.certificate: - credentials = grpc.ssl_channel_credentials( - root_certificates=copy.deepcopy(certificate)) - self.channel = grpc.secure_channel('{}:{}'.format( - self.address, str(self.port)), credentials) + credentials = grpc.ssl_channel_credentials(root_certificates=copy.deepcopy(certificate)) + self.channel = grpc.secure_channel("{}:{}".format(self.address, str(self.port)), credentials) else: - self.channel = grpc.insecure_channel( - '{}:{}'.format(self.address, str(self.port))) + self.channel = grpc.insecure_channel("{}:{}".format(self.address, str(self.port))) def get_channel(self): - """ Get a channel. + """Get a channel. :return: An instance of a gRPC channel :rtype: :class:`grpc.Channel` @@ -60,7 +57,7 @@ def get_channel(self): class CombinerInterface: - """ Interface for the Combiner (aggregation server). + """Interface for the Combiner (aggregation server). Abstraction on top of the gRPC server servicer. :param parent: The parent combiner (controller) @@ -84,7 +81,7 @@ class CombinerInterface: """ def __init__(self, parent, name, address, fqdn, port, certificate=None, key=None, ip=None, config=None): - """ Initialize the combiner interface.""" + """Initialize the combiner interface.""" self.parent = parent self.name = name self.address = address @@ -95,15 +92,13 @@ def __init__(self, parent, name, address, fqdn, port, certificate=None, key=None self.ip = ip if not config: - self.config = { - 'max_clients': 8 - } + self.config = {"max_clients": 8} else: self.config = config @classmethod def from_json(combiner_config): - """ Initialize the combiner config from a json document. + """Initialize the combiner config from a json document. :parameter combiner_config: The combiner configuration. :type combiner_config: dict @@ -113,34 +108,34 @@ def from_json(combiner_config): return CombinerInterface(**combiner_config) def to_dict(self): - """ Export combiner configuration to a dictionary. + """Export combiner configuration to a dictionary. :return: A dictionary with the combiner configuration. :rtype: dict """ data = { - 'parent': self.parent, - 'name': self.name, - 'address': self.address, - 'fqdn': self.fqdn, - 'port': self.port, - 'ip': self.ip, - 'certificate': None, - 'key': None, - 'config': self.config + "parent": self.parent, + "name": self.name, + "address": self.address, + "fqdn": self.fqdn, + "port": self.port, + "ip": self.ip, + "certificate": None, + "key": None, + "config": self.config, } if self.certificate: cert_b64 = base64.b64encode(self.certificate) key_b64 = base64.b64encode(self.key) - data['certificate'] = str(cert_b64).split('\'')[1] - data['key'] = str(key_b64).split('\'')[1] + data["certificate"] = str(cert_b64).split("'")[1] + data["key"] = str(key_b64).split("'")[1] return data def to_json(self): - """ Export combiner configuration to json. + """Export combiner configuration to json. :return: A json document with the combiner configuration. :rtype: str @@ -148,34 +143,33 @@ def to_json(self): return json.dumps(self.to_dict()) def get_certificate(self): - """ Get combiner certificate. + """Get combiner certificate. :return: The combiner certificate. :rtype: str, None if no certificate is set. """ if self.certificate: cert_b64 = base64.b64encode(self.certificate) - return str(cert_b64).split('\'')[1] + return str(cert_b64).split("'")[1] else: return None def get_key(self): - """ Get combiner key. + """Get combiner key. :return: The combiner key. :rtype: str, None if no key is set. """ if self.key: key_b64 = base64.b64encode(self.key) - return str(key_b64).split('\'')[1] + return str(key_b64).split("'")[1] else: return None def flush_model_update_queue(self): - """ Reset the model update queue on the combiner. """ + """Reset the model update queue on the combiner.""" - channel = Channel(self.address, self.port, - self.certificate).get_channel() + channel = Channel(self.address, self.port, self.certificate).get_channel() control = rpc.ControlStub(channel) request = fedn.ControlRequest() @@ -189,14 +183,13 @@ def flush_model_update_queue(self): raise def set_aggregator(self, aggregator): - """ Set the active aggregator module. + """Set the active aggregator module. :param aggregator: The name of the aggregator module. :type config: str """ - channel = Channel(self.address, self.port, - self.certificate).get_channel() + channel = Channel(self.address, self.port, self.certificate).get_channel() control = rpc.ControlStub(channel) request = fedn.ControlRequest() @@ -213,15 +206,14 @@ def set_aggregator(self, aggregator): raise def submit(self, config): - """ Submit a compute plan to the combiner. + """Submit a compute plan to the combiner. :param config: The job configuration. :type config: dict :return: Server ControlResponse object. :rtype: :class:`fedn.network.grpc.fedn_pb2.ControlResponse` """ - channel = Channel(self.address, self.port, - self.certificate).get_channel() + channel = Channel(self.address, self.port, self.certificate).get_channel() control = rpc.ControlStub(channel) request = fedn.ControlRequest() request.command = fedn.Command.START @@ -241,7 +233,7 @@ def submit(self, config): return response def get_model(self, id, timeout=10): - """ Download a model from the combiner server. + """Download a model from the combiner server. :param id: The model id. :type id: str @@ -249,8 +241,7 @@ def get_model(self, id, timeout=10): :rtype: :class:`io.BytesIO`, None if the model is not available. """ - channel = Channel(self.address, self.port, - self.certificate).get_channel() + channel = Channel(self.address, self.port, self.certificate).get_channel() modelservice = rpc.ModelServiceStub(channel) data = BytesIO() @@ -276,13 +267,12 @@ def get_model(self, id, timeout=10): continue def allowing_clients(self): - """ Check if the combiner is allowing additional client connections. + """Check if the combiner is allowing additional client connections. :return: True if accepting, else False. :rtype: bool """ - channel = Channel(self.address, self.port, - self.certificate).get_channel() + channel = Channel(self.address, self.port, self.certificate).get_channel() connector = rpc.ConnectorStub(channel) request = fedn.ConnectionRequest() @@ -303,7 +293,7 @@ def allowing_clients(self): return False def list_active_clients(self, queue=1): - """ List active clients. + """List active clients. :param queue: The channel (queue) to use (optional). Default is 1 = MODEL_UPDATE_REQUESTS channel. see :class:`fedn.network.grpc.fedn_pb2.Channel` @@ -311,8 +301,7 @@ def list_active_clients(self, queue=1): :return: A list of active clients. :rtype: json """ - channel = Channel(self.address, self.port, - self.certificate).get_channel() + channel = Channel(self.address, self.port, self.certificate).get_channel() control = rpc.ConnectorStub(channel) request = fedn.ListClientsRequest() request.channel = queue diff --git a/fedn/network/combiner/modelservice.py b/fedn/network/combiner/modelservice.py index 59c63108b..0b50edbc7 100644 --- a/fedn/network/combiner/modelservice.py +++ b/fedn/network/combiner/modelservice.py @@ -21,11 +21,9 @@ def upload_request_generator(mdl, id): while True: b = mdl.read(CHUNK_SIZE) if b: - result = fedn.ModelRequest( - data=b, id=id, status=fedn.ModelStatus.IN_PROGRESS) + result = fedn.ModelRequest(data=b, id=id, status=fedn.ModelStatus.IN_PROGRESS) else: - result = fedn.ModelRequest( - id=id, data=None, status=fedn.ModelStatus.OK) + result = fedn.ModelRequest(id=id, data=None, status=fedn.ModelStatus.OK) yield result if not b: break @@ -47,14 +45,14 @@ def model_as_bytesIO(model): def get_tmp_path(): - """ Return a temporary output path compatible with save_model, load_model. """ + """Return a temporary output path compatible with save_model, load_model.""" fd, path = tempfile.mkstemp() os.close(fd) return path def load_model_from_BytesIO(model_bytesio, helper): - """ Load a model from a BytesIO object. + """Load a model from a BytesIO object. :param model_bytesio: A BytesIO object containing the model. :type model_bytesio: :class:`io.BytesIO` :param helper: The helper object for the model. @@ -63,7 +61,7 @@ def load_model_from_BytesIO(model_bytesio, helper): :rtype: return type of helper.load """ path = get_tmp_path() - with open(path, 'wb') as fh: + with open(path, "wb") as fh: fh.write(model_bytesio) fh.flush() model = helper.load(path) @@ -72,7 +70,7 @@ def load_model_from_BytesIO(model_bytesio, helper): def serialize_model_to_BytesIO(model, helper): - """ Serialize a model to a BytesIO object. + """Serialize a model to a BytesIO object. :param model: The model object. :type model: return type of helper.load @@ -85,7 +83,7 @@ def serialize_model_to_BytesIO(model, helper): a = BytesIO() a.seek(0, 0) - with open(outfile_name, 'rb') as f: + with open(outfile_name, "rb") as f: a.write(f.read()) a.seek(0) os.unlink(outfile_name) @@ -93,15 +91,13 @@ def serialize_model_to_BytesIO(model, helper): class ModelService(rpc.ModelServiceServicer): - """ Service for handling download and upload of models to the server. - - """ + """Service for handling download and upload of models to the server.""" def __init__(self): self.temp_model_storage = TempModelStorage() def exist(self, model_id): - """ Check if a model exists on the server. + """Check if a model exists on the server. :param model_id: The model id. :return: True if the model exists, else False. @@ -109,7 +105,7 @@ def exist(self, model_id): return self.temp_model_storage.exist(model_id) def get_model(self, id): - """ Download model with id 'id' from server. + """Download model with id 'id' from server. :param id: The model id. :type id: str @@ -131,7 +127,7 @@ def get_model(self, id): return None def set_model(self, model, id): - """ Upload model to server. + """Upload model to server. :param model: A model object (BytesIO) :type model: :class:`io.BytesIO` @@ -144,7 +140,7 @@ def set_model(self, model, id): # Model Service def Upload(self, request_iterator, context): - """ RPC endpoints for uploading a model. + """RPC endpoints for uploading a model. :param request_iterator: The model request iterator. :type request_iterator: :class:`fedn.network.grpc.fedn_pb2.ModelRequest` @@ -161,8 +157,7 @@ def Upload(self, request_iterator, context): self.temp_model_storage.set_model_metadata(request.id, fedn.ModelStatus.IN_PROGRESS) if request.status == fedn.ModelStatus.OK and not request.data: - result = fedn.ModelResponse(id=request.id, status=fedn.ModelStatus.OK, - message="Got model successfully.") + result = fedn.ModelResponse(id=request.id, status=fedn.ModelStatus.OK, message="Got model successfully.") # self.temp_model_storage_metadata.update({request.id: fedn.ModelStatus.OK}) self.temp_model_storage.set_model_metadata(request.id, fedn.ModelStatus.OK) self.temp_model_storage.get_ptr(request.id).flush() @@ -170,7 +165,7 @@ def Upload(self, request_iterator, context): return result def Download(self, request, context): - """ RPC endpoints for downloading a model. + """RPC endpoints for downloading a model. :param request: The model request object. :type request: :class:`fedn.network.grpc.fedn_pb2.ModelRequest` @@ -179,11 +174,11 @@ def Download(self, request, context): :return: A model response iterator. :rtype: :class:`fedn.network.grpc.fedn_pb2.ModelResponse` """ - logger.info(f'grpc.ModelService.Download: {request.sender.role}:{request.sender.name} requested model {request.id}') + logger.info(f"grpc.ModelService.Download: {request.sender.role}:{request.sender.name} requested model {request.id}") try: status = self.temp_model_storage.get_model_metadata(request.id) if status != fedn.ModelStatus.OK: - logger.error(f'model file is not ready: {request.id}, status: {status}') + logger.error(f"model file is not ready: {request.id}, status: {status}") yield fedn.ModelResponse(id=request.id, data=None, status=status) except Exception: logger.error("Error file does not exist: {}".format(request.id)) @@ -192,7 +187,7 @@ def Download(self, request, context): try: obj = self.temp_model_storage.get(request.id) if obj is None: - raise Exception(f'File not found: {request.id}') + raise Exception(f"File not found: {request.id}") with obj as f: while True: piece = f.read(CHUNK_SIZE) diff --git a/fedn/network/combiner/roundhandler.py b/fedn/network/combiner/roundhandler.py index 848142efc..2a8436e01 100644 --- a/fedn/network/combiner/roundhandler.py +++ b/fedn/network/combiner/roundhandler.py @@ -7,8 +7,7 @@ from fedn.common.log_config import logger from fedn.network.combiner.aggregators.aggregatorbase import get_aggregator -from fedn.network.combiner.modelservice import (load_model_from_BytesIO, - serialize_model_to_BytesIO) +from fedn.network.combiner.modelservice import load_model_from_BytesIO, serialize_model_to_BytesIO from fedn.utils.helpers.helpers import get_helper from fedn.utils.parameters import Parameters @@ -18,7 +17,7 @@ class ModelUpdateError(Exception): class RoundHandler: - """ Round handler. + """Round handler. The round handler processes requests from the global controller to produce model updates and perform model validations. @@ -34,7 +33,7 @@ class RoundHandler: """ def __init__(self, storage, server, modelservice): - """ Initialize the RoundHandler.""" + """Initialize the RoundHandler.""" self.round_configs = queue.Queue() self.storage = storage @@ -53,12 +52,12 @@ def push_round_config(self, round_config): :rtype: str """ try: - round_config['_job_id'] = str(uuid.uuid4()) + round_config["_job_id"] = str(uuid.uuid4()) self.round_configs.put(round_config) except Exception: logger.error("Failed to push round config.") raise - return round_config['_job_id'] + return round_config["_job_id"] def load_model_update(self, helper, model_id): """Load model update with id model_id into its memory representation. @@ -74,8 +73,7 @@ def load_model_update(self, helper, model_id): try: model = load_model_from_BytesIO(model_str.getbuffer(), helper) except IOError: - logger.warning( - "AGGREGATOR({}): Failed to load model!".format(self.name)) + logger.warning("AGGREGATOR({}): Failed to load model!".format(self.name)) else: raise ModelUpdateError("Failed to load model.") @@ -108,7 +106,7 @@ def load_model_update_str(self, model_id, retry=3): return model_str def waitforit(self, config, buffer_size=100, polling_interval=0.1): - """ Defines the policy for how long the server should wait before starting to aggregate models. + """Defines the policy for how long the server should wait before starting to aggregate models. The policy is as follows: 1. Wait a maximum of time_window time until the round times out. @@ -122,7 +120,7 @@ def waitforit(self, config, buffer_size=100, polling_interval=0.1): :type polling_interval: float, optional """ - time_window = float(config['round_timeout']) + time_window = float(config["round_timeout"]) tt = 0.0 while tt < time_window: @@ -143,22 +141,21 @@ def _training_round(self, config, clients): :rtype: model, dict """ - logger.info( - "ROUNDHANDLER: Initiating training round, participating clients: {}".format(clients)) + logger.info("ROUNDHANDLER: Initiating training round, participating clients: {}".format(clients)) meta = {} - meta['nr_expected_updates'] = len(clients) - meta['nr_required_updates'] = int(config['clients_required']) - meta['timeout'] = float(config['round_timeout']) + meta["nr_expected_updates"] = len(clients) + meta["nr_required_updates"] = int(config["clients_required"]) + meta["timeout"] = float(config["round_timeout"]) # Request model updates from all active clients. self.server.request_model_update(config, clients=clients) # If buffer_size is -1 (default), the round terminates when/if all clients have completed. - if int(config['buffer_size']) == -1: + if int(config["buffer_size"]) == -1: buffer_size = len(clients) else: - buffer_size = int(config['buffer_size']) + buffer_size = int(config["buffer_size"]) # Wait / block until the round termination policy has been met. self.waitforit(config, buffer_size=buffer_size) @@ -168,27 +165,25 @@ def _training_round(self, config, clients): data = None try: - helper = get_helper(config['helper_type']) - logger.info("Config delete_models_storage: {}".format(config['delete_models_storage'])) - if config['delete_models_storage'] == 'True': + helper = get_helper(config["helper_type"]) + logger.info("Config delete_models_storage: {}".format(config["delete_models_storage"])) + if config["delete_models_storage"] == "True": delete_models = True else: delete_models = False if "aggregator_kwargs" in config.keys(): - dict_parameters = ast.literal_eval(config['aggregator_kwargs']) + dict_parameters = ast.literal_eval(config["aggregator_kwargs"]) parameters = Parameters(dict_parameters) else: parameters = None - model, data = self.aggregator.combine_models(helper=helper, - delete_models=delete_models, - parameters=parameters) + model, data = self.aggregator.combine_models(helper=helper, delete_models=delete_models, parameters=parameters) except Exception as e: logger.warning("AGGREGATION FAILED AT COMBINER! {}".format(e)) - meta['time_combination'] = time.time() - tic - meta['aggregation_time'] = data + meta["time_combination"] = time.time() - tic + meta["aggregation_time"] = data return model, meta def _validation_round(self, config, clients, model_id): @@ -237,7 +232,7 @@ def stage_model(self, model_id, timeout_retry=3, retry=2): self.modelservice.set_model(model, model_id) def _assign_round_clients(self, n, type="trainers"): - """ Obtain a list of clients(trainers or validators) to ask for updates in this round. + """Obtain a list of clients(trainers or validators) to ask for updates in this round. :param n: Size of a random set taken from active trainers(clients), if n > "active trainers" all is used :type n: int @@ -276,30 +271,27 @@ def _check_nr_round_clients(self, config): """ active = self.server.nr_active_trainers() - if active >= int(config['clients_required']): - logger.info("Number of clients required ({0}) to start round met {1}.".format( - config['clients_required'], active)) + if active >= int(config["clients_required"]): + logger.info("Number of clients required ({0}) to start round met {1}.".format(config["clients_required"], active)) return True else: logger.info("Too few clients to start round.") return False def execute_validation_round(self, round_config): - """ Coordinate validation rounds as specified in config. + """Coordinate validation rounds as specified in config. :param round_config: The round config object. :type round_config: dict """ - model_id = round_config['model_id'] - logger.info( - "COMBINER orchestrating validation of model {}".format(model_id)) + model_id = round_config["model_id"] + logger.info("COMBINER orchestrating validation of model {}".format(model_id)) self.stage_model(model_id) - validators = self._assign_round_clients( - self.server.max_clients, type="validators") + validators = self._assign_round_clients(self.server.max_clients, type="validators") self._validation_round(round_config, validators, model_id) def execute_training_round(self, config): - """ Coordinates clients to execute training tasks. + """Coordinates clients to execute training tasks. :param config: The round config object. :type config: dict @@ -307,39 +299,37 @@ def execute_training_round(self, config): :rtype: dict """ - logger.info("Processing training round, job_id {}".format(config['_job_id'])) + logger.info("Processing training round, job_id {}".format(config["_job_id"])) data = {} - data['config'] = config - data['round_id'] = config['round_id'] + data["config"] = config + data["round_id"] = config["round_id"] # Download model to update and set in temp storage. - self.stage_model(config['model_id']) + self.stage_model(config["model_id"]) clients = self._assign_round_clients(self.server.max_clients) model, meta = self._training_round(config, clients) - data['data'] = meta + data["data"] = meta if model is None: - logger.warning( - "\t Failed to update global model in round {0}!".format(config['round_id'])) + logger.warning("\t Failed to update global model in round {0}!".format(config["round_id"])) if model is not None: - helper = get_helper(config['helper_type']) + helper = get_helper(config["helper_type"]) a = serialize_model_to_BytesIO(model, helper) model_id = self.storage.set_model(a.read(), is_file=False) a.close() - data['model_id'] = model_id + data["model_id"] = model_id - logger.info( - "TRAINING ROUND COMPLETED. Aggregated model id: {}, Job id: {}".format(model_id, config['_job_id'])) + logger.info("TRAINING ROUND COMPLETED. Aggregated model id: {}, Job id: {}".format(model_id, config["_job_id"])) # Delete temp model - self.modelservice.temp_model_storage.delete(config['model_id']) + self.modelservice.temp_model_storage.delete(config["model_id"]) return data def run(self, polling_interval=1.0): - """ Main control loop. Execute rounds based on round config on the queue. + """Main control loop. Execute rounds based on round config on the queue. :param polling_interval: The polling interval in seconds for checking if a new job/config is available. :type polling_interval: float @@ -354,23 +344,22 @@ def run(self, polling_interval=1.0): round_meta = {} if ready: - if round_config['task'] == 'training': + if round_config["task"] == "training": tic = time.time() round_meta = self.execute_training_round(round_config) - round_meta['time_exec_training'] = time.time() - \ - tic - round_meta['status'] = "Success" - round_meta['name'] = self.server.id + round_meta["time_exec_training"] = time.time() - tic + round_meta["status"] = "Success" + round_meta["name"] = self.server.id self.server.statestore.set_round_combiner_data(round_meta) - elif round_config['task'] == 'validation' or round_config['task'] == 'inference': + elif round_config["task"] == "validation" or round_config["task"] == "inference": self.execute_validation_round(round_config) else: logger.warning("config contains unkown task type.") else: round_meta = {} - round_meta['status'] = "Failed" - round_meta['reason'] = "Failed to meet client allocation requirements for this round config." - logger.warning("{0}".format(round_meta['reason'])) + round_meta["status"] = "Failed" + round_meta["reason"] = "Failed to meet client allocation requirements for this round config." + logger.warning("{0}".format(round_meta["reason"])) self.round_configs.task_done() except queue.Empty: diff --git a/fedn/network/config.py b/fedn/network/config.py index a9e8773f4..0c32949b8 100644 --- a/fedn/network/config.py +++ b/fedn/network/config.py @@ -6,16 +6,14 @@ class Config(ABC): class ReducerConfig(Config): - """ Configuration for the Reducer component. """ + """Configuration for the Reducer component.""" + compute_bundle_dir = None models_dir = None initial_model = None - storage_backend = { - 'type': 's3', 'settings': - {'bucket': 'models'} - } + storage_backend = {"type": "s3", "settings": {"bucket": "models"}} def __init__(self): pass diff --git a/fedn/network/controller/control.py b/fedn/network/controller/control.py index 2b64098cf..99a59469c 100644 --- a/fedn/network/controller/control.py +++ b/fedn/network/controller/control.py @@ -3,8 +3,7 @@ import time import uuid -from tenacity import (retry, retry_if_exception_type, stop_after_delay, - wait_random) +from tenacity import retry, retry_if_exception_type, stop_after_delay, wait_random from fedn.common.log_config import logger from fedn.network.combiner.interfaces import CombinerUnavailableError @@ -54,10 +53,10 @@ def __init__(self, message): class CombinersNotDoneException(Exception): - """ Exception class for when model is None """ + """Exception class for when model is None""" def __init__(self, message): - """ Constructor method. + """Constructor method. :param message: The exception message. :type message: str @@ -108,9 +107,9 @@ def session(self, config): last_round = int(self.get_latest_round_id()) for combiner in self.network.get_combiners(): - combiner.set_aggregator(config['aggregator']) + combiner.set_aggregator(config["aggregator"]) - self.set_session_status(config['session_id'], 'Started') + self.set_session_status(config["session_id"], "Started") # Execute the rounds in this session for round in range(1, int(config["rounds"] + 1)): # Increment the round number @@ -129,11 +128,11 @@ def session(self, config): config["model_id"] = self.statestore.get_latest_model() # TODO: Report completion of session - self.set_session_status(config['session_id'], 'Finished') + self.set_session_status(config["session_id"], "Finished") self._state = ReducerState.idle def round(self, session_config, round_id): - """ Execute one global round. + """Execute one global round. : param session_config: The session config. : type session_config: dict @@ -142,11 +141,11 @@ def round(self, session_config, round_id): """ - self.create_round({'round_id': round_id, 'status': "Pending"}) + self.create_round({"round_id": round_id, "status": "Pending"}) if len(self.network.get_combiners()) < 1: logger.warning("Round cannot start, no combiners connected!") - self.set_round_status(round_id, 'Failed') + self.set_round_status(round_id, "Failed") return None, self.statestore.get_round(round_id) # Assemble round config for this global round @@ -165,11 +164,10 @@ def round(self, session_config, round_id): round_start = self.evaluate_round_start_policy(participating_combiners) if round_start: - logger.info("round start policy met, {} participating combiners.".format( - len(participating_combiners))) + logger.info("round start policy met, {} participating combiners.".format(len(participating_combiners))) else: logger.warning("Round start policy not met, skipping round!") - self.set_round_status(round_id, 'Failed') + self.set_round_status(round_id, "Failed") return None, self.statestore.get_round(round_id) # Ask participating combiners to coordinate model updates @@ -181,18 +179,19 @@ def round(self, session_config, round_id): def do_if_round_times_out(result): logger.warning("Round timed out!") - @ retry(wait=wait_random(min=1.0, max=2.0), - stop=stop_after_delay(session_config['round_timeout']), - retry_error_callback=do_if_round_times_out, - retry=retry_if_exception_type(CombinersNotDoneException)) + @retry( + wait=wait_random(min=1.0, max=2.0), + stop=stop_after_delay(session_config["round_timeout"]), + retry_error_callback=do_if_round_times_out, + retry=retry_if_exception_type(CombinersNotDoneException), + ) def combiners_done(): - round = self.statestore.get_round(round_id) - if 'combiners' not in round: + if "combiners" not in round: logger.info("Waiting for combiners to update model...") raise CombinersNotDoneException("Combiners have not yet reported.") - if len(round['combiners']) < len(participating_combiners): + if len(round["combiners"]) < len(participating_combiners): logger.info("Waiting for combiners to update model...") raise CombinersNotDoneException("All combiners have not yet reported.") @@ -203,11 +202,10 @@ def combiners_done(): # Due to the distributed nature of the computation, there might be a # delay before combiners have reported the round data to the db, # so we need some robustness here. - @ retry(wait=wait_random(min=0.1, max=1.0), - retry=retry_if_exception_type(KeyError)) + @retry(wait=wait_random(min=0.1, max=1.0), retry=retry_if_exception_type(KeyError)) def check_combiners_done_reporting(): round = self.statestore.get_round(round_id) - combiners = round['combiners'] + combiners = round["combiners"] return combiners _ = check_combiners_done_reporting() @@ -216,7 +214,7 @@ def check_combiners_done_reporting(): round_valid = self.evaluate_round_validity_policy(round) if not round_valid: logger.error("Round failed. Invalid - evaluate_round_validity_policy: False") - self.set_round_status(round_id, 'Failed') + self.set_round_status(round_id, "Failed") return None, self.statestore.get_round(round_id) logger.info("Reducing combiner level models...") @@ -224,12 +222,12 @@ def check_combiners_done_reporting(): round_data = {} try: round = self.statestore.get_round(round_id) - model, data = self.reduce(round['combiners']) - round_data['reduce'] = data + model, data = self.reduce(round["combiners"]) + round_data["reduce"] = data logger.info("Done reducing models from combiners!") except Exception as e: logger.error("Failed to reduce models from combiners, reason: {}".format(e)) - self.set_round_status(round_id, 'Failed') + self.set_round_status(round_id, "Failed") return None, self.statestore.get_round(round_id) # Commit the new global model to the model trail @@ -237,20 +235,16 @@ def check_combiners_done_reporting(): logger.info("Committing global model to model trail...") tic = time.time() model_id = uuid.uuid4() - session_id = ( - session_config["session_id"] - if "session_id" in session_config - else None - ) + session_id = session_config["session_id"] if "session_id" in session_config else None self.commit(model_id, model, session_id) round_data["time_commit"] = time.time() - tic logger.info("Done committing global model to model trail.") else: logger.error("Failed to commit model to global model trail.") - self.set_round_status(round_id, 'Failed') + self.set_round_status(round_id, "Failed") return None, self.statestore.get_round(round_id) - self.set_round_status(round_id, 'Success') + self.set_round_status(round_id, "Success") # 4. Trigger participating combiner nodes to execute a validation round for the current model validate = session_config["validate"] @@ -261,8 +255,7 @@ def check_combiners_done_reporting(): combiner_config["task"] = "validation" combiner_config["helper_type"] = self.statestore.get_helper() - validating_combiners = self.get_participating_combiners( - combiner_config) + validating_combiners = self.get_participating_combiners(combiner_config) for combiner, combiner_config in validating_combiners: try: @@ -273,7 +266,7 @@ def check_combiners_done_reporting(): pass self.set_round_data(round_id, round_data) - self.set_round_status(round_id, 'Finished') + self.set_round_status(round_id, "Finished") return model_id, self.statestore.get_round(round_id) def reduce(self, combiners): @@ -292,15 +285,15 @@ def reduce(self, combiners): model = None for combiner in combiners: - name = combiner['name'] - model_id = combiner['model_id'] + name = combiner["name"] + model_id = combiner["model_id"] logger.info("Fetching model ({}) from model repository".format(model_id)) try: tic = time.time() data = self.model_repository.get_model(model_id) - meta['time_fetch_model'] += (time.time() - tic) + meta["time_fetch_model"] += time.time() - tic except Exception as e: logger.error("Failed to fetch model from model repository {}: {}".format(name, e)) data = None @@ -373,8 +366,7 @@ def inference_round(self, config): combiner_config["helper_type"] = self.statestore.get_framework() # Select combiners - validating_combiners = self.get_participating_combiners( - combiner_config) + validating_combiners = self.get_participating_combiners(combiner_config) # Test round start policy round_start = self.check_round_start_policy(validating_combiners) diff --git a/fedn/network/controller/controlbase.py b/fedn/network/controller/controlbase.py index e825e8e8b..d99bae40a 100644 --- a/fedn/network/controller/controlbase.py +++ b/fedn/network/controller/controlbase.py @@ -61,9 +61,7 @@ def __init__(self, statestore): raise MisconfiguredStorageBackend() if storage_config["storage_type"] == "S3": - self.model_repository = Repository( - storage_config["storage_config"] - ) + self.model_repository = Repository(storage_config["storage_config"]) else: logger.error("Unsupported storage backend, exiting.") raise UnsupportedStorageBackend() @@ -92,11 +90,7 @@ def get_helper(self): helper_type = self.statestore.get_helper() helper = fedn.utils.helpers.helpers.get_helper(helper_type) if not helper: - raise MisconfiguredHelper( - "Unsupported helper type {}, please configure compute_package.helper !".format( - helper_type - ) - ) + raise MisconfiguredHelper("Unsupported helper type {}, please configure compute_package.helper !".format(helper_type)) return helper def get_state(self): @@ -177,8 +171,8 @@ def get_compute_package(self, compute_package=""): else: return None - def create_session(self, config, status='Initialized'): - """ Initialize a new session in backend db. """ + def create_session(self, config, status="Initialized"): + """Initialize a new session in backend db.""" if "session_id" not in config.keys(): session_id = uuid.uuid4() @@ -191,7 +185,7 @@ def create_session(self, config, status='Initialized'): self.statestore.set_session_status(session_id, status) def set_session_status(self, session_id, status): - """ Set the round round stats. + """Set the round round stats. :param round_id: The round unique identifier :type round_id: str @@ -201,12 +195,12 @@ def set_session_status(self, session_id, status): self.statestore.set_session_status(session_id, status) def create_round(self, round_data): - """Initialize a new round in backend db. """ + """Initialize a new round in backend db.""" self.statestore.create_round(round_data) def set_round_data(self, round_id, round_data): - """ Set round data. + """Set round data. :param round_id: The round unique identifier :type round_id: str @@ -216,7 +210,7 @@ def set_round_data(self, round_id, round_data): self.statestore.set_round_data(round_id, round_data) def set_round_status(self, round_id, status): - """ Set the round round stats. + """Set the round round stats. :param round_id: The round unique identifier :type round_id: str @@ -226,7 +220,7 @@ def set_round_status(self, round_id, status): self.statestore.set_round_status(round_id, status) def set_round_config(self, round_id, round_config): - """ Upate round in backend db. + """Upate round in backend db. :param round_id: The round unique identifier :type round_id: str @@ -263,9 +257,7 @@ def commit(self, model_id, model=None, session_id=None): logger.info("Saving model file temporarily to disk...") outfile_name = helper.save(model) logger.info("CONTROL: Uploading model to Minio...") - model_id = self.model_repository.set_model( - outfile_name, is_file=True - ) + model_id = self.model_repository.set_model(outfile_name, is_file=True) logger.info("CONTROL: Deleting temporary model file...") os.unlink(outfile_name) @@ -292,16 +284,12 @@ def get_participating_combiners(self, combiner_round_config): self._handle_unavailable_combiner(combiner) continue - is_participating = self.evaluate_round_participation_policy( - combiner_round_config, nr_active_clients - ) + is_participating = self.evaluate_round_participation_policy(combiner_round_config, nr_active_clients) if is_participating: combiners.append((combiner, combiner_round_config)) return combiners - def evaluate_round_participation_policy( - self, compute_plan, nr_active_clients - ): + def evaluate_round_participation_policy(self, compute_plan, nr_active_clients): """Evaluate policy for combiner round-participation. A combiner participates if it is responsive and reports enough active clients to participate in the round. @@ -325,7 +313,7 @@ def evaluate_round_start_policy(self, combiners): return False def evaluate_round_validity_policy(self, round): - """ Check if the round is valid. + """Check if the round is valid. At the end of the round, before committing a model to the global model trail, we check if the round validity policy has been met. This can involve @@ -338,9 +326,9 @@ def evaluate_round_validity_policy(self, round): :rtype: bool """ model_ids = [] - for combiner in round['combiners']: + for combiner in round["combiners"]: try: - model_ids.append(combiner['model_id']) + model_ids.append(combiner["model_id"]) except KeyError: pass @@ -350,7 +338,7 @@ def evaluate_round_validity_policy(self, round): return True def state(self): - """ Get the current state of the controller. + """Get the current state of the controller. :return: The state :rype: str diff --git a/fedn/network/grpc/auth.py b/fedn/network/grpc/auth.py index d879cd812..e57926cd7 100644 --- a/fedn/network/grpc/auth.py +++ b/fedn/network/grpc/auth.py @@ -6,35 +6,33 @@ from fedn.network.api.auth import check_custom_claims ENDPOINT_ROLES_MAPPING = { - '/fedn.Combiner/TaskStream': ['client'], - '/fedn.Combiner/SendModelUpdate': ['client'], - '/fedn.Combiner/SendModelValidation': ['client'], - '/fedn.Connector/SendHeartbeat': ['client'], - '/fedn.Connector/SendStatus': ['client'], - '/fedn.ModelService/Download': ['client'], - '/fedn.ModelService/Upload': ['client'], - '/fedn.Control/Start': ['controller'], - '/fedn.Control/Stop': ['controller'], - '/fedn.Control/FlushAggregationQueue': ['controller'], - '/fedn.Control/SetAggregator': ['controller'], + "/fedn.Combiner/TaskStream": ["client"], + "/fedn.Combiner/SendModelUpdate": ["client"], + "/fedn.Combiner/SendModelValidation": ["client"], + "/fedn.Connector/SendHeartbeat": ["client"], + "/fedn.Connector/SendStatus": ["client"], + "/fedn.ModelService/Download": ["client"], + "/fedn.ModelService/Upload": ["client"], + "/fedn.Control/Start": ["controller"], + "/fedn.Control/Stop": ["controller"], + "/fedn.Control/FlushAggregationQueue": ["controller"], + "/fedn.Control/SetAggregator": ["controller"], } ENDPOINT_WHITELIST = [ - '/fedn.Connector/AcceptingClients', - '/fedn.Connector/ListActiveClients', - '/fedn.Control/Start', - '/fedn.Control/Stop', - '/fedn.Control/FlushAggregationQueue', - '/fedn.Control/SetAggregator', + "/fedn.Connector/AcceptingClients", + "/fedn.Connector/ListActiveClients", + "/fedn.Control/Start", + "/fedn.Control/Stop", + "/fedn.Control/FlushAggregationQueue", + "/fedn.Control/SetAggregator", ] -USER_AGENT_WHITELIST = [ - 'grpc_health_probe' -] +USER_AGENT_WHITELIST = ["grpc_health_probe"] def check_role_claims(payload, endpoint): - user_role = payload.get('role', '') + user_role = payload.get("role", "") # Perform endpoint-specific RBAC check allowed_roles = ENDPOINT_ROLES_MAPPING.get(endpoint) @@ -63,33 +61,33 @@ def intercept_service(self, continuation, handler_call_details): if handler_call_details.method in ENDPOINT_WHITELIST: return continuation(handler_call_details) # Pass if the request comes from whitelisted user agents - user_agent = metadata.get('user-agent').split(' ')[0] + user_agent = metadata.get("user-agent").split(" ")[0] if user_agent in USER_AGENT_WHITELIST: return continuation(handler_call_details) - token = metadata.get('authorization') + token = metadata.get("authorization") if token is None: - return _unary_unary_rpc_terminator(grpc.StatusCode.UNAUTHENTICATED, 'Token is missing') + return _unary_unary_rpc_terminator(grpc.StatusCode.UNAUTHENTICATED, "Token is missing") if not token.startswith(FEDN_AUTH_SCHEME): - return _unary_unary_rpc_terminator(grpc.StatusCode.UNAUTHENTICATED, f'Invalid token scheme, expected {FEDN_AUTH_SCHEME}') + return _unary_unary_rpc_terminator(grpc.StatusCode.UNAUTHENTICATED, f"Invalid token scheme, expected {FEDN_AUTH_SCHEME}") - token = token.split(' ')[1] + token = token.split(" ")[1] try: payload = jwt.decode(token, SECRET_KEY, algorithms=[FEDN_JWT_ALGORITHM]) if not check_role_claims(payload, handler_call_details.method): - return _unary_unary_rpc_terminator(grpc.StatusCode.PERMISSION_DENIED, 'Insufficient permissions') + return _unary_unary_rpc_terminator(grpc.StatusCode.PERMISSION_DENIED, "Insufficient permissions") if not check_custom_claims(payload): - return _unary_unary_rpc_terminator(grpc.StatusCode.PERMISSION_DENIED, 'Insufficient permissions') + return _unary_unary_rpc_terminator(grpc.StatusCode.PERMISSION_DENIED, "Insufficient permissions") return continuation(handler_call_details) except jwt.InvalidTokenError: - return _unary_unary_rpc_terminator(grpc.StatusCode.UNAUTHENTICATED, 'Invalid token') + return _unary_unary_rpc_terminator(grpc.StatusCode.UNAUTHENTICATED, "Invalid token") except jwt.ExpiredSignatureError: - return _unary_unary_rpc_terminator(grpc.StatusCode.UNAUTHENTICATED, 'Token expired') + return _unary_unary_rpc_terminator(grpc.StatusCode.UNAUTHENTICATED, "Token expired") except Exception as e: logger.error(str(e)) return _unary_unary_rpc_terminator(grpc.StatusCode.UNKNOWN, str(e)) diff --git a/fedn/network/loadbalancer/firstavailable.py b/fedn/network/loadbalancer/firstavailable.py index 9d44d3fbd..13dc766b2 100644 --- a/fedn/network/loadbalancer/firstavailable.py +++ b/fedn/network/loadbalancer/firstavailable.py @@ -2,7 +2,7 @@ class LeastPacked(LoadBalancerBase): - """ Load balancer that selects the first available combiner. + """Load balancer that selects the first available combiner. :param network: A handle to the network. :type network: class: `fedn.network.api.network.Network` @@ -12,7 +12,7 @@ def __init__(self, network): super().__init__(network) def find_combiner(self): - """ Find the first available combiner. """ + """Find the first available combiner.""" for combiner in self.network.get_combiners(): if combiner.allowing_clients(): diff --git a/fedn/network/loadbalancer/leastpacked.py b/fedn/network/loadbalancer/leastpacked.py index cac7bba54..a762701b0 100644 --- a/fedn/network/loadbalancer/leastpacked.py +++ b/fedn/network/loadbalancer/leastpacked.py @@ -3,7 +3,7 @@ class LeastPacked(LoadBalancerBase): - """ Load balancer that selects the combiner with the least number of attached training clients. + """Load balancer that selects the combiner with the least number of attached training clients. :param network: A handle to the network. :type network: class: `fedn.network.api.network.Network` @@ -14,7 +14,7 @@ def __init__(self, network): def find_combiner(self): """ - Find the combiner with the least number of attached clients. + Find the combiner with the least number of attached clients. """ min_clients = None diff --git a/fedn/network/loadbalancer/loadbalancerbase.py b/fedn/network/loadbalancer/loadbalancerbase.py index ff1edfa9b..ddbfa2231 100644 --- a/fedn/network/loadbalancer/loadbalancerbase.py +++ b/fedn/network/loadbalancer/loadbalancerbase.py @@ -2,7 +2,7 @@ class LoadBalancerBase(ABC): - """ Abstract base class for load balancers. + """Abstract base class for load balancers. :param network: A handle to the network. :type network: class: `fedn.network.api.network.Network` @@ -14,5 +14,5 @@ def __init__(self, network): @abstractmethod def find_combiner(self): - """ Find a combiner to connect to. """ + """Find a combiner to connect to.""" pass diff --git a/fedn/network/state.py b/fedn/network/state.py index 9d18bc924..fcbd391eb 100644 --- a/fedn/network/state.py +++ b/fedn/network/state.py @@ -2,7 +2,8 @@ class ReducerState(Enum): - """ Enum for representing the state of a reducer.""" + """Enum for representing the state of a reducer.""" + setup = 1 idle = 2 instructing = 3 @@ -10,7 +11,7 @@ class ReducerState(Enum): def ReducerStateToString(state): - """ Convert ReducerState to string. + """Convert ReducerState to string. :param state: The state. :type state: :class:`fedn.network.state.ReducerState` @@ -30,7 +31,7 @@ def ReducerStateToString(state): def StringToReducerState(state): - """ Convert string to ReducerState. + """Convert string to ReducerState. :param state: The state as string. :type state: str diff --git a/fedn/network/storage/models/memorymodelstorage.py b/fedn/network/storage/models/memorymodelstorage.py index e6ef9b07f..6a40a7ae0 100644 --- a/fedn/network/storage/models/memorymodelstorage.py +++ b/fedn/network/storage/models/memorymodelstorage.py @@ -8,25 +8,22 @@ class MemoryModelStorage(ModelStorage): - """ Class for in-memory storage of model artifacts. + """Class for in-memory storage of model artifacts. Models are stored as BytesIO objects in a dictionary. """ def __init__(self): - self.models = defaultdict(io.BytesIO) self.models_metadata = {} def exist(self, model_id): - if model_id in self.models.keys(): return True return False def get(self, model_id): - obj = self.models[model_id] obj.seek(0, 0) # Have to copy object to not mix up the file pointers when sending... fix in better way. @@ -42,9 +39,7 @@ def get_ptr(self, model_id): return self.models[model_id] def get_model_metadata(self, model_id): - return self.models_metadata[model_id] def set_model_metadata(self, model_id, model_metadata): - self.models_metadata.update({model_id: model_metadata}) diff --git a/fedn/network/storage/models/modelstorage.py b/fedn/network/storage/models/modelstorage.py index 4945ab37b..3062db36e 100644 --- a/fedn/network/storage/models/modelstorage.py +++ b/fedn/network/storage/models/modelstorage.py @@ -2,10 +2,9 @@ class ModelStorage(ABC): - @abstractmethod def exist(self, model_id): - """ Check if model exists in storage + """Check if model exists in storage :param model_id: The model id :type model_id: str @@ -16,7 +15,7 @@ def exist(self, model_id): @abstractmethod def get(self, model_id): - """ Get model from storage + """Get model from storage :param model_id: The model id :type model_id: str @@ -27,7 +26,7 @@ def get(self, model_id): @abstractmethod def get_model_metadata(self, model_id): - """ Get model metadata from storage + """Get model metadata from storage :param model_id: The model id :type model_id: str @@ -38,7 +37,7 @@ def get_model_metadata(self, model_id): @abstractmethod def set_model_metadata(self, model_id, model_metadata): - """ Set model metadata in storage + """Set model metadata in storage :param model_id: The model id :type model_id: str @@ -51,7 +50,7 @@ def set_model_metadata(self, model_id, model_metadata): @abstractmethod def delete(self, model_id): - """ Delete model from storage + """Delete model from storage :param model_id: The model id :type model_id: str @@ -62,7 +61,7 @@ def delete(self, model_id): @abstractmethod def delete_all(self): - """ Delete all models from storage + """Delete all models from storage :return: True if successful, False otherwise :rtype: bool diff --git a/fedn/network/storage/models/tempmodelstorage.py b/fedn/network/storage/models/tempmodelstorage.py index 492cee5ab..214fac4d7 100644 --- a/fedn/network/storage/models/tempmodelstorage.py +++ b/fedn/network/storage/models/tempmodelstorage.py @@ -9,12 +9,10 @@ class TempModelStorage(ModelStorage): - """ Class for managing local temporary models on file on combiners.""" + """Class for managing local temporary models on file on combiners.""" def __init__(self): - - self.default_dir = os.environ.get( - 'FEDN_MODEL_DIR', '/tmp/models') # set default to tmp + self.default_dir = os.environ.get("FEDN_MODEL_DIR", "/tmp/models") # set default to tmp if not os.path.exists(self.default_dir): os.makedirs(self.default_dir) @@ -22,13 +20,11 @@ def __init__(self): self.models_metadata = {} def exist(self, model_id): - if model_id in self.models.keys(): return True return False def get(self, model_id): - try: if self.models_metadata[model_id] != fedn.ModelStatus.OK: logger.warning("File not ready! Try again") @@ -38,7 +34,7 @@ def get(self, model_id): return None obj = BytesIO() - with open(os.path.join(self.default_dir, str(model_id)), 'rb') as f: + with open(os.path.join(self.default_dir, str(model_id)), "rb") as f: obj.write(f.read()) obj.seek(0, 0) @@ -51,15 +47,14 @@ def get_ptr(self, model_id): :return: """ try: - f = self.models[model_id]['file'] + f = self.models[model_id]["file"] except KeyError: f = open(os.path.join(self.default_dir, str(model_id)), "wb") - self.models[model_id] = {'file': f} - return self.models[model_id]['file'] + self.models[model_id] = {"file": f} + return self.models[model_id]["file"] def get_model_metadata(self, model_id): - try: status = self.models_metadata[model_id] except KeyError: @@ -67,12 +62,10 @@ def get_model_metadata(self, model_id): return status def set_model_metadata(self, model_id, model_metadata): - self.models_metadata.update({model_id: model_metadata}) # Delete model from disk def delete(self, model_id): - try: os.remove(os.path.join(self.default_dir, str(model_id))) logger.info("TEMPMODELSTORAGE: Deleted model with id: {}".format(model_id)) @@ -86,7 +79,6 @@ def delete(self, model_id): # Delete all models from disk def delete_all(self): - ids_pop = [] for model_id in self.models.keys(): try: diff --git a/fedn/network/storage/s3/base.py b/fedn/network/storage/s3/base.py index 8f22a2485..671b75d7d 100644 --- a/fedn/network/storage/s3/base.py +++ b/fedn/network/storage/s3/base.py @@ -6,7 +6,7 @@ class RepositoryBase(object): @abc.abstractmethod def set_artifact(self, instance_name, instance, bucket): - """ Set object with name object_name + """Set object with name object_name :param instance_name: The name of the object :tyep insance_name: str @@ -18,7 +18,7 @@ def set_artifact(self, instance_name, instance, bucket): @abc.abstractmethod def get_artifact(self, instance_name, bucket): - """ Retrive object with name instance_name. + """Retrive object with name instance_name. :param instance_name: The name of the object to retrieve :type instance_name: str @@ -29,7 +29,7 @@ def get_artifact(self, instance_name, bucket): @abc.abstractmethod def get_artifact_stream(self, instance_name, bucket): - """ Return a stream handler for object with name instance_name. + """Return a stream handler for object with name instance_name. :param instance_name: The name if the object :type instance_name: str diff --git a/fedn/network/storage/s3/miniorepository.py b/fedn/network/storage/s3/miniorepository.py index f6082b158..9c86b8997 100644 --- a/fedn/network/storage/s3/miniorepository.py +++ b/fedn/network/storage/s3/miniorepository.py @@ -9,12 +9,12 @@ class MINIORepository(RepositoryBase): - """ Class implementing Repository for MinIO. """ + """Class implementing Repository for MinIO.""" client = None def __init__(self, config): - """ Initialize object. + """Initialize object. :param config: Dictionary containing configuration for credentials and bucket names. :type config: dict @@ -23,34 +23,35 @@ def __init__(self, config): super().__init__() self.name = "MINIORepository" - if config['storage_secure_mode']: - manager = PoolManager( - num_pools=100, cert_reqs='CERT_NONE', assert_hostname=False) - self.client = Minio("{0}:{1}".format(config['storage_hostname'], config['storage_port']), - access_key=config['storage_access_key'], - secret_key=config['storage_secret_key'], - secure=config['storage_secure_mode'], http_client=manager) + if config["storage_secure_mode"]: + manager = PoolManager(num_pools=100, cert_reqs="CERT_NONE", assert_hostname=False) + self.client = Minio( + "{0}:{1}".format(config["storage_hostname"], config["storage_port"]), + access_key=config["storage_access_key"], + secret_key=config["storage_secret_key"], + secure=config["storage_secure_mode"], + http_client=manager, + ) else: - self.client = Minio("{0}:{1}".format(config['storage_hostname'], config['storage_port']), - access_key=config['storage_access_key'], - secret_key=config['storage_secret_key'], - secure=config['storage_secure_mode']) + self.client = Minio( + "{0}:{1}".format(config["storage_hostname"], config["storage_port"]), + access_key=config["storage_access_key"], + secret_key=config["storage_secret_key"], + secure=config["storage_secure_mode"], + ) def set_artifact(self, instance_name, instance, bucket, is_file=False): - if is_file: self.client.fput_object(bucket, instance_name, instance) else: try: - self.client.put_object( - bucket, instance_name, io.BytesIO(instance), len(instance)) + self.client.put_object(bucket, instance_name, io.BytesIO(instance), len(instance)) except Exception as e: raise Exception("Could not load data into bytes {}".format(e)) return True def get_artifact(self, instance_name, bucket): - try: data = self.client.get_object(bucket, instance_name) return data.read() @@ -61,7 +62,6 @@ def get_artifact(self, instance_name, bucket): data.release_conn() def get_artifact_stream(self, instance_name, bucket): - try: data = self.client.get_object(bucket, instance_name) return data @@ -69,7 +69,7 @@ def get_artifact_stream(self, instance_name, bucket): raise Exception("Could not fetch data from bucket, {}".format(e)) def list_artifacts(self, bucket): - """ List all objects in bucket. + """List all objects in bucket. :param bucket: Name of the bucket :type bucket: str @@ -81,12 +81,11 @@ def list_artifacts(self, bucket): for obj in objs: objects.append(obj.object_name) except Exception: - raise Exception( - "Could not list models in bucket {}".format(bucket)) + raise Exception("Could not list models in bucket {}".format(bucket)) return objects def delete_artifact(self, instance_name, bucket): - """ Delete object with name instance_name from buckets. + """Delete object with name instance_name from buckets. :param instance_name: The object name :param bucket: Buckets to delete from @@ -96,11 +95,11 @@ def delete_artifact(self, instance_name, bucket): try: self.client.remove_object(bucket, instance_name) except InvalidResponseError as err: - logger.error('Could not delete artifact: {0} err: {1}'.format(instance_name, err)) + logger.error("Could not delete artifact: {0} err: {1}".format(instance_name, err)) pass def create_bucket(self, bucket_name): - """ Create a new bucket. If bucket exists, do nothing. + """Create a new bucket. If bucket exists, do nothing. :param bucket_name: The name of the bucket :type bucket_name: str diff --git a/fedn/network/storage/s3/repository.py b/fedn/network/storage/s3/repository.py index d7d455341..18d36cdbb 100644 --- a/fedn/network/storage/s3/repository.py +++ b/fedn/network/storage/s3/repository.py @@ -5,12 +5,11 @@ class Repository: - """ Interface for storing model objects and compute packages in S3 compatible storage. """ + """Interface for storing model objects and compute packages in S3 compatible storage.""" def __init__(self, config): - - self.model_bucket = config['storage_bucket'] - self.context_bucket = config['context_bucket'] + self.model_bucket = config["storage_bucket"] + self.context_bucket = config["context_bucket"] # TODO: Make a plug-in solution self.client = MINIORepository(config) @@ -19,27 +18,25 @@ def __init__(self, config): self.client.create_bucket(self.model_bucket) def get_model(self, model_id): - """ Retrieve a model with id model_id. + """Retrieve a model with id model_id. :param model_id: Unique identifier for model to retrive. :return: The model object """ - logger.info("Client {} trying to get model with id: {}".format( - self.client.name, model_id)) + logger.info("Client {} trying to get model with id: {}".format(self.client.name, model_id)) return self.client.get_artifact(model_id, self.model_bucket) def get_model_stream(self, model_id): - """ Retrieve a stream handle to model with id model_id. + """Retrieve a stream handle to model with id model_id. :param model_id: :return: Handle to model object """ - logger.info("Client {} trying to get model with id: {}".format( - self.client.name, model_id)) + logger.info("Client {} trying to get model with id: {}".format(self.client.name, model_id)) return self.client.get_artifact_stream(model_id, self.model_bucket) def set_model(self, model, is_file=True): - """ Upload model object. + """Upload model object. :param model: The model object :type model: BytesIO or str file name. @@ -49,15 +46,14 @@ def set_model(self, model, is_file=True): model_id = uuid.uuid4() try: - self.client.set_artifact(str(model_id), model, - bucket=self.model_bucket, is_file=is_file) + self.client.set_artifact(str(model_id), model, bucket=self.model_bucket, is_file=is_file) except Exception: logger.error("Failed to upload model with ID {} to repository.".format(model_id)) raise return str(model_id) def delete_model(self, model_id): - """ Delete model. + """Delete model. :param model_id: The id of the model to delete :type model_id: str @@ -69,7 +65,7 @@ def delete_model(self, model_id): raise def set_compute_package(self, name, compute_package, is_file=True): - """ Upload compute package. + """Upload compute package. :param name: The name of the compute package. :type name: str @@ -79,14 +75,13 @@ def set_compute_package(self, name, compute_package, is_file=True): """ try: - self.client.set_artifact(str(name), compute_package, - bucket=self.context_bucket, is_file=is_file) + self.client.set_artifact(str(name), compute_package, bucket=self.context_bucket, is_file=is_file) except Exception: logger.error("Failed to write compute_package to repository.") raise def get_compute_package(self, compute_package): - """ Retrieve compute package from object store. + """Retrieve compute package from object store. :param compute_package: The name of the compute package. :type compute_pacakge: str @@ -100,7 +95,7 @@ def get_compute_package(self, compute_package): return data def delete_compute_package(self, compute_package): - """ Delete a compute package from storage. + """Delete a compute package from storage. :param compute_package: The name of the compute_package :type compute_package: str diff --git a/fedn/network/storage/statestore/mongostatestore.py b/fedn/network/storage/statestore/mongostatestore.py index c4fdf3fe7..a53a6e4d5 100644 --- a/fedn/network/storage/statestore/mongostatestore.py +++ b/fedn/network/storage/statestore/mongostatestore.py @@ -61,7 +61,7 @@ def __init__(self, network_id, config): self.init_index() def connect(self): - """ Establish client connection to MongoDB. + """Establish client connection to MongoDB. :param config: Dictionary containing connection strings and security credentials. :type config: dict @@ -125,11 +125,7 @@ def transition(self, state): True, ) else: - logger.info( - "Not updating state, already in {}".format( - ReducerStateToString(state) - ) - ) + logger.info("Not updating state, already in {}".format(ReducerStateToString(state))) def get_sessions(self, limit=None, skip=None, sort_key="_id", sort_order=pymongo.DESCENDING): """Get all sessions. @@ -151,13 +147,9 @@ def get_sessions(self, limit=None, skip=None, sort_key="_id", sort_order=pymongo limit = int(limit) skip = int(skip) - result = self.sessions.find().limit(limit).skip(skip).sort( - sort_key, sort_order - ) + result = self.sessions.find().limit(limit).skip(skip).sort(sort_key, sort_order) else: - result = self.sessions.find().sort( - sort_key, sort_order - ) + result = self.sessions.find().sort(sort_key, sort_order) count = self.sessions.count_documents({}) @@ -204,9 +196,7 @@ def set_latest_model(self, model_id, session_id=None): } ) - self.model.update_one( - {"key": "current_model"}, {"$set": {"model": model_id}}, True - ) + self.model.update_one({"key": "current_model"}, {"$set": {"model": model_id}}, True) self.model.update_one( {"key": "model_trail"}, { @@ -225,9 +215,7 @@ def get_initial_model(self): :rtype: str """ - result = self.model.find_one( - {"key": "model_trail"}, sort=[("committed_at", pymongo.ASCENDING)] - ) + result = self.model.find_one({"key": "model_trail"}, sort=[("committed_at", pymongo.ASCENDING)]) if result is None: return None @@ -266,16 +254,12 @@ def set_current_model(self, model_id: str): """ try: - committed_at = datetime.now() existing_model = self.model.find_one({"key": "models", "model": model_id}) if existing_model is not None: - - self.model.update_one( - {"key": "current_model"}, {"$set": {"model": model_id, "committed_at": committed_at, "session_id": None}}, True - ) + self.model.update_one({"key": "current_model"}, {"$set": {"model": model_id, "committed_at": committed_at, "session_id": None}}, True) return True except Exception as e: @@ -334,7 +318,6 @@ def set_active_compute_package(self, id: str): """ try: - find = {"id": id} projection = {"_id": False, "key": False} @@ -345,9 +328,7 @@ def set_active_compute_package(self, id: str): doc["key"] = "active" - self.control.package.replace_one( - {"key": "active"}, doc - ) + self.control.package.replace_one({"key": "active"}, doc) except Exception as e: logger.error("ERROR: {}".format(e)) @@ -376,9 +357,7 @@ def set_compute_package(self, file_name: str, storage_file_name: str, helper_typ self.control.package.update_one( {"key": "active"}, - { - "$set": obj - }, + {"$set": obj}, True, ) @@ -395,7 +374,6 @@ def get_compute_package(self): :rtype: ObjectID """ try: - find = {"key": "active"} projection = {"key": False, "_id": False} ret = self.control.package.find_one(find, projection) @@ -449,9 +427,7 @@ def set_helper(self, helper): :type helper: str :return: """ - self.control.package.update_one( - {"key": "active"}, {"$set": {"helper": helper}}, True - ) + self.control.package.update_one({"key": "active"}, {"$set": {"helper": helper}}, True) def get_helper(self): """Get the active helper package. @@ -466,9 +442,7 @@ def get_helper(self): # ret = self.control.config.find_one({'key': 'round_config'}) try: retcheck = ret["helper"] - if ( - retcheck == "" or retcheck == " " - ): # ugly check for empty string + if retcheck == "" or retcheck == " ": # ugly check for empty string return None return retcheck except (KeyError, IndexError): @@ -495,11 +469,7 @@ def list_models( """ result = None - find_option = ( - {"key": "models"} - if session_id is None - else {"key": "models", "session_id": session_id} - ) + find_option = {"key": "models"} if session_id is None else {"key": "models", "session_id": session_id} projection = {"_id": False, "key": False} @@ -507,17 +477,10 @@ def list_models( limit = int(limit) skip = int(skip) - result = ( - self.model.find(find_option, projection) - .limit(limit) - .skip(skip) - .sort(sort_key, sort_order) - ) + result = self.model.find(find_option, projection).limit(limit).skip(skip).sort(sort_key, sort_order) else: - result = self.model.find(find_option, projection).sort( - sort_key, sort_order - ) + result = self.model.find(find_option, projection).sort(sort_key, sort_order) count = self.model.count_documents(find_option) @@ -625,9 +588,7 @@ def get_events(self, **kwargs): projection = {"_id": False} if not kwargs: - result = self.control.status.find({}, projection).sort( - "timestamp", pymongo.DESCENDING - ) + result = self.control.status.find({}, projection).sort("timestamp", pymongo.DESCENDING) count = self.control.status.count_documents({}) else: limit = kwargs.pop("limit", None) @@ -636,16 +597,9 @@ def get_events(self, **kwargs): if limit is not None and skip is not None: limit = int(limit) skip = int(skip) - result = ( - self.control.status.find(kwargs, projection) - .sort("timestamp", pymongo.DESCENDING) - .limit(limit) - .skip(skip) - ) + result = self.control.status.find(kwargs, projection).sort("timestamp", pymongo.DESCENDING).limit(limit).skip(skip) else: - result = self.control.status.find(kwargs, projection).sort( - "timestamp", pymongo.DESCENDING - ) + result = self.control.status.find(kwargs, projection).sort("timestamp", pymongo.DESCENDING) count = self.control.status.count_documents(kwargs) @@ -661,9 +615,7 @@ def get_storage_backend(self): :rtype: ObjectID """ try: - ret = self.storage.find( - {"status": "enabled"}, projection={"_id": False} - ) + ret = self.storage.find({"status": "enabled"}, projection={"_id": False}) return ret[0] except (KeyError, IndexError): return None @@ -678,9 +630,7 @@ def set_storage_backend(self, config): config = copy.deepcopy(config) config["updated_at"] = str(datetime.now()) config["status"] = "enabled" - self.storage.update_one( - {"storage_type": config["storage_type"]}, {"$set": config}, True - ) + self.storage.update_one({"storage_type": config["storage_type"]}, {"$set": config}, True) def set_reducer(self, reducer_data): """Set the reducer in the statestore. @@ -690,9 +640,7 @@ def set_reducer(self, reducer_data): :return: """ reducer_data["updated_at"] = str(datetime.now()) - self.reducer.update_one( - {"name": reducer_data["name"]}, {"$set": reducer_data}, True - ) + self.reducer.update_one({"name": reducer_data["name"]}, {"$set": reducer_data}, True) def get_reducer(self): """Get reducer.config. @@ -767,9 +715,7 @@ def set_combiner(self, combiner_data): """ combiner_data["updated_at"] = str(datetime.now()) - self.combiners.update_one( - {"name": combiner_data["name"]}, {"$set": combiner_data}, True - ) + self.combiners.update_one({"name": combiner_data["name"]}, {"$set": combiner_data}, True) def delete_combiner(self, combiner): """Delete a combiner from statestore. @@ -793,9 +739,7 @@ def set_client(self, client_data): :return: """ client_data["updated_at"] = str(datetime.now()) - self.clients.update_one( - {"name": client_data["name"]}, {"$set": client_data}, True - ) + self.clients.update_one({"name": client_data["name"]}, {"$set": client_data}, True) def get_client(self, name): """Get client by name. @@ -866,15 +810,15 @@ def list_combiners_data(self, combiners, sort_key="count", sort_order=pymongo.DE result = None try: - - pipeline = [ - {"$match": {"combiner": {"$in": combiners}, "status": "online"}}, - {"$group": {"_id": "$combiner", "count": {"$sum": 1}}}, - {"$sort": {sort_key: sort_order, "_id": pymongo.ASCENDING}} - ] if combiners is not None else [ - {"$group": {"_id": "$combiner", "count": {"$sum": 1}}}, - {"$sort": {sort_key: sort_order, "_id": pymongo.ASCENDING}} - ] + pipeline = ( + [ + {"$match": {"combiner": {"$in": combiners}, "status": "online"}}, + {"$group": {"_id": "$combiner", "count": {"$sum": 1}}}, + {"$sort": {sort_key: sort_order, "_id": pymongo.ASCENDING}}, + ] + if combiners is not None + else [{"$group": {"_id": "$combiner", "count": {"$sum": 1}}}, {"$sort": {sort_key: sort_order, "_id": pymongo.ASCENDING}}] + ) result = self.clients.aggregate(pipeline) @@ -911,7 +855,7 @@ def drop_status(self): self.status.drop() def create_session(self, id=None): - """ Create a new session object. + """Create a new session object. :param id: The ID of the created session. :type id: uuid, str @@ -919,11 +863,11 @@ def create_session(self, id=None): """ if not id: id = uuid.uuid4() - data = {'session_id': str(id)} + data = {"session_id": str(id)} self.sessions.insert_one(data) def create_round(self, round_data): - """ Create a new round. + """Create a new round. :param round_data: Dictionary with round data. :type round_data: dict @@ -939,8 +883,7 @@ def set_session_config(self, id, config): :param config: Session configuration :type config: dict """ - self.sessions.update_one({'session_id': str(id)}, { - '$push': {'session_config': config}}, True) + self.sessions.update_one({"session_id": str(id)}, {"$push": {"session_config": config}}, True) def set_session_status(self, id, status): """Set session status. @@ -949,8 +892,7 @@ def set_session_status(self, id, status): :type round_id: str :param round_status: The status of the session. """ - self.sessions.update_one({'session_id': str(id)}, { - '$set': {'status': status}}, True) + self.sessions.update_one({"session_id": str(id)}, {"$set": {"status": status}}, True) def set_round_combiner_data(self, data): """Set combiner round controller data. @@ -958,8 +900,7 @@ def set_round_combiner_data(self, data): :param data: The combiner data :type data: dict """ - self.rounds.update_one({'round_id': str(data['round_id'])}, { - '$push': {'combiners': data}}, True) + self.rounds.update_one({"round_id": str(data["round_id"])}, {"$push": {"combiners": data}}, True) def set_round_config(self, round_id, round_config): """Set round configuration. @@ -969,8 +910,7 @@ def set_round_config(self, round_id, round_config): :param round_config: The round configuration :type round_config: dict """ - self.rounds.update_one({'round_id': round_id}, { - '$set': {'round_config': round_config}}, True) + self.rounds.update_one({"round_id": round_id}, {"$set": {"round_config": round_config}}, True) def set_round_status(self, round_id, round_status): """Set round status. @@ -979,8 +919,7 @@ def set_round_status(self, round_id, round_status): :type round_id: str :param round_status: The status of the round. """ - self.rounds.update_one({'round_id': round_id}, { - '$set': {'status': round_status}}, True) + self.rounds.update_one({"round_id": round_id}, {"$set": {"status": round_status}}, True) def set_round_data(self, round_id, round_data): """Update round metadata @@ -990,11 +929,10 @@ def set_round_data(self, round_id, round_data): :param round_data: The round metadata :type round_data: dict """ - self.rounds.update_one({'round_id': round_id}, { - '$set': {'round_data': round_data}}, True) + self.rounds.update_one({"round_id": round_id}, {"$set": {"round_data": round_data}}, True) def update_client_status(self, clients, status): - """ Update client status in statestore. + """Update client status in statestore. :param client_name: The client name :type client_name: str :param status: The client status diff --git a/fedn/network/storage/statestore/statestorebase.py b/fedn/network/storage/statestore/statestorebase.py index f41e3c025..7c6681682 100644 --- a/fedn/network/storage/statestore/statestorebase.py +++ b/fedn/network/storage/statestore/statestorebase.py @@ -2,22 +2,19 @@ class StateStoreBase(ABC): - """ - - """ + """ """ def __init__(self): pass @abstractmethod def state(self): - """ Return the current state of the statestore. - """ + """Return the current state of the statestore.""" pass @abstractmethod def transition(self, state): - """ Transition the statestore to a new state. + """Transition the statestore to a new state. :param state: The new state. :type state: str @@ -26,7 +23,7 @@ def transition(self, state): @abstractmethod def set_latest_model(self, model_id): - """ Set the latest model id in the statestore. + """Set the latest model id in the statestore. :param model_id: The model id. :type model_id: str @@ -35,7 +32,7 @@ def set_latest_model(self, model_id): @abstractmethod def get_latest_model(self): - """ Get the latest model id from the statestore. + """Get the latest model id from the statestore. :return: The model object. :rtype: ObjectId @@ -44,7 +41,7 @@ def get_latest_model(self): @abstractmethod def is_inited(self): - """ Check if the statestore is initialized. + """Check if the statestore is initialized. :return: True if initialized, else False. :rtype: bool diff --git a/fedn/utils/checksum.py b/fedn/utils/checksum.py index 3c7bbd3ec..8ca678597 100644 --- a/fedn/utils/checksum.py +++ b/fedn/utils/checksum.py @@ -2,7 +2,7 @@ def sha(fname): - """ Calculate the sha256 checksum of a file. Used for computing checksums of compute packages. + """Calculate the sha256 checksum of a file. Used for computing checksums of compute packages. :param fname: The file path. :type fname: str diff --git a/fedn/utils/dispatcher.py b/fedn/utils/dispatcher.py index 77f357249..5d00021b1 100644 --- a/fedn/utils/dispatcher.py +++ b/fedn/utils/dispatcher.py @@ -72,10 +72,7 @@ def _validate_virtualenv_is_available(): on how to install virtualenv. """ if not _is_virtualenv_available(): - raise Exception( - "Could not find the virtualenv binary. Run `pip install virtualenv` to install " - "virtualenv." - ) + raise Exception("Could not find the virtualenv binary. Run `pip install virtualenv` to install " "virtualenv.") def _get_virtualenv_extra_env_vars(env_root_dir=None): @@ -95,9 +92,7 @@ def _get_python_env(python_env_file): return _PythonEnv.from_yaml(python_env_file) -def _create_virtualenv( - python_bin_path, env_dir, python_env, extra_env=None, capture_output=False -): +def _create_virtualenv(python_bin_path, env_dir, python_env, extra_env=None, capture_output=False): # Created a command to activate the environment paths = ("bin", "activate") if _IS_UNIX else ("Scripts", "activate.bat") activate_cmd = env_dir.joinpath(*paths) @@ -110,8 +105,7 @@ def _create_virtualenv( with remove_on_error( env_dir, onerror=lambda e: logger.warning( - "Encountered an unexpected error: %s while creating a virtualenv environment in %s, " - "removing the environment directory...", + "Encountered an unexpected error: %s while creating a virtualenv environment in %s, " "removing the environment directory...", repr(e), env_dir, ), @@ -127,9 +121,7 @@ def _create_virtualenv( with tempfile.TemporaryDirectory() as tmpdir: tmp_req_file = f"requirements.{uuid.uuid4().hex}.txt" Path(tmpdir).joinpath(tmp_req_file).write_text("\n".join(deps)) - cmd = _join_commands( - activate_cmd, f"python -m pip install -r {tmp_req_file}" - ) + cmd = _join_commands(activate_cmd, f"python -m pip install -r {tmp_req_file}") _exec_cmd(cmd, capture_output=capture_output, cwd=tmpdir, extra_env=extra_env) return activate_cmd @@ -138,20 +130,17 @@ def _create_virtualenv( def _read_yaml_file(file_path): try: cfg = None - with open(file_path, 'rb') as config_file: - + with open(file_path, "rb") as config_file: cfg = yaml.safe_load(config_file.read()) except Exception as e: - logger.error( - f"Error trying to read yaml file: {file_path}" - ) + logger.error(f"Error trying to read yaml file: {file_path}") raise e return cfg class Dispatcher: - """ Dispatcher class for compute packages. + """Dispatcher class for compute packages. :param config: The configuration. :type config: dict @@ -160,7 +149,7 @@ class Dispatcher: """ def __init__(self, config, project_dir): - """ Initialize the dispatcher.""" + """Initialize the dispatcher.""" self.config = config self.project_dir = project_dir self.activate_cmd = "" @@ -174,10 +163,7 @@ def _get_or_create_python_env(self, capture_output=False, pip_requirements_overr else: python_env_path = os.path.join(self.project_dir, python_env) if not os.path.exists(python_env_path): - raise Exception( - "Compute package specified python_env file %s, but no such " - "file was found." % python_env_path - ) + raise Exception("Compute package specified python_env file %s, but no such " "file was found." % python_env_path) python_env = _get_python_env(python_env_path) extra_env = _get_virtualenv_extra_env_vars() @@ -201,10 +187,7 @@ def _get_or_create_python_env(self, capture_output=False, pip_requirements_overr ) # Install additional dependencies specified by `requirements_override` if pip_requirements_override: - logger.info( - "Installing additional dependencies specified by " - f"pip_requirements_override: {pip_requirements_override}" - ) + logger.info("Installing additional dependencies specified by " f"pip_requirements_override: {pip_requirements_override}") cmd = _join_commands( activate_cmd, f"python -m pip install --quiet -U {' '.join(pip_requirements_override)}", @@ -222,29 +205,23 @@ def _get_or_create_python_env(self, capture_output=False, pip_requirements_overr raise - def run_cmd(self, - cmd_type, - capture_output=False, - extra_env=None, - synchronous=True, - stream_output=False - ): - """ Run a command. + def run_cmd(self, cmd_type, capture_output=False, extra_env=None, synchronous=True, stream_output=False): + """Run a command. :param cmd_type: The command type. :type cmd_type: str :return: """ try: - cmdsandargs = cmd_type.split(' ') + cmdsandargs = cmd_type.split(" ") - entry_point = self.config['entry_points'][cmdsandargs[0]]['command'] + entry_point = self.config["entry_points"][cmdsandargs[0]]["command"] # remove the first element, that is not a file but a command args = cmdsandargs[1:] # Join entry point and arguments into a single command as a string - entry_point_args = ' '.join(args) + entry_point_args = " ".join(args) entry_point = f"{entry_point} {entry_point_args}" if self.activate_cmd: @@ -252,7 +229,7 @@ def run_cmd(self, else: cmd = _join_commands(entry_point) - logger.info('Running command: {}'.format(cmd)) + logger.info("Running command: {}".format(cmd)) _exec_cmd( cmd, cwd=self.project_dir, @@ -263,7 +240,7 @@ def run_cmd(self, stream_output=stream_output, ) - logger.info('Done executing {}'.format(cmd_type)) + logger.info("Done executing {}".format(cmd_type)) except IndexError: message = "No such argument or configuration to run." logger.error(message) diff --git a/fedn/utils/flowercompat/client_app_adapter.py b/fedn/utils/flowercompat/client_app_adapter.py index 610e5ca9f..15de3ff6b 100644 --- a/fedn/utils/flowercompat/client_app_adapter.py +++ b/fedn/utils/flowercompat/client_app_adapter.py @@ -1,16 +1,27 @@ from typing import Tuple from flwr.client import ClientApp -from flwr.common import (Context, EvaluateIns, FitIns, GetParametersIns, - Message, MessageType, MessageTypeLegacy, Metadata, - NDArrays, ndarrays_to_parameters, - parameters_to_ndarrays) -from flwr.common.recordset_compat import (evaluateins_to_recordset, - fitins_to_recordset, - getparametersins_to_recordset, - recordset_to_evaluateres, - recordset_to_fitres, - recordset_to_getparametersres) +from flwr.common import ( + Context, + EvaluateIns, + FitIns, + GetParametersIns, + Message, + MessageType, + MessageTypeLegacy, + Metadata, + NDArrays, + ndarrays_to_parameters, + parameters_to_ndarrays, +) +from flwr.common.recordset_compat import ( + evaluateins_to_recordset, + fitins_to_recordset, + getparametersins_to_recordset, + recordset_to_evaluateres, + recordset_to_fitres, + recordset_to_getparametersres, +) class FlwrClientAppAdapter: @@ -21,23 +32,21 @@ def __init__(self, app: ClientApp) -> None: def init_parameters(self, partition_id: int, config: dict = {}): # Construct a get_parameters message for the ClientApp - message, context = self._construct_message( - MessageTypeLegacy.GET_PARAMETERS, [], partition_id, config - ) + message, context = self._construct_message(MessageTypeLegacy.GET_PARAMETERS, [], partition_id, config) # Call client app with train message client_return_message = self.app(message, context) # return NDArrays of clients parameters parameters = self._parse_get_parameters_message(client_return_message) if len(parameters) == 0: - raise ValueError("The 'parameters' list is empty. Ensure your flower \ - client has implemented a get_parameters() function.") + raise ValueError( + "The 'parameters' list is empty. Ensure your flower \ + client has implemented a get_parameters() function." + ) return parameters def train(self, parameters: NDArrays, partition_id: int, config: dict = {}): # Construct a train message for the ClientApp with given parameters - message, context = self._construct_message( - MessageType.TRAIN, parameters, partition_id, config - ) + message, context = self._construct_message(MessageType.TRAIN, parameters, partition_id, config) # Call client app with train message client_return_message = self.app(message, context) # Parse return message @@ -46,9 +55,7 @@ def train(self, parameters: NDArrays, partition_id: int, config: dict = {}): def evaluate(self, parameters: NDArrays, partition_id: int, config: dict = {}): # Construct an evaluate message for the ClientApp with given parameters - message, context = self._construct_message( - MessageType.EVALUATE, parameters, partition_id, config - ) + message, context = self._construct_message(MessageType.EVALUATE, parameters, partition_id, config) # Call client app with evaluate message client_return_message = self.app(message, context) # Parse return message diff --git a/fedn/utils/helpers/helperbase.py b/fedn/utils/helpers/helperbase.py index a59ee49d8..3377d0336 100644 --- a/fedn/utils/helpers/helperbase.py +++ b/fedn/utils/helpers/helperbase.py @@ -4,16 +4,16 @@ class HelperBase(ABC): - """ Abstract class defining helpers. """ + """Abstract class defining helpers.""" def __init__(self): - """ Initialize helper. """ + """Initialize helper.""" self.name = self.__class__.__name__ @abstractmethod def increment_average(self, m1, m2, a, W): - """ Compute one increment of incremental weighted averaging. + """Compute one increment of incremental weighted averaging. :param m1: Current model weights in array-like format. :param m2: New model weights in array-like format. @@ -25,7 +25,7 @@ def increment_average(self, m1, m2, a, W): @abstractmethod def save(self, model, path): - """ Serialize weights to file. The serialized model must be a single binary object. + """Serialize weights to file. The serialized model must be a single binary object. :param model: Weights in array-like format. :param path: Path to file. @@ -35,7 +35,7 @@ def save(self, model, path): @abstractmethod def load(self, fh): - """ Load weights from file or filelike. + """Load weights from file or filelike. :param fh: file path, filehandle, filelike. :return: Weights in array-like format. @@ -43,10 +43,10 @@ def load(self, fh): pass def get_tmp_path(self): - """ Return a temporary output path compatible with save_model, load_model. + """Return a temporary output path compatible with save_model, load_model. :return: Path to file. """ - fd, path = tempfile.mkstemp(suffix='.npz') + fd, path = tempfile.mkstemp(suffix=".npz") os.close(fd) return path diff --git a/fedn/utils/helpers/helpers.py b/fedn/utils/helpers/helpers.py index 841af58a0..7fbf83a81 100644 --- a/fedn/utils/helpers/helpers.py +++ b/fedn/utils/helpers/helpers.py @@ -5,7 +5,7 @@ def get_helper(helper_module_name): - """ Return an instance of the helper class. + """Return an instance of the helper class. :param helper_module_name: The name of the helper plugin module. :type helper_module_name: str @@ -18,24 +18,24 @@ def get_helper(helper_module_name): def save_metadata(metadata, filename): - """ Save metadata to file. + """Save metadata to file. :param metadata: The metadata to save. :type metadata: dict :param filename: The name of the file to save to. :type filename: str """ - with open(filename+'-metadata', 'w') as outfile: + with open(filename + "-metadata", "w") as outfile: json.dump(metadata, outfile) def save_metrics(metrics, filename): - """ Save metrics to file. + """Save metrics to file. :param metrics: The metrics to save. :type metrics: dict :param filename: The name of the file to save to. :type filename: str """ - with open(filename, 'w') as outfile: + with open(filename, "w") as outfile: json.dump(metrics, outfile) diff --git a/fedn/utils/helpers/plugins/androidhelper.py b/fedn/utils/helpers/plugins/androidhelper.py index 6a9fc7f9d..119801b8c 100644 --- a/fedn/utils/helpers/plugins/androidhelper.py +++ b/fedn/utils/helpers/plugins/androidhelper.py @@ -18,9 +18,7 @@ def __init__(self): # function to calculate an incremental weighted average of the weights - def increment_average( - self, model, model_next, num_examples, total_examples - ): + def increment_average(self, model, model_next, num_examples, total_examples): """Incremental weighted average of model weights. :param model: Current model weights. @@ -40,9 +38,7 @@ def increment_average( return (1 - w) * model + w * model_next # function to calculate an incremental weighted average of the weights using numpy.add - def increment_average_add( - self, model, model_next, num_examples, total_examples - ): + def increment_average_add(self, model, model_next, num_examples, total_examples): """Incremental weighted average of model weights. :param model: Current model weights. @@ -59,9 +55,7 @@ def increment_average_add( # Incremental weighted average w = np.add( model, - num_examples - * (np.array(model_next) - np.array(model)) - / total_examples, + num_examples * (np.array(model_next) - np.array(model)) / total_examples, ) return w @@ -75,7 +69,7 @@ def save(self, weights, path=None): if not path: path = self.get_tmp_path() - byte_array = struct.pack("f"*len(weights), *weights) + byte_array = struct.pack("f" * len(weights), *weights) with open(path, "wb") as file: file.write(byte_array) @@ -94,7 +88,7 @@ def load(self, fh): else: byte_data = fh.read() - weights = np.array(struct.unpack(f'{len(byte_data) // 4}f', byte_data)) + weights = np.array(struct.unpack(f"{len(byte_data) // 4}f", byte_data)) return weights diff --git a/fedn/utils/helpers/plugins/numpyhelper.py b/fedn/utils/helpers/plugins/numpyhelper.py index 9751e3903..ce6c29420 100644 --- a/fedn/utils/helpers/plugins/numpyhelper.py +++ b/fedn/utils/helpers/plugins/numpyhelper.py @@ -1,19 +1,18 @@ - import numpy as np from fedn.utils.helpers.helperbase import HelperBase class Helper(HelperBase): - """ FEDn helper class for models weights/parameters that can be transformed to numpy ndarrays. """ + """FEDn helper class for models weights/parameters that can be transformed to numpy ndarrays.""" def __init__(self): - """ Initialize helper. """ + """Initialize helper.""" super().__init__() self.name = "numpyhelper" def increment_average(self, m1, m2, n, N): - """ Update a weighted incremental average of model weights. + """Update a weighted incremental average of model weights. :param m1: Current parameters. :type model: list of numpy ndarray @@ -27,10 +26,10 @@ def increment_average(self, m1, m2, n, N): :rtype: list of numpy ndarray """ - return [np.add(x, n*(y-x)/N) for x, y in zip(m1, m2)] + return [np.add(x, n * (y - x) / N) for x, y in zip(m1, m2)] def add(self, m1, m2, a=1.0, b=1.0): - """ m1*a + m2*b + """m1*a + m2*b :param model: Current model weights. :type model: list of ndarrays @@ -40,10 +39,10 @@ def add(self, m1, m2, a=1.0, b=1.0): :rtype: list of ndarrays """ - return [x*a+y*b for x, y in zip(m1, m2)] + return [x * a + y * b for x, y in zip(m1, m2)] def subtract(self, m1, m2, a=1.0, b=1.0): - """ m1*a - m2*b. + """m1*a - m2*b. :param m1: Current model weights. :type m1: list of ndarrays @@ -55,7 +54,7 @@ def subtract(self, m1, m2, a=1.0, b=1.0): return self.add(m1, m2, a, -b) def divide(self, m1, m2): - """ Subtract weights. + """Subtract weights. :param m1: Current model weights. :type m1: list of ndarrays @@ -68,7 +67,7 @@ def divide(self, m1, m2): return [np.divide(x, y) for x, y in zip(m1, m2)] def multiply(self, m1, m2): - """ Multiply m1 by m2. + """Multiply m1 by m2. :param m1: Current model weights. :type m1: list of ndarrays @@ -81,7 +80,7 @@ def multiply(self, m1, m2): return [np.multiply(x, y) for (x, y) in zip(m1, m2)] def sqrt(self, m1): - """ Sqrt of m1, element-wise. + """Sqrt of m1, element-wise. :param m1: Current model weights. :type model: list of ndarrays @@ -94,7 +93,7 @@ def sqrt(self, m1): return [np.sqrt(x) for x in m1] def power(self, m1, a): - """ m1 raised to the power of m2. + """m1 raised to the power of m2. :param m1: Current model weights. :type m1: list of ndarrays @@ -107,7 +106,7 @@ def power(self, m1, a): return [np.power(x, a) for x in m1] def norm(self, m): - """ Return the norm (L1) of model weights. + """Return the norm (L1) of model weights. :param m: Current model weights. :type m: list of ndarrays @@ -120,7 +119,7 @@ def norm(self, m): return n def sign(self, m): - """ Sign of m. + """Sign of m. :param m: Model parameters. :type m: list of ndarrays @@ -131,7 +130,7 @@ def sign(self, m): return [np.sign(x) for x in m] def ones(self, m1, a): - """ Return a list of numpy arrays of the same shape as m1, filled with ones. + """Return a list of numpy arrays of the same shape as m1, filled with ones. :param m1: Current model weights. :type m1: list of ndarrays @@ -143,11 +142,11 @@ def ones(self, m1, a): res = [] for x in m1: - res.append(np.ones(np.shape(x))*a) + res.append(np.ones(np.shape(x)) * a) return res def save(self, weights, path=None): - """ Serialize weights to file. The serialized model must be a single binary object. + """Serialize weights to file. The serialized model must be a single binary object. :param weights: List of weights in numpy format. :param path: Path to file. @@ -165,7 +164,7 @@ def save(self, weights, path=None): return path def load(self, fh): - """ Load weights from file or filelike. + """Load weights from file or filelike. :param fh: file path, filehandle, filelike. :return: List of weights in numpy format. diff --git a/fedn/utils/plots.py b/fedn/utils/plots.py index 8ec81aca2..7901e2374 100644 --- a/fedn/utils/plots.py +++ b/fedn/utils/plots.py @@ -11,17 +11,14 @@ class Plot: - """ - - """ + """ """ def __init__(self, statestore): try: statestore_config = statestore.get_config() - statestore = MongoStateStore( - statestore_config['network_id'], statestore_config['mongo_config']) + statestore = MongoStateStore(statestore_config["network_id"], statestore_config["mongo_config"]) self.mdb = statestore.connect() - self.status = self.mdb['control.status'] + self.status = self.mdb["control.status"] self.round_time = self.mdb["control.round_time"] self.combiner_round_time = self.mdb["control.combiner_round_time"] self.psutil_usage = self.mdb["control.psutil_monitoring"] @@ -34,10 +31,10 @@ def __init__(self, statestore): # plot metrics from DB def _scalar_metrics(self, metrics): - """ Extract all scalar valued metrics from a MODEL_VALIDATON. """ + """Extract all scalar valued metrics from a MODEL_VALIDATON.""" - data = json.loads(metrics['data']) - data = json.loads(data['data']) + data = json.loads(metrics["data"]) + data = json.loads(data["data"]) valid_metrics = [] for metric, val in data.items(): @@ -55,18 +52,17 @@ def create_table_plot(self): :return: """ - metrics = self.status.find_one({'type': 'MODEL_VALIDATION'}) + metrics = self.status.find_one({"type": "MODEL_VALIDATION"}) if metrics is None: fig = go.Figure(data=[]) - fig.update_layout( - title_text='No data currently available for table mean metrics') + fig.update_layout(title_text="No data currently available for table mean metrics") table = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) return False valid_metrics = self._scalar_metrics(metrics) if valid_metrics == []: fig = go.Figure(data=[]) - fig.update_layout(title_text='No scalar metrics found') + fig.update_layout(title_text="No scalar metrics found") table = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) return False @@ -74,14 +70,12 @@ def create_table_plot(self): models = [] for metric in valid_metrics: validations = {} - for post in self.status.find({'type': 'MODEL_VALIDATION'}): - e = json.loads(post['data']) + for post in self.status.find({"type": "MODEL_VALIDATION"}): + e = json.loads(post["data"]) try: - validations[e['modelId']].append( - float(json.loads(e['data'])[metric])) + validations[e["modelId"]].append(float(json.loads(e["data"])[metric])) except KeyError: - validations[e['modelId']] = [ - float(json.loads(e['data'])[metric])] + validations[e["modelId"]] = [float(json.loads(e["data"])[metric])] vals = [] models = [] @@ -98,19 +92,21 @@ def create_table_plot(self): vals.reverse() values.append(vals) - fig = go.Figure(data=[go.Table( - header=dict(values=['Model ID'] + header_vals, - line_color='darkslategray', - fill_color='lightskyblue', - align='left'), - - cells=dict(values=values, # 2nd column - line_color='darkslategray', - fill_color='lightcyan', - align='left')) - ]) + fig = go.Figure( + data=[ + go.Table( + header=dict(values=["Model ID"] + header_vals, line_color="darkslategray", fill_color="lightskyblue", align="left"), + cells=dict( + values=values, # 2nd column + line_color="darkslategray", + fill_color="lightcyan", + align="left", + ), + ) + ] + ) - fig.update_layout(title_text='Summary: mean metrics') + fig.update_layout(title_text="Summary: mean metrics") table = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) return table @@ -123,63 +119,67 @@ def create_timeline_plot(self): x = [] y = [] base = [] - for p in self.status.find({'type': 'MODEL_UPDATE_REQUEST'}): - e = json.loads(p['data']) - cid = e['correlationId'] - for cc in self.status.find({'sender': p['sender'], 'type': 'MODEL_UPDATE'}): - da = json.loads(cc['data']) - if da['correlationId'] == cid: + for p in self.status.find({"type": "MODEL_UPDATE_REQUEST"}): + e = json.loads(p["data"]) + cid = e["correlationId"] + for cc in self.status.find({"sender": p["sender"], "type": "MODEL_UPDATE"}): + da = json.loads(cc["data"]) + if da["correlationId"] == cid: cp = cc - cd = json.loads(cp['data']) - tr = datetime.strptime(e['timestamp'], '%Y-%m-%d %H:%M:%S.%f') - tu = datetime.strptime(cd['timestamp'], '%Y-%m-%d %H:%M:%S.%f') + cd = json.loads(cp["data"]) + tr = datetime.strptime(e["timestamp"], "%Y-%m-%d %H:%M:%S.%f") + tu = datetime.strptime(cd["timestamp"], "%Y-%m-%d %H:%M:%S.%f") ts = tu - tr base.append(tr.timestamp()) x.append(ts.total_seconds() / 60.0) - y.append(p['sender']['name']) - - trace_data.append(go.Bar( - x=y, - y=x, - marker=dict(color='royalblue'), - name="Training", - )) + y.append(p["sender"]["name"]) + + trace_data.append( + go.Bar( + x=y, + y=x, + marker=dict(color="royalblue"), + name="Training", + ) + ) x = [] y = [] base = [] - for p in self.status.find({'type': 'MODEL_VALIDATION_REQUEST'}): - e = json.loads(p['data']) - cid = e['correlationId'] - for cc in self.status.find({'sender': p['sender'], 'type': 'MODEL_VALIDATION'}): - da = json.loads(cc['data']) - if da['correlationId'] == cid: + for p in self.status.find({"type": "MODEL_VALIDATION_REQUEST"}): + e = json.loads(p["data"]) + cid = e["correlationId"] + for cc in self.status.find({"sender": p["sender"], "type": "MODEL_VALIDATION"}): + da = json.loads(cc["data"]) + if da["correlationId"] == cid: cp = cc - cd = json.loads(cp['data']) - tr = datetime.strptime(e['timestamp'], '%Y-%m-%d %H:%M:%S.%f') - tu = datetime.strptime(cd['timestamp'], '%Y-%m-%d %H:%M:%S.%f') + cd = json.loads(cp["data"]) + tr = datetime.strptime(e["timestamp"], "%Y-%m-%d %H:%M:%S.%f") + tu = datetime.strptime(cd["timestamp"], "%Y-%m-%d %H:%M:%S.%f") ts = tu - tr base.append(tr.timestamp()) x.append(ts.total_seconds() / 60.0) - y.append(p['sender']['name']) - - trace_data.append(go.Bar( - x=y, - y=x, - marker=dict(color='lightskyblue'), - name="Validation", - )) + y.append(p["sender"]["name"]) + + trace_data.append( + go.Bar( + x=y, + y=x, + marker=dict(color="lightskyblue"), + name="Validation", + ) + ) layout = go.Layout( - barmode='stack', + barmode="stack", showlegend=True, ) fig = go.Figure(data=trace_data, layout=layout) - fig.update_xaxes(title_text='Alliance/client') - fig.update_yaxes(title_text='Time (Min)') - fig.update_layout(title_text='Alliance timeline') + fig.update_xaxes(title_text="Alliance/client") + fig.update_yaxes(title_text="Time (Min)") + fig.update_layout(title_text="Alliance timeline") timeline = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) return timeline @@ -189,16 +189,15 @@ def create_client_training_distribution(self): :return: """ training = [] - for p in self.status.find({'type': 'MODEL_UPDATE'}): - e = json.loads(p['data']) - meta = json.loads(e['meta']) - training.append(meta['exec_training']) + for p in self.status.find({"type": "MODEL_UPDATE"}): + e = json.loads(p["data"]) + meta = json.loads(e["meta"]) + training.append(meta["exec_training"]) if not training: return False fig = go.Figure(data=go.Histogram(x=training)) - fig.update_layout( - title_text='Client model training time, mean: {}'.format(numpy.mean(training))) + fig.update_layout(title_text="Client model training time, mean: {}".format(numpy.mean(training))) histogram = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) return histogram @@ -208,19 +207,18 @@ def create_client_histogram_plot(self): :return: """ training = [] - for p in self.status.find({'type': 'MODEL_UPDATE'}): - e = json.loads(p['data']) - meta = json.loads(e['meta']) - training.append(meta['exec_training']) + for p in self.status.find({"type": "MODEL_UPDATE"}): + e = json.loads(p["data"]) + meta = json.loads(e["meta"]) + training.append(meta["exec_training"]) fig = go.Figure() fig.update_layout( template="simple_white", xaxis=dict(title_text="Time (s)"), - yaxis=dict(title_text='Number of updates'), - title="Mean client training time: {}".format( - numpy.mean(training)), + yaxis=dict(title_text="Number of updates"), + title="Mean client training time: {}".format(numpy.mean(training)), # showlegend=True ) if not training: @@ -240,21 +238,16 @@ def create_client_plot(self): upload = [] download = [] training = [] - for p in self.status.find({'type': 'MODEL_UPDATE'}): - e = json.loads(p['data']) - meta = json.loads(e['meta']) - upload.append(meta['upload_model']) - download.append(meta['fetch_model']) - training.append(meta['exec_training']) - processing.append(meta['processing_time']) + for p in self.status.find({"type": "MODEL_UPDATE"}): + e = json.loads(p["data"]) + meta = json.loads(e["meta"]) + upload.append(meta["upload_model"]) + download.append(meta["fetch_model"]) + training.append(meta["exec_training"]) + processing.append(meta["processing_time"]) fig = go.Figure() - fig.update_layout( - template="simple_white", - title="Mean client processing time: {}".format( - numpy.mean(processing)), - showlegend=True - ) + fig.update_layout(template="simple_white", title="Mean client processing time: {}".format(numpy.mean(processing)), showlegend=True) if not processing: return False data = [numpy.mean(training), numpy.mean(upload), numpy.mean(download)] @@ -273,32 +266,25 @@ def create_combiner_plot(self): aggregation = [] model_load = [] combination = [] - for round in self.mdb['control.round'].find(): + for round in self.mdb["control.round"].find(): try: - for combiner in round['combiners']: + for combiner in round["combiners"]: data = combiner - stats = data['local_round']['1'] - ml = stats['aggregation_time']['time_model_load'] - ag = stats['aggregation_time']['time_model_aggregation'] - combination.append(stats['time_combination']) - waiting.append(stats['time_combination'] - ml - ag) + stats = data["local_round"]["1"] + ml = stats["aggregation_time"]["time_model_load"] + ag = stats["aggregation_time"]["time_model_aggregation"] + combination.append(stats["time_combination"]) + waiting.append(stats["time_combination"] - ml - ag) model_load.append(ml) aggregation.append(ag) except Exception: pass - labels = ['Waiting for client updates', - 'Aggregation', 'Loading model updates from disk'] - val = [numpy.mean(waiting), numpy.mean( - aggregation), numpy.mean(model_load)] + labels = ["Waiting for client updates", "Aggregation", "Loading model updates from disk"] + val = [numpy.mean(waiting), numpy.mean(aggregation), numpy.mean(model_load)] fig = go.Figure() - fig.update_layout( - template="simple_white", - title="Mean combiner round time: {}".format( - numpy.mean(combination)), - showlegend=True - ) + fig.update_layout(template="simple_white", title="Mean combiner round time: {}".format(numpy.mean(combination)), showlegend=True) if not combination: return False fig.add_trace(go.Pie(labels=labels, values=val)) @@ -310,7 +296,7 @@ def fetch_valid_metrics(self): :return: """ - metrics = self.status.find_one({'type': 'MODEL_VALIDATION'}) + metrics = self.status.find_one({"type": "MODEL_VALIDATION"}) valid_metrics = self._scalar_metrics(metrics) return valid_metrics @@ -320,34 +306,31 @@ def create_box_plot(self, metric): :param metric: :return: """ - metrics = self.status.find_one({'type': 'MODEL_VALIDATION'}) + metrics = self.status.find_one({"type": "MODEL_VALIDATION"}) if metrics is None: fig = go.Figure(data=[]) - fig.update_layout(title_text='No data currently available for metric distribution over ' - 'participants') + fig.update_layout(title_text="No data currently available for metric distribution over " "participants") box = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) return box valid_metrics = self._scalar_metrics(metrics) if valid_metrics == []: fig = go.Figure(data=[]) - fig.update_layout(title_text='No scalar metrics found') + fig.update_layout(title_text="No scalar metrics found") box = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) return box validations = {} - for post in self.status.find({'type': 'MODEL_VALIDATION'}): - e = json.loads(post['data']) + for post in self.status.find({"type": "MODEL_VALIDATION"}): + e = json.loads(post["data"]) try: - validations[e['modelId']].append( - float(json.loads(e['data'])[metric])) + validations[e["modelId"]].append(float(json.loads(e["data"])[metric])) except KeyError: - validations[e['modelId']] = [ - float(json.loads(e['data'])[metric])] + validations[e["modelId"]] = [float(json.loads(e["data"])[metric])] # Make sure validations are plotted in chronological order - model_trail = self.mdb.control.model.find_one({'key': 'model_trail'}) - model_trail_ids = model_trail['model'] + model_trail = self.mdb.control.model.find_one({"key": "model_trail"}) + model_trail_ids = model_trail["model"] validations_sorted = [] for model_id in model_trail_ids: try: @@ -364,24 +347,16 @@ def create_box_plot(self, metric): # x.append(j) y.append(numpy.mean([float(i) for i in acc])) if len(acc) >= 2: - box.add_trace(go.Box(y=acc, name=str(j), marker_color="royalblue", showlegend=False, - boxpoints=False)) + box.add_trace(go.Box(y=acc, name=str(j), marker_color="royalblue", showlegend=False, boxpoints=False)) else: - box.add_trace(go.Scatter( - x=[str(j)], y=[y[j]], showlegend=False)) + box.add_trace(go.Scatter(x=[str(j)], y=[y[j]], showlegend=False)) rounds = list(range(len(y))) - box.add_trace(go.Scatter( - x=rounds, - y=y, - name='Mean' - )) - - box.update_xaxes(title_text='Rounds') - box.update_yaxes( - tickvals=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]) - box.update_layout(title_text='Metric distribution over clients: {}'.format(metric), - margin=dict(l=20, r=20, t=45, b=20)) + box.add_trace(go.Scatter(x=rounds, y=y, name="Mean")) + + box.update_xaxes(title_text="Rounds") + box.update_yaxes(tickvals=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]) + box.update_layout(title_text="Metric distribution over clients: {}".format(metric), margin=dict(l=20, r=20, t=45, b=20)) box = json.dumps(box, cls=plotly.utils.PlotlyJSONEncoder) return box @@ -391,38 +366,27 @@ def create_round_plot(self): :return: """ trace_data = [] - metrics = self.round_time.find_one({'key': 'round_time'}) + metrics = self.round_time.find_one({"key": "round_time"}) if metrics is None: fig = go.Figure(data=[]) - fig.update_layout( - title_text='No data currently available for round time') + fig.update_layout(title_text="No data currently available for round time") return False - for post in self.round_time.find({'key': 'round_time'}): - rounds = post['round'] - traces_data = post['round_time'] + for post in self.round_time.find({"key": "round_time"}): + rounds = post["round"] + traces_data = post["round_time"] - trace_data.append(go.Scatter( - x=rounds, - y=traces_data, - mode='lines+markers', - name='Reducer' - )) + trace_data.append(go.Scatter(x=rounds, y=traces_data, mode="lines+markers", name="Reducer")) - for rec in self.combiner_round_time.find({'key': 'combiner_round_time'}): - c_traces_data = rec['round_time'] + for rec in self.combiner_round_time.find({"key": "combiner_round_time"}): + c_traces_data = rec["round_time"] - trace_data.append(go.Scatter( - x=rounds, - y=c_traces_data, - mode='lines+markers', - name='Combiner' - )) + trace_data.append(go.Scatter(x=rounds, y=c_traces_data, mode="lines+markers", name="Combiner")) fig = go.Figure(data=trace_data) - fig.update_xaxes(title_text='Round') - fig.update_yaxes(title_text='Time (s)') - fig.update_layout(title_text='Round time') + fig.update_xaxes(title_text="Round") + fig.update_yaxes(title_text="Time (s)") + fig.update_layout(title_text="Round time") round_t = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) return round_t @@ -431,46 +395,38 @@ def create_cpu_plot(self): :return: """ - metrics = self.psutil_usage.find_one({'key': 'cpu_mem_usage'}) + metrics = self.psutil_usage.find_one({"key": "cpu_mem_usage"}) if metrics is None: fig = go.Figure(data=[]) - fig.update_layout( - title_text='No data currently available for MEM and CPU usage') + fig.update_layout(title_text="No data currently available for MEM and CPU usage") cpu = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) return False - for post in self.psutil_usage.find({'key': 'cpu_mem_usage'}): - cpu = post['cpu'] - mem = post['mem'] - ps_time = post['time'] - round = post['round'] + for post in self.psutil_usage.find({"key": "cpu_mem_usage"}): + cpu = post["cpu"] + mem = post["mem"] + ps_time = post["time"] + round = post["round"] # Create figure with secondary y-axis fig = make_subplots(specs=[[{"secondary_y": True}]]) - fig.add_trace(go.Scatter( - x=ps_time, - y=cpu, - mode='lines+markers', - name='CPU (%)' - )) - - fig.add_trace(go.Scatter( - x=ps_time, - y=mem, - mode='lines+markers', - name='MEM (%)' - )) - - fig.add_trace(go.Scatter( - x=ps_time, - y=round, - mode='lines+markers', - name='Round', - ), secondary_y=True) - - fig.update_xaxes(title_text='Date Time') - fig.update_yaxes(title_text='Percentage (%)') + fig.add_trace(go.Scatter(x=ps_time, y=cpu, mode="lines+markers", name="CPU (%)")) + + fig.add_trace(go.Scatter(x=ps_time, y=mem, mode="lines+markers", name="MEM (%)")) + + fig.add_trace( + go.Scatter( + x=ps_time, + y=round, + mode="lines+markers", + name="Round", + ), + secondary_y=True, + ) + + fig.update_xaxes(title_text="Date Time") + fig.update_yaxes(title_text="Percentage (%)") fig.update_yaxes(title_text="Round", secondary_y=True) - fig.update_layout(title_text='CPU loads and memory usage') + fig.update_layout(title_text="CPU loads and memory usage") cpu = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) return cpu From fb76c474f0651b22582d79372290b33e1d831a94 Mon Sep 17 00:00:00 2001 From: Fredrik Wrede Date: Wed, 8 May 2024 13:22:48 +0200 Subject: [PATCH 3/5] Bug/SK-841 | Fix docs build after folder refactor #602 --- .readthedocs.yaml | 2 +- docs/conf.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.readthedocs.yaml b/.readthedocs.yaml index bc45dc53b..a1e30fdef 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -11,5 +11,5 @@ sphinx: python: install: - method: pip - path: ./fedn + path: . - requirements: docs/requirements.txt diff --git a/docs/conf.py b/docs/conf.py index 913c35d9c..cdbba39e5 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -4,7 +4,7 @@ import sphinx_rtd_theme # noqa: F401 # Insert path -sys.path.insert(0, os.path.abspath('../fedn')) +sys.path.insert(0, os.path.abspath('..')) # Project info project = 'FEDn' From 0281ad7037a484235448c1a6483763c1d51b5f1d Mon Sep 17 00:00:00 2001 From: Viktor Valadi <42983197+viktorvaladi@users.noreply.github.com> Date: Mon, 13 May 2024 10:43:43 +0200 Subject: [PATCH 4/5] Feature/SK-839 | Remove ruff ignores that can be autofixed (#601) --- docs/conf.py | 66 +++++++++---------- examples/async-clients/client/entrypoint.py | 1 - examples/async-clients/run_clients.py | 1 - examples/flower-client/client/entrypoint.py | 1 - examples/mnist-keras/client/entrypoint.py | 1 - fedn/cli/client_cmd.py | 12 +--- fedn/cli/combiner_cmd.py | 11 +--- fedn/cli/config_cmd.py | 3 +- fedn/cli/main.py | 4 +- fedn/cli/model_cmd.py | 7 +- fedn/cli/package_cmd.py | 7 +- fedn/cli/round_cmd.py | 7 +- fedn/cli/run_cmd.py | 16 ++--- fedn/cli/session_cmd.py | 7 +- fedn/cli/shared.py | 3 +- fedn/cli/status_cmd.py | 7 +- fedn/cli/validation_cmd.py | 7 +- fedn/common/certificate/certificate.py | 22 ++----- fedn/common/certificate/certificatemanager.py | 16 ++--- fedn/common/log_config.py | 6 +- fedn/network/api/client.py | 1 - fedn/network/api/interface.py | 10 +-- fedn/network/api/network.py | 1 - fedn/network/api/server.py | 7 -- fedn/network/clients/client.py | 9 +-- fedn/network/clients/connect.py | 6 +- fedn/network/clients/package.py | 1 - .../combiner/aggregators/aggregatorbase.py | 14 ++-- fedn/network/combiner/aggregators/fedavg.py | 12 ++-- fedn/network/combiner/aggregators/fedopt.py | 61 ++++++++--------- fedn/network/combiner/combiner.py | 14 +--- fedn/network/combiner/connect.py | 4 +- fedn/network/combiner/interfaces.py | 4 -- fedn/network/combiner/modelservice.py | 1 - fedn/network/combiner/roundhandler.py | 8 --- fedn/network/controller/control.py | 6 -- fedn/network/controller/controlbase.py | 19 ++---- fedn/network/grpc/__init__.py | 2 +- fedn/network/grpc/server.py | 15 ++--- fedn/network/loadbalancer/firstavailable.py | 1 - fedn/network/loadbalancer/leastpacked.py | 8 +-- fedn/network/storage/models/__init__.py | 3 +- .../storage/models/memorymodelstorage.py | 4 +- .../storage/models/tempmodelstorage.py | 4 +- fedn/network/storage/s3/__init__.py | 3 +- fedn/network/storage/s3/miniorepository.py | 2 - fedn/network/storage/s3/repository.py | 2 - .../storage/statestore/mongostatestore.py | 16 ----- .../storage/statestore/stores/client_store.py | 20 +++--- .../statestore/stores/combiner_store.py | 32 ++++----- .../storage/statestore/stores/model_store.py | 34 +++++----- .../statestore/stores/package_store.py | 28 ++++---- .../storage/statestore/stores/round_store.py | 16 ++--- .../statestore/stores/session_store.py | 16 ++--- .../storage/statestore/stores/shared.py | 4 +- .../storage/statestore/stores/status_store.py | 26 ++++---- .../storage/statestore/stores/store.py | 4 +- .../statestore/stores/validation_store.py | 24 +++---- fedn/utils/dispatcher.py | 12 ++-- fedn/utils/environment.py | 6 +- fedn/utils/helpers/helperbase.py | 1 - fedn/utils/helpers/plugins/numpyhelper.py | 8 --- fedn/utils/plots.py | 41 +++--------- fedn/utils/process.py | 3 +- pyproject.toml | 15 +---- 65 files changed, 260 insertions(+), 473 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index cdbba39e5..6d74d6539 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -4,31 +4,31 @@ import sphinx_rtd_theme # noqa: F401 # Insert path -sys.path.insert(0, os.path.abspath('..')) +sys.path.insert(0, os.path.abspath("..")) # Project info -project = 'FEDn' -copyright = '2021, Scaleout Systems AB' -author = 'Scaleout Systems AB' +project = "FEDn" +copyright = "2021, Scaleout Systems AB" +author = "Scaleout Systems AB" # The full version, including alpha/beta/rc tags -release = '0.9.2' +release = "0.9.2" # Add any Sphinx extension module names here, as strings extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.doctest', - 'sphinx.ext.intersphinx', - 'sphinx.ext.coverage', - 'sphinx.ext.mathjax', - 'sphinx.ext.ifconfig', - 'sphinx.ext.viewcode', - 'sphinx_rtd_theme', - 'sphinx_code_tabs' + "sphinx.ext.autodoc", + "sphinx.ext.doctest", + "sphinx.ext.intersphinx", + "sphinx.ext.coverage", + "sphinx.ext.mathjax", + "sphinx.ext.ifconfig", + "sphinx.ext.viewcode", + "sphinx_rtd_theme", + "sphinx_code_tabs" ] # The master toctree document. -master_doc = 'index' +master_doc = "index" # Add any paths that contain templates here, relative to this directory. templates_path = [] @@ -39,31 +39,31 @@ exclude_patterns = [] # The theme to use for HTML and HTML Help pages. -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" html_theme_options = { - 'logo_only': True, + "logo_only": True, } # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # Output file base name for HTML help builder. -htmlhelp_basename = 'fedndocs' +htmlhelp_basename = "fedndocs" # If defined shows an image instead of project name on page top-left (link to index page) -html_logo = '_static/images/scaleout_logo_flat_dark.svg' +html_logo = "_static/images/scaleout_logo_flat_dark.svg" # FEDn logo looks ugly on rtd theme -html_favicon = 'favicon.png' +html_favicon = "favicon.png" # Here we assume that the file is at _static/custom.css html_css_files = [ - 'css/elements.css', - 'css/text.css', - 'css/utilities.css', + "css/elements.css", + "css/text.css", + "css/utilities.css", ] # LaTeX elements @@ -89,14 +89,14 @@ # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'fedn.tex', 'FEDn Documentation', - 'Scaleout Systems AB', 'manual'), + (master_doc, "fedn.tex", "FEDn Documentation", + "Scaleout Systems AB", "manual"), ] # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ - (master_doc, 'fedn', 'FEDn Documentation', + (master_doc, "fedn", "FEDn Documentation", [author], 1) ] @@ -104,17 +104,17 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'fedn', 'FEDn Documentation', - author, 'fedn', 'One line description of project.', - 'Miscellaneous'), + (master_doc, "fedn", "FEDn Documentation", + author, "fedn", "One line description of project.", + "Miscellaneous"), ] # Bibliographic Dublin Core info. epub_title = project -epub_exclude_files = ['search.html'] +epub_exclude_files = ["search.html"] # Example configuration for intersphinx: refer to the Python standard library. -intersphinx_mapping = {'https://docs.python.org/': None} +intersphinx_mapping = {"https://docs.python.org/": None} -pygments_style = 'sphinx' +pygments_style = "sphinx" diff --git a/examples/async-clients/client/entrypoint.py b/examples/async-clients/client/entrypoint.py index 220b5299b..0b8e10668 100644 --- a/examples/async-clients/client/entrypoint.py +++ b/examples/async-clients/client/entrypoint.py @@ -79,7 +79,6 @@ def make_data(n_min=50, n_max=100): def train(in_model_path, out_model_path): """Train model.""" - # Load model parameters = load_parameters(in_model_path) model = compile_model() diff --git a/examples/async-clients/run_clients.py b/examples/async-clients/run_clients.py index 82da30ad9..f2ce72291 100644 --- a/examples/async-clients/run_clients.py +++ b/examples/async-clients/run_clients.py @@ -68,7 +68,6 @@ def run_client(online_for=120, name="client"): This is repeated for N_CYCLES. """ - conf = copy.deepcopy(client_config) conf["name"] = name diff --git a/examples/flower-client/client/entrypoint.py b/examples/flower-client/client/entrypoint.py index 1a9a8b8cf..a790644ef 100755 --- a/examples/flower-client/client/entrypoint.py +++ b/examples/flower-client/client/entrypoint.py @@ -14,7 +14,6 @@ def _get_node_id(): """Get client number from environment variable.""" - number = os.environ.get("CLIENT_NUMBER", "0") return int(number) diff --git a/examples/mnist-keras/client/entrypoint.py b/examples/mnist-keras/client/entrypoint.py index 5420a78bb..1ed8f2f77 100755 --- a/examples/mnist-keras/client/entrypoint.py +++ b/examples/mnist-keras/client/entrypoint.py @@ -135,7 +135,6 @@ def validate(in_model_path, out_json_path, data_path=None): :param data_path: The path to the data file. :type data_path: str """ - # Load data x_train, y_train = load_data(data_path) x_test, y_test = load_data(data_path, is_train=False) diff --git a/fedn/cli/client_cmd.py b/fedn/cli/client_cmd.py index e72f29569..80b0b3353 100644 --- a/fedn/cli/client_cmd.py +++ b/fedn/cli/client_cmd.py @@ -15,7 +15,6 @@ def validate_client_config(config): :param config: Client config (dict). """ - try: if config["discover_host"] is None or config["discover_host"] == "": raise InvalidClientConfig("Missing required configuration: discover_host") @@ -28,9 +27,7 @@ def validate_client_config(config): @main.group("client") @click.pass_context def client_cmd(ctx): - """ - - :param ctx: + """:param ctx: """ pass @@ -43,8 +40,7 @@ def client_cmd(ctx): @client_cmd.command("list") @click.pass_context def list_clients(ctx, protocol: str, host: str, port: str, token: str = None, n_max: int = None): - """ - Return: + """Return: ------ - count: number of clients - result: list of clients @@ -114,9 +110,7 @@ def client_cmd( reconnect_after_missed_heartbeat, verbosity, ): - """ - - :param ctx: + """:param ctx: :param discoverhost: :param discoverport: :param token: diff --git a/fedn/cli/combiner_cmd.py b/fedn/cli/combiner_cmd.py index 2b4447437..02a797448 100644 --- a/fedn/cli/combiner_cmd.py +++ b/fedn/cli/combiner_cmd.py @@ -12,9 +12,7 @@ @main.group("combiner") @click.pass_context def combiner_cmd(ctx): - """ - - :param ctx: + """:param ctx: """ pass @@ -33,9 +31,7 @@ def combiner_cmd(ctx): @click.option("-in", "--init", required=False, default=None, help="Path to configuration file to (re)init combiner.") @click.pass_context def start_cmd(ctx, discoverhost, discoverport, token, name, host, port, fqdn, secure, verify, max_clients, init): - """ - - :param ctx: + """:param ctx: :param discoverhost: :param discoverport: :param token: @@ -76,8 +72,7 @@ def start_cmd(ctx, discoverhost, discoverport, token, name, host, port, fqdn, se @combiner_cmd.command("list") @click.pass_context def list_combiners(ctx, protocol: str, host: str, port: str, token: str = None, n_max: int = None): - """ - Return: + """Return: ------ - count: number of combiners - result: list of combiners diff --git a/fedn/cli/config_cmd.py b/fedn/cli/config_cmd.py index d5286997f..0b77260c3 100644 --- a/fedn/cli/config_cmd.py +++ b/fedn/cli/config_cmd.py @@ -21,8 +21,7 @@ @main.group("config", invoke_without_command=True) @click.pass_context def config_cmd(ctx): - """ - - Configuration commands for the FEDn CLI. + """- Configuration commands for the FEDn CLI. """ if ctx.invoked_subcommand is None: click.echo("\n--- FEDn Cli Configuration ---\n") diff --git a/fedn/cli/main.py b/fedn/cli/main.py index 52276c418..d6f912e62 100644 --- a/fedn/cli/main.py +++ b/fedn/cli/main.py @@ -9,8 +9,6 @@ @click.group(context_settings=CONTEXT_SETTINGS) @click.pass_context def main(ctx): - """ - - :param ctx: + """:param ctx: """ ctx.obj = dict() diff --git a/fedn/cli/model_cmd.py b/fedn/cli/model_cmd.py index e44793a9f..80a8f795e 100644 --- a/fedn/cli/model_cmd.py +++ b/fedn/cli/model_cmd.py @@ -8,9 +8,7 @@ @main.group("model") @click.pass_context def model_cmd(ctx): - """ - - :param ctx: + """:param ctx: """ pass @@ -23,8 +21,7 @@ def model_cmd(ctx): @model_cmd.command("list") @click.pass_context def list_models(ctx, protocol: str, host: str, port: str, token: str = None, n_max: int = None): - """ - Return: + """Return: ------ - count: number of models - result: list of models diff --git a/fedn/cli/package_cmd.py b/fedn/cli/package_cmd.py index 6d503d414..3c78d9944 100644 --- a/fedn/cli/package_cmd.py +++ b/fedn/cli/package_cmd.py @@ -13,9 +13,7 @@ @main.group("package") @click.pass_context def package_cmd(ctx): - """ - - :param ctx: + """:param ctx: """ pass @@ -51,8 +49,7 @@ def create_cmd(ctx, path, name): @package_cmd.command("list") @click.pass_context def list_packages(ctx, protocol: str, host: str, port: str, token: str = None, n_max: int = None): - """ - Return: + """Return: ------ - count: number of packages - result: list of packages diff --git a/fedn/cli/round_cmd.py b/fedn/cli/round_cmd.py index ca23cafe7..ac42f43ef 100644 --- a/fedn/cli/round_cmd.py +++ b/fedn/cli/round_cmd.py @@ -8,9 +8,7 @@ @main.group("round") @click.pass_context def round_cmd(ctx): - """ - - :param ctx: + """:param ctx: """ pass @@ -23,8 +21,7 @@ def round_cmd(ctx): @round_cmd.command("list") @click.pass_context def list_rounds(ctx, protocol: str, host: str, port: str, token: str = None, n_max: int = None): - """ - Return: + """Return: ------ - count: number of rounds - result: list of rounds diff --git a/fedn/cli/run_cmd.py b/fedn/cli/run_cmd.py index b9fe4528e..123f17320 100644 --- a/fedn/cli/run_cmd.py +++ b/fedn/cli/run_cmd.py @@ -17,9 +17,7 @@ def get_statestore_config_from_file(init): - """ - - :param init: + """:param init: :return: """ with open(init, "r") as file: @@ -43,9 +41,7 @@ def check_helper_config_file(config): @main.group("run") @click.pass_context def run_cmd(ctx): - """ - - :param ctx: + """:param ctx: """ pass @@ -125,9 +121,7 @@ def client_cmd( reconnect_after_missed_heartbeat, verbosity, ): - """ - - :param ctx: + """:param ctx: :param discoverhost: :param discoverport: :param token: @@ -201,9 +195,7 @@ def client_cmd( @click.option("-in", "--init", required=False, default=None, help="Path to configuration file to (re)init combiner.") @click.pass_context def combiner_cmd(ctx, discoverhost, discoverport, token, name, host, port, fqdn, secure, verify, max_clients, init): - """ - - :param ctx: + """:param ctx: :param discoverhost: :param discoverport: :param token: diff --git a/fedn/cli/session_cmd.py b/fedn/cli/session_cmd.py index 55597b5b3..65db98c69 100644 --- a/fedn/cli/session_cmd.py +++ b/fedn/cli/session_cmd.py @@ -8,9 +8,7 @@ @main.group("session") @click.pass_context def session_cmd(ctx): - """ - - :param ctx: + """:param ctx: """ pass @@ -23,8 +21,7 @@ def session_cmd(ctx): @session_cmd.command("list") @click.pass_context def list_sessions(ctx, protocol: str, host: str, port: str, token: str = None, n_max: int = None): - """ - Return: + """Return: ------ - count: number of sessions - result: list of sessions diff --git a/fedn/cli/shared.py b/fedn/cli/shared.py index 2500d9e2b..d32f4ff43 100644 --- a/fedn/cli/shared.py +++ b/fedn/cli/shared.py @@ -65,8 +65,7 @@ def get_client_package_dir(path: str) -> str: # Print response from api (list of entities) def print_response(response, entity_name: str): - """ - Prints the api response to the cli. + """Prints the api response to the cli. :param response: type: array description: list of entities diff --git a/fedn/cli/status_cmd.py b/fedn/cli/status_cmd.py index a4f17e349..078acaf13 100644 --- a/fedn/cli/status_cmd.py +++ b/fedn/cli/status_cmd.py @@ -8,9 +8,7 @@ @main.group("status") @click.pass_context def status_cmd(ctx): - """ - - :param ctx: + """:param ctx: """ pass @@ -23,8 +21,7 @@ def status_cmd(ctx): @status_cmd.command("list") @click.pass_context def list_statuses(ctx, protocol: str, host: str, port: str, token: str = None, n_max: int = None): - """ - Return: + """Return: ------ - count: number of statuses - result: list of statuses diff --git a/fedn/cli/validation_cmd.py b/fedn/cli/validation_cmd.py index 055be0c65..4bf4e63fa 100644 --- a/fedn/cli/validation_cmd.py +++ b/fedn/cli/validation_cmd.py @@ -8,9 +8,7 @@ @main.group("validation") @click.pass_context def validation_cmd(ctx): - """ - - :param ctx: + """:param ctx: """ pass @@ -23,8 +21,7 @@ def validation_cmd(ctx): @validation_cmd.command("list") @click.pass_context def list_validations(ctx, protocol: str, host: str, port: str, token: str = None, n_max: int = None): - """ - Return: + """Return: ------ - count: number of validations - result: list of validations diff --git a/fedn/common/certificate/certificate.py b/fedn/common/certificate/certificate.py index 857a05e7c..3cb09016c 100644 --- a/fedn/common/certificate/certificate.py +++ b/fedn/common/certificate/certificate.py @@ -9,8 +9,7 @@ class Certificate: - """ - Utility to generate unsigned certificates. + """Utility to generate unsigned certificates. """ @@ -37,8 +36,7 @@ def __init__(self, cwd, name=None, key_name="key.pem", cert_name="cert.pem", cre def gen_keypair( self, ): - """ - Generate keypair. + """Generate keypair. """ key = crypto.PKey() @@ -65,9 +63,7 @@ def gen_keypair( certfile.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert)) def set_keypair_raw(self, certificate, privatekey): - """ - - :param certificate: + """:param certificate: :param privatekey: """ with open(self.key_path, "wb") as keyfile: @@ -77,9 +73,7 @@ def set_keypair_raw(self, certificate, privatekey): certfile.write(crypto.dump_certificate(crypto.FILETYPE_PEM, certificate)) def get_keypair_raw(self): - """ - - :return: + """:return: """ with open(self.key_path, "rb") as keyfile: key_buf = keyfile.read() @@ -88,9 +82,7 @@ def get_keypair_raw(self): return copy.deepcopy(cert_buf), copy.deepcopy(key_buf) def get_key(self): - """ - - :return: + """:return: """ with open(self.key_path, "rb") as keyfile: key_buf = keyfile.read() @@ -98,9 +90,7 @@ def get_key(self): return key def get_cert(self): - """ - - :return: + """:return: """ with open(self.cert_path, "rb") as certfile: cert_buf = certfile.read() diff --git a/fedn/common/certificate/certificatemanager.py b/fedn/common/certificate/certificatemanager.py index ce165d862..172d799ed 100644 --- a/fedn/common/certificate/certificatemanager.py +++ b/fedn/common/certificate/certificatemanager.py @@ -4,8 +4,7 @@ class CertificateManager: - """ - Utility to handle certificates for both Reducer and Combiner services. + """Utility to handle certificates for both Reducer and Combiner services. """ @@ -16,8 +15,7 @@ def __init__(self, directory): self.load_all() def get_or_create(self, name): - """ - Look for an existing certificate, if not found, generate a self-signed certificate based on name. + """Look for an existing certificate, if not found, generate a self-signed certificate based on name. :param name: The name used when issuing the certificate. :return: A certificate @@ -33,8 +31,7 @@ def get_or_create(self, name): return cert def add(self, certificate): - """ - Add certificate to certificate list. + """Add certificate to certificate list. :param certificate: :return: Success status (True, False) @@ -46,8 +43,7 @@ def add(self, certificate): return False def load_all(self): - """ - Load all certificates and add to certificates list. + """Load all certificates and add to certificates list. """ for filename in sorted(os.listdir(self.directory)): @@ -59,9 +55,7 @@ def load_all(self): self.certificates.append(c) def find(self, name): - """ - - :param name: Name of certificate + """:param name: Name of certificate :return: certificate if successful, else None """ for cert in self.certificates: diff --git a/fedn/common/log_config.py b/fedn/common/log_config.py index b8aa1218b..0d3ddb96c 100644 --- a/fedn/common/log_config.py +++ b/fedn/common/log_config.py @@ -62,8 +62,7 @@ def emit(self, record): def set_log_level_from_string(level_str): - """ - Set the log level based on a string input. + """Set the log level based on a string input. """ # Mapping of string representation to logging constants level_mapping = { @@ -85,8 +84,7 @@ def set_log_level_from_string(level_str): def set_log_stream(log_file): - """ - Redirect the log stream to a specified file, if log_file is set. + """Redirect the log stream to a specified file, if log_file is set. """ if not log_file: return diff --git a/fedn/network/api/client.py b/fedn/network/api/client.py index 43678fbcf..cd0ca5a7a 100644 --- a/fedn/network/api/client.py +++ b/fedn/network/api/client.py @@ -497,7 +497,6 @@ def get_session(self, id: str): :return: Session. :rtype: dict """ - response = requests.get(self._get_url_api_v1(f"sessions/{id}"), self.verify, headers=self.headers) _json = response.json() diff --git a/fedn/network/api/interface.py b/fedn/network/api/interface.py index 718cd8a18..663add77e 100644 --- a/fedn/network/api/interface.py +++ b/fedn/network/api/interface.py @@ -194,7 +194,6 @@ def set_compute_package(self, file, helper_type: str, name: str = None, descript :return: A json response with success or failure message. :rtype: :class:`flask.Response` """ - if self.control.state() == ReducerState.instructing or self.control.state() == ReducerState.monitoring: return ( jsonify( @@ -307,7 +306,6 @@ def list_compute_packages(self, limit: str = None, skip: str = None, include_act :return: All compute packages as a json response. :rtype: :class:`flask.Response` """ - if limit is not None and skip is not None: limit = int(limit) skip = int(skip) @@ -397,7 +395,6 @@ def _create_checksum(self, name=None): :return: Success or failure boolean, message and the checksum. :rtype: bool, str, str """ - if name is None: name, message = self._get_compute_package_name() if name is None: @@ -418,7 +415,6 @@ def get_checksum(self, name): :return: The checksum as a json object. :rtype: :py:class:`flask.Response` """ - success, message, sum = self._create_checksum(name) if not success: return jsonify({"success": False, "message": message}), 404 @@ -816,7 +812,6 @@ def get_model_descendants(self, model_id: str, limit: str = None): :return: The model descendants for the given model as a json response. :rtype: :class:`flask.Response` """ - if model_id is None: return jsonify({"success": False, "message": "No model id provided."}) @@ -868,8 +863,7 @@ def get_all_rounds(self): "combiners": combiners, } payload[id] = info - else: - return jsonify(payload) + return jsonify(payload) def get_round(self, round_id): """Get a round. @@ -915,7 +909,6 @@ def get_plot_data(self, feature=None): :return: The plot data as json response. :rtype: :py:class:`flask.Response` """ - plot = Plot(self.control.statestore) try: @@ -942,7 +935,6 @@ def list_combiners_data(self, combiners): :return: The combiners data as json response. :rtype: :py:class:`flask.Response` """ - response = self.statestore.list_combiners_data(combiners) arr = [] diff --git a/fedn/network/api/network.py b/fedn/network/api/network.py index cb105f10a..045f8aa34 100644 --- a/fedn/network/api/network.py +++ b/fedn/network/api/network.py @@ -113,7 +113,6 @@ def add_client(self, client): :type client: dict :return: None """ - if self.get_client(client["name"]): return diff --git a/fedn/network/api/server.py b/fedn/network/api/server.py index 5f645e4e2..c196da762 100644 --- a/fedn/network/api/server.py +++ b/fedn/network/api/server.py @@ -105,7 +105,6 @@ def list_models(): Returns: _type_: json """ - session_id = request.args.get("session_id", None) limit = request.args.get("limit", None) skip = request.args.get("skip", None) @@ -161,7 +160,6 @@ def list_clients(): return: All clients as a json object. rtype: json """ - limit = request.args.get("limit", None) skip = request.args.get("skip", None) status = request.args.get("status", None) @@ -202,7 +200,6 @@ def list_combiners(): return: All combiners as a json object. rtype: json """ - limit = request.args.get("limit", None) skip = request.args.get("skip", None) @@ -389,7 +386,6 @@ def list_compute_packages(): return: The compute package as a json object. rtype: json """ - limit = request.args.get("limit", None) skip = request.args.get("skip", None) include_active = request.args.get("include_active", None) @@ -596,7 +592,6 @@ def add_client(): return: The response from control. rtype: json """ - json_data = request.get_json() remote_addr = request.remote_addr try: @@ -617,7 +612,6 @@ def list_combiners_data(): return: The response from control. rtype: json """ - json_data = request.get_json() # expects a list of combiner names (strings) in an array @@ -640,7 +634,6 @@ def get_plot_data(): """Get plot data from the statestore. rtype: json """ - try: feature = request.args.get("feature", None) response = api.get_plot_data(feature=feature) diff --git a/fedn/network/clients/client.py b/fedn/network/clients/client.py index c8a5afc4f..70fe005ff 100644 --- a/fedn/network/clients/client.py +++ b/fedn/network/clients/client.py @@ -114,7 +114,6 @@ def assign(self): :return: A configuration dictionary containing connection information for combiner. :rtype: dict """ - logger.info("Initiating assignment request.") while True: status, response = self.connector.assign() @@ -179,10 +178,9 @@ def connect(self, combiner_config): :param combiner_config: connection information for the combiner. :type combiner_config: dict """ - if self._connected: logger.info("Client is already attached. ") - return None + return # TODO use the combiner_config['certificate'] for setting up secure comms' host = combiner_config["host"] @@ -257,7 +255,6 @@ def _initialize_helper(self, combiner_config): :type combiner_config: dict :return: """ - if "helper_type" in combiner_config.keys(): self.helper = get_helper(combiner_config["helper_type"]) @@ -268,7 +265,6 @@ def _subscribe_to_combiner(self, config): | the discovery service (controller) and settings governing e.g. | client-combiner assignment behavior. """ - # Start sending heartbeats to the combiner. threading.Thread(target=self._send_heartbeat, kwargs={"update_frequency": config["heartbeat_interval"]}, daemon=True).start() @@ -420,7 +416,6 @@ def _listen_to_task_stream(self): :return: None :rtype: None """ - r = fedn.ClientAvailableMessage() r.sender.name = self.name r.sender.role = fedn.WORKER @@ -489,7 +484,6 @@ def _process_training_request(self, model_id: str, session_id: str = None): :return: The model id of the updated model, or None if the update failed. And a dict with metadata. :rtype: tuple """ - self.send_status("\t Starting processing of training request for model_id {}".format(model_id), sesssion_id=session_id) self.state = ClientState.training @@ -740,7 +734,6 @@ def send_status(self, msg, log_level=fedn.Status.INFO, type=None, request=None, :param request: The request message. :type request: fedn.Request """ - if not self._connected: logger.info("SendStatus: Client disconnected.") return diff --git a/fedn/network/clients/connect.py b/fedn/network/clients/connect.py index efeb3d1e9..09450c5ab 100644 --- a/fedn/network/clients/connect.py +++ b/fedn/network/clients/connect.py @@ -67,8 +67,7 @@ def __init__(self, host, port, token, name, remote_package, force_ssl=False, ver logger.info("Setting connection string to {}.".format(self.connect_string)) def assign(self): - """ - Connect client to FEDn network discovery service, ask for combiner assignment. + """Connect client to FEDn network discovery service, ask for combiner assignment. :return: Tuple with assingment status, combiner connection information if sucessful, else None. :rtype: tuple(:class:`fedn.network.clients.connect.Status`, str) @@ -127,8 +126,7 @@ def assign(self): return Status.Unassigned, None def refresh_token(self): - """ - Refresh client token. + """Refresh client token. :return: Tuple with assingment status, combiner connection information if sucessful, else None. :rtype: tuple(:class:`fedn.network.clients.connect.Status`, str) diff --git a/fedn/network/clients/package.py b/fedn/network/clients/package.py index 54f45b883..f99d12d49 100644 --- a/fedn/network/clients/package.py +++ b/fedn/network/clients/package.py @@ -153,7 +153,6 @@ def dispatcher(self, run_path): :return: Dispatcher object :rtype: :class:`fedn.utils.dispatcher.Dispatcher` """ - self.dispatch_config = _read_yaml_file(os.path.join(run_path, "fedn.yaml")) dispatcher = Dispatcher(self.dispatch_config, run_path) diff --git a/fedn/network/combiner/aggregators/aggregatorbase.py b/fedn/network/combiner/aggregators/aggregatorbase.py index e0053cb6e..0a9c33f43 100644 --- a/fedn/network/combiner/aggregators/aggregatorbase.py +++ b/fedn/network/combiner/aggregators/aggregatorbase.py @@ -86,8 +86,8 @@ def _validate_model_update(self, model_update): :return: True if the model update is valid, False otherwise. :rtype: bool """ - data = json.loads(model_update.meta)['training_metadata'] - if 'num_examples' not in data.keys(): + data = json.loads(model_update.meta)["training_metadata"] + if "num_examples" not in data.keys(): logger.error("AGGREGATOR({}): Model validation failed, num_examples missing in metadata.".format(self.name)) return False return True @@ -120,21 +120,21 @@ def load_model_update(self, model_update, helper): model = self.round_handler.load_model_update(helper, model_id) # Get relevant metadata metadata = json.loads(model_update.meta) - if 'config' in metadata.keys(): + if "config" in metadata.keys(): # Used in Python client - config = json.loads(metadata['config']) + config = json.loads(metadata["config"]) else: # Used in C++ client config = json.loads(model_update.config) - training_metadata = metadata['training_metadata'] - training_metadata['round_id'] = config['round_id'] + training_metadata = metadata["training_metadata"] + training_metadata["round_id"] = config["round_id"] return model, training_metadata def get_state(self): """ Get the state of the aggregator's queue, including the number of model updates.""" state = { - 'queue_len': self.model_updates.qsize() + "queue_len": self.model_updates.qsize() } return state diff --git a/fedn/network/combiner/aggregators/fedavg.py b/fedn/network/combiner/aggregators/fedavg.py index 19ce84803..9ed0adf3c 100644 --- a/fedn/network/combiner/aggregators/fedavg.py +++ b/fedn/network/combiner/aggregators/fedavg.py @@ -21,7 +21,6 @@ class Aggregator(AggregatorBase): def __init__(self, storage, server, modelservice, round_handler): """Constructor method""" - super().__init__(storage, server, modelservice, round_handler) self.name = "fedavg" @@ -41,10 +40,9 @@ def combine_models(self, helper=None, delete_models=True, parameters=None): :return: The global model and metadata :rtype: tuple """ - data = {} - data['time_model_load'] = 0.0 - data['time_model_aggregation'] = 0.0 + data["time_model_load"] = 0.0 + data["time_model_aggregation"] = 0.0 model = None nr_aggregated_models = 0 @@ -67,13 +65,13 @@ def combine_models(self, helper=None, delete_models=True, parameters=None): "AGGREGATOR({}): Processing model update {}, metadata: {} ".format(self.name, model_update.model_update_id, metadata)) # Increment total number of examples - total_examples += metadata['num_examples'] + total_examples += metadata["num_examples"] if nr_aggregated_models == 0: model = model_next else: model = helper.increment_average( - model, model_next, metadata['num_examples'], total_examples) + model, model_next, metadata["num_examples"], total_examples) nr_aggregated_models += 1 # Delete model from storage @@ -87,7 +85,7 @@ def combine_models(self, helper=None, delete_models=True, parameters=None): "AGGREGATOR({}): Error encoutered while processing model update {}, skipping this update.".format(self.name, e)) self.model_updates.task_done() - data['nr_aggregated_models'] = nr_aggregated_models + data["nr_aggregated_models"] = nr_aggregated_models logger.info("AGGREGATOR({}): Aggregation completed, aggregated {} models.".format(self.name, nr_aggregated_models)) return model, data diff --git a/fedn/network/combiner/aggregators/fedopt.py b/fedn/network/combiner/aggregators/fedopt.py index 305340f10..5041e097f 100644 --- a/fedn/network/combiner/aggregators/fedopt.py +++ b/fedn/network/combiner/aggregators/fedopt.py @@ -55,18 +55,17 @@ def combine_models(self, helper=None, delete_models=True, parameters=None): :return: The global model and metadata :rtype: tuple """ - data = {} - data['time_model_load'] = 0.0 - data['time_model_aggregation'] = 0.0 + data["time_model_load"] = 0.0 + data["time_model_aggregation"] = 0.0 # Define parameter schema parameter_schema = { - 'serveropt': str, - 'learning_rate': float, - 'beta1': float, - 'beta2': float, - 'tau': float, + "serveropt": str, + "learning_rate": float, + "beta1": float, + "beta2": float, + "tau": float, } try: @@ -77,11 +76,11 @@ def combine_models(self, helper=None, delete_models=True, parameters=None): # Default hyperparameters. Note that these may need fine tuning. default_parameters = { - 'serveropt': 'adam', - 'learning_rate': 1e-3, - 'beta1': 0.9, - 'beta2': 0.99, - 'tau': 1e-4, + "serveropt": "adam", + "learning_rate": 1e-3, + "beta1": 0.9, + "beta2": 0.99, + "tau": 1e-4, } # Validate parameters @@ -119,7 +118,7 @@ def combine_models(self, helper=None, delete_models=True, parameters=None): "AGGREGATOR({}): Processing model update {}".format(self.name, model_update.model_update_id)) # Increment total number of examples - total_examples += metadata['num_examples'] + total_examples += metadata["num_examples"] if nr_aggregated_models == 0: model_old = self.round_handler.load_model_update(helper, model_update.model_id) @@ -127,7 +126,7 @@ def combine_models(self, helper=None, delete_models=True, parameters=None): else: pseudo_gradient_next = helper.subtract(model_next, model_old) pseudo_gradient = helper.increment_average( - pseudo_gradient, pseudo_gradient_next, metadata['num_examples'], total_examples) + pseudo_gradient, pseudo_gradient_next, metadata["num_examples"], total_examples) nr_aggregated_models += 1 # Delete model from storage @@ -141,17 +140,17 @@ def combine_models(self, helper=None, delete_models=True, parameters=None): "AGGREGATOR({}): Error encoutered while processing model update {}, skipping this update.".format(self.name, e)) self.model_updates.task_done() - if parameters['serveropt'] == 'adam': + if parameters["serveropt"] == "adam": model = self.serveropt_adam(helper, pseudo_gradient, model_old, parameters) - elif parameters['serveropt'] == 'yogi': + elif parameters["serveropt"] == "yogi": model = self.serveropt_yogi(helper, pseudo_gradient, model_old, parameters) - elif parameters['serveropt'] == 'adagrad': + elif parameters["serveropt"] == "adagrad": model = self.serveropt_adagrad(helper, pseudo_gradient, model_old, parameters) else: logger.error("Unsupported server optimizer passed to FedOpt.") return None, data - data['nr_aggregated_models'] = nr_aggregated_models + data["nr_aggregated_models"] = nr_aggregated_models logger.info("AGGREGATOR({}): Aggregation completed, aggregated {} models.".format(self.name, nr_aggregated_models)) return model, data @@ -170,10 +169,10 @@ def serveropt_adam(self, helper, pseudo_gradient, model_old, parameters): :return: new model weights. :rtype: as defined by helper. """ - beta1 = parameters['beta1'] - beta2 = parameters['beta2'] - learning_rate = parameters['learning_rate'] - tau = parameters['tau'] + beta1 = parameters["beta1"] + beta2 = parameters["beta2"] + learning_rate = parameters["learning_rate"] + tau = parameters["tau"] if not self.v: self.v = helper.ones(pseudo_gradient, math.pow(tau, 2)) @@ -206,11 +205,10 @@ def serveropt_yogi(self, helper, pseudo_gradient, model_old, parameters): :return: new model weights. :rtype: as defined by helper. """ - - beta1 = parameters['beta1'] - beta2 = parameters['beta2'] - learning_rate = parameters['learning_rate'] - tau = parameters['tau'] + beta1 = parameters["beta1"] + beta2 = parameters["beta2"] + learning_rate = parameters["learning_rate"] + tau = parameters["tau"] if not self.v: self.v = helper.ones(pseudo_gradient, math.pow(tau, 2)) @@ -245,10 +243,9 @@ def serveropt_adagrad(self, helper, pseudo_gradient, model_old, parameters): :return: new model weights. :rtype: as defined by helper. """ - - beta1 = parameters['beta1'] - learning_rate = parameters['learning_rate'] - tau = parameters['tau'] + beta1 = parameters["beta1"] + learning_rate = parameters["learning_rate"] + tau = parameters["tau"] if not self.v: self.v = helper.ones(pseudo_gradient, math.pow(tau, 2)) diff --git a/fedn/network/combiner/combiner.py b/fedn/network/combiner/combiner.py index f19674b73..450b8b689 100644 --- a/fedn/network/combiner/combiner.py +++ b/fedn/network/combiner/combiner.py @@ -59,7 +59,6 @@ class Combiner(rpc.CombinerServicer, rpc.ReducerServicer, rpc.ConnectorServicer, def __init__(self, config): """Initialize Combiner server.""" - set_log_level_from_string(config.get("verbosity", "INFO")) set_log_stream(config.get("logfile", None)) @@ -327,11 +326,9 @@ def _list_active_clients(self, channel): if status != "online": self.clients[client]["status"] = "online" clients["update_active_clients"].append(client) - else: - # If client has changed status, update statestore - if status == "online": - self.clients[client]["status"] = "offline" - clients["update_offline_clients"].append(client) + elif status == "online": + self.clients[client]["status"] = "offline" + clients["update_offline_clients"].append(client) # Update statestore with client status if len(clients["update_active_clients"]) > 0: self.statestore.update_client_status(clients["update_active_clients"], "online") @@ -369,7 +366,6 @@ def _send_status(self, status): :param status: the status to report :type status: :class:`fedn.network.grpc.fedn_pb2.Status` """ - self.statestore.report_status(status) def _flush_model_update_queue(self): @@ -377,7 +373,6 @@ def _flush_model_update_queue(self): :return: True if successful, else False """ - q = self.round_handler.aggregator.model_updates try: with q.mutex: @@ -588,7 +583,6 @@ def TaskStream(self, response, context): :param context: the context :type context: :class:`grpc._server._Context` """ - client = response.sender metadata = context.invocation_metadata() if metadata: @@ -643,7 +637,6 @@ def register_model_validation(self, validation): :param validation: the model validation :type validation: :class:`fedn.network.grpc.fedn_pb2.ModelValidation` """ - self.statestore.report_validation(validation) def SendModelValidation(self, request, context): @@ -668,7 +661,6 @@ def SendModelValidation(self, request, context): def run(self): """Start the server.""" - logger.info("COMBINER: {} started, ready for gRPC requests.".format(self.id)) try: while True: diff --git a/fedn/network/combiner/connect.py b/fedn/network/combiner/connect.py index e144baa94..854c8e103 100644 --- a/fedn/network/combiner/connect.py +++ b/fedn/network/combiner/connect.py @@ -67,7 +67,6 @@ def __init__(self, host, port, myhost, fqdn, myport, token, name, secure=False, :param verify: Verify the connection to the discovery service. :type verify: bool """ - self.host = host self.fqdn = fqdn self.port = port @@ -92,8 +91,7 @@ def __init__(self, host, port, myhost, fqdn, myport, token, name, secure=False, logger.info("Setting connection string to {}".format(self.connect_string)) def announce(self): - """ - Announce combiner to FEDn network via discovery service (REST-API). + """Announce combiner to FEDn network via discovery service (REST-API). :return: Tuple with announcement Status, FEDn network configuration if sucessful, else None. :rtype: :class:`fedn.network.combiner.connect.Status`, str diff --git a/fedn/network/combiner/interfaces.py b/fedn/network/combiner/interfaces.py index f247c2bc1..bf10a00f1 100644 --- a/fedn/network/combiner/interfaces.py +++ b/fedn/network/combiner/interfaces.py @@ -113,7 +113,6 @@ def to_dict(self): :return: A dictionary with the combiner configuration. :rtype: dict """ - data = { "parent": self.parent, "name": self.name, @@ -168,7 +167,6 @@ def get_key(self): def flush_model_update_queue(self): """Reset the model update queue on the combiner.""" - channel = Channel(self.address, self.port, self.certificate).get_channel() control = rpc.ControlStub(channel) @@ -188,7 +186,6 @@ def set_aggregator(self, aggregator): :param aggregator: The name of the aggregator module. :type config: str """ - channel = Channel(self.address, self.port, self.certificate).get_channel() control = rpc.ControlStub(channel) @@ -240,7 +237,6 @@ def get_model(self, id, timeout=10): :return: A file-like object containing the model. :rtype: :class:`io.BytesIO`, None if the model is not available. """ - channel = Channel(self.address, self.port, self.certificate).get_channel() modelservice = rpc.ModelServiceStub(channel) diff --git a/fedn/network/combiner/modelservice.py b/fedn/network/combiner/modelservice.py index 0b50edbc7..b5e7bff73 100644 --- a/fedn/network/combiner/modelservice.py +++ b/fedn/network/combiner/modelservice.py @@ -112,7 +112,6 @@ def get_model(self, id): :return: A BytesIO object containing the model. :rtype: :class:`io.BytesIO`, None if model does not exist. """ - data = BytesIO() data.seek(0, 0) diff --git a/fedn/network/combiner/roundhandler.py b/fedn/network/combiner/roundhandler.py index 2a8436e01..4edc04b6e 100644 --- a/fedn/network/combiner/roundhandler.py +++ b/fedn/network/combiner/roundhandler.py @@ -34,7 +34,6 @@ class RoundHandler: def __init__(self, storage, server, modelservice): """Initialize the RoundHandler.""" - self.round_configs = queue.Queue() self.storage = storage self.server = server @@ -67,7 +66,6 @@ def load_model_update(self, helper, model_id): :param model_id: The ID of the model update, UUID in str format :type model_id: str """ - model_str = self.load_model_update_str(model_id) if model_str: try: @@ -119,7 +117,6 @@ def waitforit(self, config, buffer_size=100, polling_interval=0.1): :param polling_interval: The polling interval, defaults to 0.1 :type polling_interval: float, optional """ - time_window = float(config["round_timeout"]) tt = 0.0 @@ -140,7 +137,6 @@ def _training_round(self, config, clients): :return: an aggregated model and associated metadata :rtype: model, dict """ - logger.info("ROUNDHANDLER: Initiating training round, participating clients: {}".format(clients)) meta = {} @@ -208,7 +204,6 @@ def stage_model(self, model_id, timeout_retry=3, retry=2): :param retry: Number of retries, defaults to 2 :type retry: int, optional """ - # If the model is already in memory at the server we do not need to do anything. if self.modelservice.temp_model_storage.exist(model_id): logger.info("Model already exists in memory, skipping model staging.") @@ -241,7 +236,6 @@ def _assign_round_clients(self, n, type="trainers"): :return: Set of clients :rtype: list """ - if type == "validators": clients = self.server.get_active_validators() elif type == "trainers": @@ -269,7 +263,6 @@ def _check_nr_round_clients(self, config): :return: True if the required number of clients are available, False otherwise. :rtype: bool """ - active = self.server.nr_active_trainers() if active >= int(config["clients_required"]): logger.info("Number of clients required ({0}) to start round met {1}.".format(config["clients_required"], active)) @@ -298,7 +291,6 @@ def execute_training_round(self, config): :return: metadata about the training round. :rtype: dict """ - logger.info("Processing training round, job_id {}".format(config["_job_id"])) data = {} diff --git a/fedn/network/controller/control.py b/fedn/network/controller/control.py index 99a59469c..a422383f0 100644 --- a/fedn/network/controller/control.py +++ b/fedn/network/controller/control.py @@ -75,7 +75,6 @@ class Control(ControlBase): def __init__(self, statestore): """Constructor method.""" - super().__init__(statestore) self.name = "DefaultControl" @@ -88,7 +87,6 @@ def session(self, config): :type config: dict """ - if self._state == ReducerState.instructing: logger.info("Controller already in INSTRUCTING state. A session is in progress.") return @@ -140,7 +138,6 @@ def round(self, session_config, round_id): : type round_id: str """ - self.create_round({"round_id": round_id, "status": "Pending"}) if len(self.network.get_combiners()) < 1: @@ -275,7 +272,6 @@ def reduce(self, combiners): : param combiners: dict of combiner names(key) and model IDs(value) to reduce : type combiners: dict """ - meta = {} meta["time_fetch_model"] = 0.0 meta["time_load_model"] = 0.0 @@ -322,7 +318,6 @@ def infer_instruct(self, config): : param config: configuration for the inference round """ - # Check/set instucting state if self.__state == ReducerState.instructing: logger.info("Already set in INSTRUCTING state") @@ -350,7 +345,6 @@ def inference_round(self, config): : param config: configuration for the inference round """ - # Init meta round_data = {} diff --git a/fedn/network/controller/controlbase.py b/fedn/network/controller/controlbase.py index d99bae40a..d667e01c4 100644 --- a/fedn/network/controller/controlbase.py +++ b/fedn/network/controller/controlbase.py @@ -113,17 +113,13 @@ def idle(self): return False def get_model_info(self): - """ - - :return: + """:return: """ return self.statestore.get_model_trail() # TODO: remove use statestore.get_events() instead def get_events(self): - """ - - :return: + """:return: """ return self.statestore.get_events() @@ -139,9 +135,7 @@ def get_latest_round(self): return round def get_compute_package_name(self): - """ - - :return: + """:return: """ definition = self.statestore.get_compute_package() if definition: @@ -159,9 +153,7 @@ def set_compute_package(self, filename, path): self.model_repository.set_compute_package(filename, path) def get_compute_package(self, compute_package=""): - """ - - :param compute_package: + """:param compute_package: :return: """ if compute_package == "": @@ -173,7 +165,6 @@ def get_compute_package(self, compute_package=""): def create_session(self, config, status="Initialized"): """Initialize a new session in backend db.""" - if "session_id" not in config.keys(): session_id = uuid.uuid4() config["session_id"] = str(session_id) @@ -196,7 +187,6 @@ def set_session_status(self, session_id, status): def create_round(self, round_data): """Initialize a new round in backend db.""" - self.statestore.create_round(round_data) def set_round_data(self, round_id, round_data): @@ -251,7 +241,6 @@ def commit(self, model_id, model=None, session_id=None): :param session_id: Unique identifier for the session :type session_id: str """ - helper = self.get_helper() if model is not None: logger.info("Saving model file temporarily to disk...") diff --git a/fedn/network/grpc/__init__.py b/fedn/network/grpc/__init__.py index ad5e023ab..19daa5e47 100644 --- a/fedn/network/grpc/__init__.py +++ b/fedn/network/grpc/__init__.py @@ -1 +1 @@ -__all__ = ['fedn_pb2', 'fedn_pb2_grpc'] +__all__ = ["fedn_pb2", "fedn_pb2_grpc"] diff --git a/fedn/network/grpc/server.py b/fedn/network/grpc/server.py index f953bf96a..4354a7aa5 100644 --- a/fedn/network/grpc/server.py +++ b/fedn/network/grpc/server.py @@ -4,8 +4,7 @@ from grpc_health.v1 import health, health_pb2_grpc import fedn.network.grpc.fedn_pb2_grpc as rpc -from fedn.common.log_config import (logger, set_log_level_from_string, - set_log_stream) +from fedn.common.log_config import logger, set_log_level_from_string, set_log_stream from fedn.network.grpc.auth import JWTInterceptor @@ -14,8 +13,8 @@ class Server: def __init__(self, servicer, modelservicer, config): - set_log_level_from_string(config.get('verbosity', "INFO")) - set_log_stream(config.get('logfile', None)) + set_log_level_from_string(config.get("verbosity", "INFO")) + set_log_stream(config.get("logfile", None)) self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=350), interceptors=[JWTInterceptor()]) self.certificate = None @@ -34,15 +33,15 @@ def __init__(self, servicer, modelservicer, config): health_pb2_grpc.add_HealthServicer_to_server(self.health_servicer, self.server) - if config['secure']: + if config["secure"]: logger.info(f'Creating secure gRPCS server using certificate: {config["certificate"]}') server_credentials = grpc.ssl_server_credentials( - ((config['key'], config['certificate'],),)) + ((config["key"], config["certificate"],),)) self.server.add_secure_port( - '[::]:' + str(config['port']), server_credentials) + "[::]:" + str(config["port"]), server_credentials) else: logger.info("Creating gRPC server") - self.server.add_insecure_port('[::]:' + str(config['port'])) + self.server.add_insecure_port("[::]:" + str(config["port"])) def start(self): """ Start the gRPC server.""" diff --git a/fedn/network/loadbalancer/firstavailable.py b/fedn/network/loadbalancer/firstavailable.py index 13dc766b2..5de8be881 100644 --- a/fedn/network/loadbalancer/firstavailable.py +++ b/fedn/network/loadbalancer/firstavailable.py @@ -13,7 +13,6 @@ def __init__(self, network): def find_combiner(self): """Find the first available combiner.""" - for combiner in self.network.get_combiners(): if combiner.allowing_clients(): return combiner diff --git a/fedn/network/loadbalancer/leastpacked.py b/fedn/network/loadbalancer/leastpacked.py index a762701b0..786dd8de0 100644 --- a/fedn/network/loadbalancer/leastpacked.py +++ b/fedn/network/loadbalancer/leastpacked.py @@ -13,8 +13,7 @@ def __init__(self, network): super().__init__(network) def find_combiner(self): - """ - Find the combiner with the least number of attached clients. + """Find the combiner with the least number of attached clients. """ min_clients = None @@ -25,10 +24,7 @@ def find_combiner(self): if combiner.allowing_clients(): # Using default default Channel = 1, MODEL_UPDATE_REQUESTS nr_active_clients = len(combiner.list_active_clients()) - if not min_clients: - min_clients = nr_active_clients - selected_combiner = combiner - elif nr_active_clients < min_clients: + if not min_clients or nr_active_clients < min_clients: min_clients = nr_active_clients selected_combiner = combiner except CombinerUnavailableError: diff --git a/fedn/network/storage/models/__init__.py b/fedn/network/storage/models/__init__.py index fdfba1986..38135d697 100644 --- a/fedn/network/storage/models/__init__.py +++ b/fedn/network/storage/models/__init__.py @@ -1,4 +1,5 @@ """ The models package handles storing of model updates durign the federated training process. The functionality is used by the combiner service during aggregation of model updates from clients. By implementing the interface in the base class modelstorage.py, a developer may customize the -behaviour of the framework. """ +behaviour of the framework. +""" diff --git a/fedn/network/storage/models/memorymodelstorage.py b/fedn/network/storage/models/memorymodelstorage.py index 6a40a7ae0..54599fb28 100644 --- a/fedn/network/storage/models/memorymodelstorage.py +++ b/fedn/network/storage/models/memorymodelstorage.py @@ -31,9 +31,7 @@ def get(self, model_id): return obj def get_ptr(self, model_id): - """ - - :param model_id: + """:param model_id: :return: """ return self.models[model_id] diff --git a/fedn/network/storage/models/tempmodelstorage.py b/fedn/network/storage/models/tempmodelstorage.py index 214fac4d7..891f6ea07 100644 --- a/fedn/network/storage/models/tempmodelstorage.py +++ b/fedn/network/storage/models/tempmodelstorage.py @@ -41,9 +41,7 @@ def get(self, model_id): return obj def get_ptr(self, model_id): - """ - - :param model_id: + """:param model_id: :return: """ try: diff --git a/fedn/network/storage/s3/__init__.py b/fedn/network/storage/s3/__init__.py index 0befb7819..2e9f7d361 100644 --- a/fedn/network/storage/s3/__init__.py +++ b/fedn/network/storage/s3/__init__.py @@ -1,3 +1,4 @@ """ Module handling storage of objects in S3-compatible object storage. This functionality is used by the controller to store global models in the model trail in persistent storage. Currently implemented for MinIO, but a ' -developer can extend the framwork by implemeting the interface in base.py. """ +developer can extend the framwork by implemeting the interface in base.py. +""" diff --git a/fedn/network/storage/s3/miniorepository.py b/fedn/network/storage/s3/miniorepository.py index 9c86b8997..ff329856a 100644 --- a/fedn/network/storage/s3/miniorepository.py +++ b/fedn/network/storage/s3/miniorepository.py @@ -19,7 +19,6 @@ def __init__(self, config): :param config: Dictionary containing configuration for credentials and bucket names. :type config: dict """ - super().__init__() self.name = "MINIORepository" @@ -91,7 +90,6 @@ def delete_artifact(self, instance_name, bucket): :param bucket: Buckets to delete from :type bucket: str """ - try: self.client.remove_object(bucket, instance_name) except InvalidResponseError as err: diff --git a/fedn/network/storage/s3/repository.py b/fedn/network/storage/s3/repository.py index 18d36cdbb..c1704e5ca 100644 --- a/fedn/network/storage/s3/repository.py +++ b/fedn/network/storage/s3/repository.py @@ -73,7 +73,6 @@ def set_compute_package(self, name, compute_package, is_file=True): :type compute_pacakge: BytesIO or str file name. :param is_file: True if model is a file name, else False """ - try: self.client.set_artifact(str(name), compute_package, bucket=self.context_bucket, is_file=is_file) except Exception: @@ -100,7 +99,6 @@ def delete_compute_package(self, compute_package): :param compute_package: The name of the compute_package :type compute_package: str """ - try: self.client.delete_artifact(compute_package, bucket=[self.context_bucket]) except Exception: diff --git a/fedn/network/storage/statestore/mongostatestore.py b/fedn/network/storage/statestore/mongostatestore.py index a53a6e4d5..6bf3be4ff 100644 --- a/fedn/network/storage/statestore/mongostatestore.py +++ b/fedn/network/storage/statestore/mongostatestore.py @@ -140,7 +140,6 @@ def get_sessions(self, limit=None, skip=None, sort_key="_id", sort_order=pymongo :type sort_order: pymongo.ASCENDING or pymongo.DESCENDING :return: Dictionary of sessions in result (array of session objects) and count. """ - result = None if limit is not None and skip is not None: @@ -175,7 +174,6 @@ def set_latest_model(self, model_id, session_id=None): :type model_id: str :return: """ - committed_at = datetime.now() current_model = self.model.find_one({"key": "current_model"}) parent_model = None @@ -214,7 +212,6 @@ def get_initial_model(self): :return: The initial model id. None if no model is found. :rtype: str """ - result = self.model.find_one({"key": "model_trail"}, sort=[("committed_at", pymongo.ASCENDING)]) if result is None: return None @@ -252,7 +249,6 @@ def set_current_model(self, model_id: str): :type model_id: str :return: """ - try: committed_at = datetime.now() @@ -273,7 +269,6 @@ def get_latest_round(self): :return: The id of the most recent round. :rtype: ObjectId """ - return self.rounds.find_one(sort=[("_id", pymongo.DESCENDING)]) def get_round(self, id): @@ -284,7 +279,6 @@ def get_round(self, id): :return: round with id, reducer and combiners :rtype: ObjectId """ - return self.rounds.find_one({"round_id": str(id)}) def get_rounds(self): @@ -293,7 +287,6 @@ def get_rounds(self): :return: All rounds. :rtype: ObjectId """ - return self.rounds.find() def get_validations(self, **kwargs): @@ -304,7 +297,6 @@ def get_validations(self, **kwargs): :return: validations matching query :rtype: ObjectId """ - result = self.control.validations.find(kwargs) return result @@ -316,7 +308,6 @@ def set_active_compute_package(self, id: str): :return: True if successful. :rtype: bool """ - try: find = {"id": id} projection = {"_id": False, "key": False} @@ -344,7 +335,6 @@ def set_compute_package(self, file_name: str, storage_file_name: str, helper_typ :return: True if successful. :rtype: bool """ - obj = { "file_name": file_name, "storage_file_name": storage_file_name, @@ -396,7 +386,6 @@ def list_compute_packages(self, limit: int = None, skip: int = None, sort_key="c :return: Dictionary of compute packages in result and count. :rtype: dict """ - result = None count = None @@ -544,7 +533,6 @@ def get_model_descendants(self, model_id: str, limit: int): :return: List of model descendants. :rtype: list """ - model: object = self.model.find_one({"key": "models", "model": model_id}) current_model_id: str = model["model"] if model is not None else None result: list = [] @@ -684,7 +672,6 @@ def get_combiners(self, limit=None, skip=None, sort_key="updated_at", sort_order :return: Dictionary of combiners in result and count. :rtype: dict """ - result = None count = None @@ -713,7 +700,6 @@ def set_combiner(self, combiner_data): :type combiner_data: dict :return: """ - combiner_data["updated_at"] = str(datetime.now()) self.combiners.update_one({"name": combiner_data["name"]}, {"$set": combiner_data}, True) @@ -769,7 +755,6 @@ def list_clients(self, limit=None, skip=None, status=None, sort_key="last_seen", :type status: str :param sort_key: The key to sort by. """ - result = None count = None @@ -806,7 +791,6 @@ def list_combiners_data(self, combiners, sort_key="count", sort_order=pymongo.DE :return: list of combiner data. :rtype: list(ObjectId) """ - result = None try: diff --git a/fedn/network/storage/statestore/stores/client_store.py b/fedn/network/storage/statestore/stores/client_store.py index 05c03643a..5797fab7d 100644 --- a/fedn/network/storage/statestore/stores/client_store.py +++ b/fedn/network/storage/statestore/stores/client_store.py @@ -18,16 +18,16 @@ def __init__(self, id: str, name: str, combiner: str, combiner_preferred: str, i self.updated_at = updated_at self.last_seen = last_seen - def from_dict(data: dict) -> 'Client': + def from_dict(data: dict) -> "Client": return Client( - id=str(data['_id']), - name=data['name'] if 'name' in data else None, - combiner=data['combiner'] if 'combiner' in data else None, - combiner_preferred=data['combiner_preferred'] if 'combiner_preferred' in data else None, - ip=data['ip'] if 'ip' in data else None, - status=data['status'] if 'status' in data else None, - updated_at=data['updated_at'] if 'updated_at' in data else None, - last_seen=data['last_seen'] if 'last_seen' in data else None + id=str(data["_id"]), + name=data["name"] if "name" in data else None, + combiner=data["combiner"] if "combiner" in data else None, + combiner_preferred=data["combiner_preferred"] if "combiner_preferred" in data else None, + ip=data["ip"] if "ip" in data else None, + status=data["status"] if "status" in data else None, + updated_at=data["updated_at"] if "updated_at" in data else None, + last_seen=data["last_seen"] if "last_seen" in data else None ) @@ -74,7 +74,7 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI """ response = super().list(limit, skip, sort_key or "last_seen", sort_order, use_typing=use_typing, **kwargs) - result = [Client.from_dict(item) for item in response['result']] if use_typing else response['result'] + result = [Client.from_dict(item) for item in response["result"]] if use_typing else response["result"] return { "count": response["count"], diff --git a/fedn/network/storage/statestore/stores/combiner_store.py b/fedn/network/storage/statestore/stores/combiner_store.py index d47386e8a..02495b66f 100644 --- a/fedn/network/storage/statestore/stores/combiner_store.py +++ b/fedn/network/storage/statestore/stores/combiner_store.py @@ -38,20 +38,20 @@ def __init__( self.status = status self.updated_at = updated_at - def from_dict(data: dict) -> 'Combiner': + def from_dict(data: dict) -> "Combiner": return Combiner( - id=str(data['_id']), - name=data['name'] if 'name' in data else None, - address=data['address'] if 'address' in data else None, - certificate=data['certificate'] if 'certificate' in data else None, - config=data['config'] if 'config' in data else None, - fqdn=data['fqdn'] if 'fqdn' in data else None, - ip=data['ip'] if 'ip' in data else None, - key=data['key'] if 'key' in data else None, - parent=data['parent'] if 'parent' in data else None, - port=data['port'] if 'port' in data else None, - status=data['status'] if 'status' in data else None, - updated_at=data['updated_at'] if 'updated_at' in data else None + id=str(data["_id"]), + name=data["name"] if "name" in data else None, + address=data["address"] if "address" in data else None, + certificate=data["certificate"] if "certificate" in data else None, + config=data["config"] if "config" in data else None, + fqdn=data["fqdn"] if "fqdn" in data else None, + ip=data["ip"] if "ip" in data else None, + key=data["key"] if "key" in data else None, + parent=data["parent"] if "parent" in data else None, + port=data["port"] if "port" in data else None, + status=data["status"] if "status" in data else None, + updated_at=data["updated_at"] if "updated_at" in data else None ) @@ -70,9 +70,9 @@ def get(self, id: str, use_typing: bool = False) -> Combiner: """ if ObjectId.is_valid(id): id_obj = ObjectId(id) - document = self.database[self.collection].find_one({'_id': id_obj}) + document = self.database[self.collection].find_one({"_id": id_obj}) else: - document = self.database[self.collection].find_one({'name': id}) + document = self.database[self.collection].find_one({"name": id}) if document is None: raise EntityNotFound(f"Entity with (id | name) {id} not found") @@ -107,7 +107,7 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI """ response = super().list(limit, skip, sort_key or "updated_at", sort_order, use_typing=use_typing, **kwargs) - result = [Combiner.from_dict(item) for item in response['result']] if use_typing else response['result'] + result = [Combiner.from_dict(item) for item in response["result"]] if use_typing else response["result"] return { "count": response["count"], diff --git a/fedn/network/storage/statestore/stores/model_store.py b/fedn/network/storage/statestore/stores/model_store.py index f72beefa7..172603405 100644 --- a/fedn/network/storage/statestore/stores/model_store.py +++ b/fedn/network/storage/statestore/stores/model_store.py @@ -19,14 +19,14 @@ def __init__(self, id: str, key: str, model: str, parent_model: str, session_id: self.session_id = session_id self.committed_at = committed_at - def from_dict(data: dict) -> 'Model': + def from_dict(data: dict) -> "Model": return Model( - id=str(data['_id']), - key=data['key'] if 'key' in data else None, - model=data['model'] if 'model' in data else None, - parent_model=data['parent_model'] if 'parent_model' in data else None, - session_id=data['session_id'] if 'session_id' in data else None, - committed_at=data['committed_at'] if 'committed_at' in data else None + id=str(data["_id"]), + key=data["key"] if "key" in data else None, + model=data["model"] if "model" in data else None, + parent_model=data["parent_model"] if "parent_model" in data else None, + session_id=data["session_id"] if "session_id" in data else None, + committed_at=data["committed_at"] if "committed_at" in data else None ) @@ -46,9 +46,9 @@ def get(self, id: str, use_typing: bool = False) -> Model: kwargs = {"key": "models"} if ObjectId.is_valid(id): id_obj = ObjectId(id) - kwargs['_id'] = id_obj + kwargs["_id"] = id_obj else: - kwargs['model'] = id + kwargs["model"] = id document = self.database[self.collection].find_one(kwargs) @@ -83,13 +83,13 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI example: {"key": "models"} return: A dictionary with the count and the result """ - kwargs['key'] = "models" + kwargs["key"] = "models" response = super().list(limit, skip, sort_key or "committed_at", sort_order, use_typing=use_typing, **kwargs) - result = [Model.from_dict(item) for item in response['result']] if use_typing else response['result'] + result = [Model.from_dict(item) for item in response["result"]] if use_typing else response["result"] return { - "count": response['count'], + "count": response["count"], "result": result } @@ -107,9 +107,9 @@ def list_descendants(self, id: str, limit: int, use_typing: bool = False) -> Lis kwargs = {"key": "models"} if ObjectId.is_valid(id): id_obj = ObjectId(id) - kwargs['_id'] = id_obj + kwargs["_id"] = id_obj else: - kwargs['model'] = id + kwargs["model"] = id model: object = self.database[self.collection].find_one(kwargs) @@ -150,9 +150,9 @@ def list_ancestors(self, id: str, limit: int, include_self: bool = False, revers kwargs = {"key": "models"} if ObjectId.is_valid(id): id_obj = ObjectId(id) - kwargs['_id'] = id_obj + kwargs["_id"] = id_obj else: - kwargs['model'] = id + kwargs["model"] = id model: object = self.database[self.collection].find_one(kwargs) @@ -191,5 +191,5 @@ def count(self, **kwargs) -> int: example: {"key": "models"} return: The count (int) """ - kwargs['key'] = "models" + kwargs["key"] = "models" return super().count(**kwargs) diff --git a/fedn/network/storage/statestore/stores/package_store.py b/fedn/network/storage/statestore/stores/package_store.py index 423b8716a..eb7154af2 100644 --- a/fedn/network/storage/statestore/stores/package_store.py +++ b/fedn/network/storage/statestore/stores/package_store.py @@ -32,21 +32,21 @@ def __init__( self.storage_file_name = storage_file_name self.active = active - def from_dict(data: dict, active_package: dict) -> 'Package': + def from_dict(data: dict, active_package: dict) -> "Package": active = False if active_package: if "id" in active_package and "id" in data: active = active_package["id"] == data["id"] return Package( - id=data['id'] if 'id' in data else None, - key=data['key'] if 'key' in data else None, - committed_at=data['committed_at'] if 'committed_at' in data else None, - description=data['description'] if 'description' in data else None, - file_name=data['file_name'] if 'file_name' in data else None, - helper=data['helper'] if 'helper' in data else None, - name=data['name'] if 'name' in data else None, - storage_file_name=data['storage_file_name'] if 'storage_file_name' in data else None, + id=data["id"] if "id" in data else None, + key=data["key"] if "key" in data else None, + committed_at=data["committed_at"] if "committed_at" in data else None, + description=data["description"] if "description" in data else None, + file_name=data["file_name"] if "file_name" in data else None, + helper=data["helper"] if "helper" in data else None, + name=data["name"] if "name" in data else None, + storage_file_name=data["storage_file_name"] if "storage_file_name" in data else None, active=active ) @@ -66,7 +66,7 @@ def get(self, id: str, use_typing: bool = False) -> Package: If True, and active property will be set based on the active package. return: The entity """ - document = self.database[self.collection].find_one({'id': id}) + document = self.database[self.collection].find_one({"id": id}) if document is None: raise EntityNotFound(f"Entity with id {id} not found") @@ -74,7 +74,7 @@ def get(self, id: str, use_typing: bool = False) -> Package: if not use_typing: return from_document(document) - response_active = self.database[self.collection].find_one({'key': 'active'}) + response_active = self.database[self.collection].find_one({"key": "active"}) return Package.from_dict(document, response_active) @@ -84,7 +84,7 @@ def get_active(self, use_typing: bool = False) -> Package: type: bool return: The entity """ - response = self.database[self.collection].find_one({'key': 'active'}) + response = self.database[self.collection].find_one({"key": "active"}) if response is None: raise EntityNotFound(f"Entity with id {id} not found") @@ -123,9 +123,9 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI response = super().list(limit, skip, sort_key or "committed_at", sort_order, use_typing=True, **kwargs) - response_active = self.database[self.collection].find_one({'key': 'active'}) + response_active = self.database[self.collection].find_one({"key": "active"}) - result = [Package.from_dict(item, response_active) for item in response['result']] + result = [Package.from_dict(item, response_active) for item in response["result"]] return { "count": response["count"], diff --git a/fedn/network/storage/statestore/stores/round_store.py b/fedn/network/storage/statestore/stores/round_store.py index 5afde0d7e..03af044c3 100644 --- a/fedn/network/storage/statestore/stores/round_store.py +++ b/fedn/network/storage/statestore/stores/round_store.py @@ -15,14 +15,14 @@ def __init__(self, id: str, round_id: str, status: str, round_config: dict, comb self.combiners = combiners self.round_data = round_data - def from_dict(data: dict) -> 'Round': + def from_dict(data: dict) -> "Round": return Round( - id=str(data['_id']), - round_id=data['round_id'] if 'round_id' in data else None, - status=data['status'] if 'status' in data else None, - round_config=data['round_config'] if 'round_config' in data else None, - combiners=data['combiners'] if 'combiners' in data else None, - round_data=data['round_data'] if 'round_data' in data else None + id=str(data["_id"]), + round_id=data["round_id"] if "round_id" in data else None, + status=data["status"] if "status" in data else None, + round_config=data["round_config"] if "round_config" in data else None, + combiners=data["combiners"] if "combiners" in data else None, + round_data=data["round_data"] if "round_data" in data else None ) @@ -70,7 +70,7 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI """ response = super().list(limit, skip, sort_key or "round_id", sort_order, use_typing=use_typing, **kwargs) - result = [Round.from_dict(item) for item in response['result']] if use_typing else response['result'] + result = [Round.from_dict(item) for item in response["result"]] if use_typing else response["result"] return { "count": response["count"], diff --git a/fedn/network/storage/statestore/stores/session_store.py b/fedn/network/storage/statestore/stores/session_store.py index c0a6b7da8..31c5e25b4 100644 --- a/fedn/network/storage/statestore/stores/session_store.py +++ b/fedn/network/storage/statestore/stores/session_store.py @@ -16,12 +16,12 @@ def __init__(self, id: str, session_id: str, status: str, session_config: dict = self.status = status self.session_config = session_config - def from_dict(data: dict) -> 'Session': + def from_dict(data: dict) -> "Session": return Session( - id=str(data['_id']), - session_id=data['session_id'] if 'session_id' in data else None, - status=data['status'] if 'status' in data else None, - session_config=data['session_config'] if 'session_config' in data else None + id=str(data["_id"]), + session_id=data["session_id"] if "session_id" in data else None, + status=data["status"] if "status" in data else None, + session_config=data["session_config"] if "session_config" in data else None ) @@ -40,9 +40,9 @@ def get(self, id: str, use_typing: bool = False) -> Session: """ if ObjectId.is_valid(id): id_obj = ObjectId(id) - document = self.database[self.collection].find_one({'_id': id_obj}) + document = self.database[self.collection].find_one({"_id": id_obj}) else: - document = self.database[self.collection].find_one({'session_id': id}) + document = self.database[self.collection].find_one({"session_id": id}) if document is None: raise EntityNotFound(f"Entity with (id | session_id) {id} not found") @@ -82,7 +82,7 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI """ response = super().list(limit, skip, sort_key or "session_id", sort_order, use_typing=use_typing, **kwargs) - result = [Session.from_dict(item) for item in response['result']] if use_typing else response['result'] + result = [Session.from_dict(item) for item in response["result"]] if use_typing else response["result"] return { "count": response["count"], diff --git a/fedn/network/storage/statestore/stores/shared.py b/fedn/network/storage/statestore/stores/shared.py index bf74296af..1ccce636e 100644 --- a/fedn/network/storage/statestore/stores/shared.py +++ b/fedn/network/storage/statestore/stores/shared.py @@ -1,7 +1,7 @@ def from_document(document: dict) -> dict: - document['id'] = str(document['_id']) - del document['_id'] + document["id"] = str(document["_id"]) + del document["_id"] return document diff --git a/fedn/network/storage/statestore/stores/status_store.py b/fedn/network/storage/statestore/stores/status_store.py index 4a62fc9bf..73fa8e588 100644 --- a/fedn/network/storage/statestore/stores/status_store.py +++ b/fedn/network/storage/statestore/stores/status_store.py @@ -31,18 +31,18 @@ def __init__( self.session_id = session_id self.sender = sender - def from_dict(data: dict) -> 'Status': + def from_dict(data: dict) -> "Status": return Status( - id=str(data['_id']), - status=data['status'] if 'status' in data else None, - timestamp=data['timestamp'] if 'timestamp' in data else None, - log_level=data['logLevel'] if 'logLevel' in data else None, - data=data['data'] if 'data' in data else None, - correlation_id=data['correlationId'] if 'correlationId' in data else None, - type=data['type'] if 'type' in data else None, - extra=data['extra'] if 'extra' in data else None, - session_id=data['sessionId'] if 'sessionId' in data else None, - sender=data['sender'] if 'sender' in data else None + id=str(data["_id"]), + status=data["status"] if "status" in data else None, + timestamp=data["timestamp"] if "timestamp" in data else None, + log_level=data["logLevel"] if "logLevel" in data else None, + data=data["data"] if "data" in data else None, + correlation_id=data["correlationId"] if "correlationId" in data else None, + type=data["type"] if "type" in data else None, + extra=data["extra"] if "extra" in data else None, + session_id=data["sessionId"] if "sessionId" in data else None, + sender=data["sender"] if "sender" in data else None ) @@ -91,6 +91,6 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI """ response = super().list(limit, skip, sort_key or "timestamp", sort_order, use_typing=use_typing, **kwargs) - result = [Status.from_dict(item) for item in response['result']] if use_typing else response['result'] + result = [Status.from_dict(item) for item in response["result"]] if use_typing else response["result"] - return {'count': response['count'], 'result': result} + return {"count": response["count"], "result": result} diff --git a/fedn/network/storage/statestore/stores/store.py b/fedn/network/storage/statestore/stores/store.py index 72a7de6e4..8334ac6b8 100644 --- a/fedn/network/storage/statestore/stores/store.py +++ b/fedn/network/storage/statestore/stores/store.py @@ -6,7 +6,7 @@ from .shared import EntityNotFound, from_document -T = TypeVar('T') +T = TypeVar("T") class Store(Generic[T]): @@ -23,7 +23,7 @@ def get(self, id: str, use_typing: bool = False) -> T: return: The entity """ id_obj = ObjectId(id) - document = self.database[self.collection].find_one({'_id': id_obj}) + document = self.database[self.collection].find_one({"_id": id_obj}) if document is None: raise EntityNotFound(f"Entity with id {id} not found") diff --git a/fedn/network/storage/statestore/stores/validation_store.py b/fedn/network/storage/statestore/stores/validation_store.py index 4e09072b1..a64d1a41e 100644 --- a/fedn/network/storage/statestore/stores/validation_store.py +++ b/fedn/network/storage/statestore/stores/validation_store.py @@ -29,17 +29,17 @@ def __init__( self.sender = sender self.receiver = receiver - def from_dict(data: dict) -> 'Validation': + def from_dict(data: dict) -> "Validation": return Validation( - id=str(data['_id']), - model_id=data['modelId'] if 'modelId' in data else None, - data=data['data'] if 'data' in data else None, - correlation_id=data['correlationId'] if 'correlationId' in data else None, - timestamp=data['timestamp'] if 'timestamp' in data else None, - session_id=data['sessionId'] if 'sessionId' in data else None, - meta=data['meta'] if 'meta' in data else None, - sender=data['sender'] if 'sender' in data else None, - receiver=data['receiver'] if 'receiver' in data else None + id=str(data["_id"]), + model_id=data["modelId"] if "modelId" in data else None, + data=data["data"] if "data" in data else None, + correlation_id=data["correlationId"] if "correlationId" in data else None, + timestamp=data["timestamp"] if "timestamp" in data else None, + session_id=data["sessionId"] if "sessionId" in data else None, + meta=data["meta"] if "meta" in data else None, + sender=data["sender"] if "sender" in data else None, + receiver=data["receiver"] if "receiver" in data else None ) @@ -89,8 +89,8 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI """ response = super().list(limit, skip, sort_key or "timestamp", sort_order, use_typing=use_typing, **kwargs) - result = [Validation.from_dict(item) for item in response['result']] if use_typing else response['result'] + result = [Validation.from_dict(item) for item in response["result"]] if use_typing else response["result"] return { - "count": response['count'], + "count": response["count"], "result": result } diff --git a/fedn/utils/dispatcher.py b/fedn/utils/dispatcher.py index 5d00021b1..d551b8053 100644 --- a/fedn/utils/dispatcher.py +++ b/fedn/utils/dispatcher.py @@ -1,5 +1,4 @@ -""" -Portions of this code are derived from the Apache 2.0 licensed project mlflow (https://mlflow.org/)., +"""Portions of this code are derived from the Apache 2.0 licensed project mlflow (https://mlflow.org/)., with modifications made by Scaleout Systems AB. Copyright (c) 2018 Databricks, Inc. @@ -60,15 +59,13 @@ def _install_python(version, pyenv_root=None, capture_output=False): def _is_virtualenv_available(): - """ - Returns True if virtualenv is available, otherwise False. + """Returns True if virtualenv is available, otherwise False. """ return shutil.which("virtualenv") is not None def _validate_virtualenv_is_available(): - """ - Validates virtualenv is available. If not, throws an `Exception` with a brief instruction + """Validates virtualenv is available. If not, throws an `Exception` with a brief instruction on how to install virtualenv. """ if not _is_virtualenv_available(): @@ -85,8 +82,7 @@ def _get_virtualenv_extra_env_vars(env_root_dir=None): def _get_python_env(python_env_file): - """ - Parses a python environment file and returns a dictionary with the parsed content. + """Parses a python environment file and returns a dictionary with the parsed content. """ if os.path.exists(python_env_file): return _PythonEnv.from_yaml(python_env_file) diff --git a/fedn/utils/environment.py b/fedn/utils/environment.py index 754cb5312..03d93eae7 100644 --- a/fedn/utils/environment.py +++ b/fedn/utils/environment.py @@ -1,5 +1,4 @@ -""" -Portions of this code are derived from the Apache 2.0 licensed project mlflow (https://mlflow.org/)., +"""Portions of this code are derived from the Apache 2.0 licensed project mlflow (https://mlflow.org/)., with modifications made by Scaleout Systems AB. Copyright (c) 2018 Databricks, Inc. @@ -28,8 +27,7 @@ class _PythonEnv: BUILD_PACKAGES = ("pip", "setuptools", "wheel") def __init__(self, name=None, python=None, build_dependencies=None, dependencies=None): - """ - Represents environment information for FEDn compute packages. + """Represents environment information for FEDn compute packages. Args: ---- diff --git a/fedn/utils/helpers/helperbase.py b/fedn/utils/helpers/helperbase.py index 3377d0336..109ab2a47 100644 --- a/fedn/utils/helpers/helperbase.py +++ b/fedn/utils/helpers/helperbase.py @@ -8,7 +8,6 @@ class HelperBase(ABC): def __init__(self): """Initialize helper.""" - self.name = self.__class__.__name__ @abstractmethod diff --git a/fedn/utils/helpers/plugins/numpyhelper.py b/fedn/utils/helpers/plugins/numpyhelper.py index ce6c29420..822ce929e 100644 --- a/fedn/utils/helpers/plugins/numpyhelper.py +++ b/fedn/utils/helpers/plugins/numpyhelper.py @@ -25,7 +25,6 @@ def increment_average(self, m1, m2, n, N): :return: Updated incremental weighted average. :rtype: list of numpy ndarray """ - return [np.add(x, n * (y - x) / N) for x, y in zip(m1, m2)] def add(self, m1, m2, a=1.0, b=1.0): @@ -38,7 +37,6 @@ def add(self, m1, m2, a=1.0, b=1.0): :return: Incremental weighted average of model weights. :rtype: list of ndarrays """ - return [x * a + y * b for x, y in zip(m1, m2)] def subtract(self, m1, m2, a=1.0, b=1.0): @@ -63,7 +61,6 @@ def divide(self, m1, m2): :return: m1/m2. :rtype: list of ndarrays """ - return [np.divide(x, y) for x, y in zip(m1, m2)] def multiply(self, m1, m2): @@ -76,7 +73,6 @@ def multiply(self, m1, m2): :return: m1.*m2 :rtype: list of ndarrays """ - return [np.multiply(x, y) for (x, y) in zip(m1, m2)] def sqrt(self, m1): @@ -89,7 +85,6 @@ def sqrt(self, m1): :return: sqrt(m1) :rtype: list of ndarrays """ - return [np.sqrt(x) for x in m1] def power(self, m1, a): @@ -102,7 +97,6 @@ def power(self, m1, a): :return: m1.^m2 :rtype: list of ndarrays """ - return [np.power(x, a) for x in m1] def norm(self, m): @@ -126,7 +120,6 @@ def sign(self, m): :return: sign(m) :rtype: list of ndarrays """ - return [np.sign(x) for x in m] def ones(self, m1, a): @@ -139,7 +132,6 @@ def ones(self, m1, a): :return: list of numpy arrays of the same shape as m1, filled with ones. :rtype: list of ndarrays """ - res = [] for x in m1: res.append(np.ones(np.shape(x)) * a) diff --git a/fedn/utils/plots.py b/fedn/utils/plots.py index 7901e2374..d04fffc4e 100644 --- a/fedn/utils/plots.py +++ b/fedn/utils/plots.py @@ -32,7 +32,6 @@ def __init__(self, statestore): # plot metrics from DB def _scalar_metrics(self, metrics): """Extract all scalar valued metrics from a MODEL_VALIDATON.""" - data = json.loads(metrics["data"]) data = json.loads(data["data"]) @@ -48,9 +47,7 @@ def _scalar_metrics(self, metrics): return valid_metrics def create_table_plot(self): - """ - - :return: + """:return: """ metrics = self.status.find_one({"type": "MODEL_VALIDATION"}) if metrics is None: @@ -111,9 +108,7 @@ def create_table_plot(self): return table def create_timeline_plot(self): - """ - - :return: + """:return: """ trace_data = [] x = [] @@ -184,9 +179,7 @@ def create_timeline_plot(self): return timeline def create_client_training_distribution(self): - """ - - :return: + """:return: """ training = [] for p in self.status.find({"type": "MODEL_UPDATE"}): @@ -202,9 +195,7 @@ def create_client_training_distribution(self): return histogram def create_client_histogram_plot(self): - """ - - :return: + """:return: """ training = [] for p in self.status.find({"type": "MODEL_UPDATE"}): @@ -230,9 +221,7 @@ def create_client_histogram_plot(self): return histogram_plot def create_client_plot(self): - """ - - :return: + """:return: """ processing = [] upload = [] @@ -258,9 +247,7 @@ def create_client_plot(self): return client_plot def create_combiner_plot(self): - """ - - :return: + """:return: """ waiting = [] aggregation = [] @@ -292,18 +279,14 @@ def create_combiner_plot(self): return combiner_plot def fetch_valid_metrics(self): - """ - - :return: + """:return: """ metrics = self.status.find_one({"type": "MODEL_VALIDATION"}) valid_metrics = self._scalar_metrics(metrics) return valid_metrics def create_box_plot(self, metric): - """ - - :param metric: + """:param metric: :return: """ metrics = self.status.find_one({"type": "MODEL_VALIDATION"}) @@ -361,9 +344,7 @@ def create_box_plot(self, metric): return box def create_round_plot(self): - """ - - :return: + """:return: """ trace_data = [] metrics = self.round_time.find_one({"key": "round_time"}) @@ -391,9 +372,7 @@ def create_round_plot(self): return round_t def create_cpu_plot(self): - """ - - :return: + """:return: """ metrics = self.psutil_usage.find_one({"key": "cpu_mem_usage"}) if metrics is None: diff --git a/fedn/utils/process.py b/fedn/utils/process.py index 1a30fca2c..c2574a760 100644 --- a/fedn/utils/process.py +++ b/fedn/utils/process.py @@ -1,5 +1,4 @@ -""" -Portions of this code are derived from the Apache 2.0 licensed project mlflow (https://mlflow.org/)., +"""Portions of this code are derived from the Apache 2.0 licensed project mlflow (https://mlflow.org/)., with modifications made by Scaleout Systems AB. Copyright (c) 2018 Databricks, Inc. diff --git a/pyproject.toml b/pyproject.toml index 24233b9f0..7e3484a8f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -177,21 +177,8 @@ lint.ignore = [ "PLW1508", # Invalid type for environment variable default; expected `str` or `None` "B007", # Loop control variable `v` not used within loop body "N806", # Variable `X_test` in function should be lowercase - - # solved with --fix - "Q000", # [*] Single quotes found but double quotes preferred - "D212", # [*] Multi-line docstring summary should start at the first line - "D213", # [*] Multi-line docstring summary should start at the second line - "D202", # [*] No blank lines allowed after function docstring (found 1) - "D209", # [*] Multi-line docstring closing quotes should be on a separate line - "D204", # [*] 1 blank line required after class docstring - "SIM114", # [*] Combine `if` branches using logical `or` operator - "D208", # [*] Docstring is over-indented - "I001", # [*] Import block is un-sorted or un-formatted "SIM103", # Return the condition directly - "PLR5501", # [*] Use `elif` instead of `else` then `if`, to reduce indentation - "RET501", # [*] Do not explicitly `return None` in function if it is the only possible return value - "PLW0120", # [*] `else` clause on loop without a `break` statement; remove the `else` and dedent its contents + "I001", # [*] Import block is un-sorted or un-formatted # unsafe? "S104", # Possible binding to all interfaces From c6bc269332808294146568827bbaae55005e5783 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonas=20Frankem=C3=B6lle?= <48800769+FrankJonasmoelle@users.noreply.github.com> Date: Mon, 13 May 2024 14:41:12 +0200 Subject: [PATCH 5/5] Feature/SK-831 | Self-supervised Learning Example (#603) --- examples/FedSimSiam/.dockerignore | 4 + examples/FedSimSiam/.gitignore | 6 + examples/FedSimSiam/README.rst | 125 +++++++++++++++ examples/FedSimSiam/client/data.py | 150 ++++++++++++++++++ examples/FedSimSiam/client/fedn.yaml | 10 ++ examples/FedSimSiam/client/model.py | 144 +++++++++++++++++ examples/FedSimSiam/client/monitoring.py | 62 ++++++++ examples/FedSimSiam/client/python_env.yaml | 9 ++ examples/FedSimSiam/client/train.py | 129 +++++++++++++++ examples/FedSimSiam/client/utils.py | 78 +++++++++ examples/FedSimSiam/client/validate.py | 63 ++++++++ .../FedSimSiam/docker-compose.override.yaml | 35 ++++ 12 files changed, 815 insertions(+) create mode 100644 examples/FedSimSiam/.dockerignore create mode 100644 examples/FedSimSiam/.gitignore create mode 100644 examples/FedSimSiam/README.rst create mode 100644 examples/FedSimSiam/client/data.py create mode 100644 examples/FedSimSiam/client/fedn.yaml create mode 100644 examples/FedSimSiam/client/model.py create mode 100644 examples/FedSimSiam/client/monitoring.py create mode 100644 examples/FedSimSiam/client/python_env.yaml create mode 100644 examples/FedSimSiam/client/train.py create mode 100644 examples/FedSimSiam/client/utils.py create mode 100644 examples/FedSimSiam/client/validate.py create mode 100644 examples/FedSimSiam/docker-compose.override.yaml diff --git a/examples/FedSimSiam/.dockerignore b/examples/FedSimSiam/.dockerignore new file mode 100644 index 000000000..8ba9024ad --- /dev/null +++ b/examples/FedSimSiam/.dockerignore @@ -0,0 +1,4 @@ +data +seed.npz +*.tgz +*.tar.gz \ No newline at end of file diff --git a/examples/FedSimSiam/.gitignore b/examples/FedSimSiam/.gitignore new file mode 100644 index 000000000..047341d71 --- /dev/null +++ b/examples/FedSimSiam/.gitignore @@ -0,0 +1,6 @@ +data +*.npz +*.tgz +*.tar.gz +.fedsimsiam +client.yaml \ No newline at end of file diff --git a/examples/FedSimSiam/README.rst b/examples/FedSimSiam/README.rst new file mode 100644 index 000000000..54434c6dc --- /dev/null +++ b/examples/FedSimSiam/README.rst @@ -0,0 +1,125 @@ +FEDn Project: FedSimSiam on CIFAR-10 +------------------------------------ + +This is an example FEDn Project that runs the federated self-supervised learning algorithm FedSimSiam on +the CIFAR-10 dataset. This is a standard example often used for benchmarking. To be able to run this example, you +need to have GPU access. + + **Note: We recommend all new users to start by following the Quickstart Tutorial: https://fedn.readthedocs.io/en/stable/quickstart.html** + +Prerequisites +------------- + +- `Python 3.8, 3.9, 3.10 or 3.11 `__ +- `A FEDn Studio account `__ +- Change the dependencies in the 'client/python_env.yaml' file to match your cuda version. + +Creating the compute package and seed model +------------------------------------------- + +Install fedn: + +.. code-block:: + + pip install fedn + +Clone this repository, then locate into this directory: + +.. code-block:: + + git clone https://github.com/scaleoutsystems/fedn.git + cd fedn/examples/FedSimSiam + +Create the compute package: + +.. code-block:: + + fedn package create --path client + +This should create a file 'package.tgz' in the project folder. + +Next, generate a seed model (the first model in a global model trail): + +.. code-block:: + + fedn run build --path client + +This will create a seed model called 'seed.npz' in the root of the project. This step will take a few minutes, depending on hardware and internet connection (builds a virtualenv). + +Using FEDn Studio +----------------- + +Follow the instructions to register for FEDN Studio and start a project (https://fedn.readthedocs.io/en/stable/studio.html). + +In your Studio project: + +- Go to the 'Sessions' menu, click on 'New session', and upload the compute package (package.tgz) and seed model (seed.npz). +- In the 'Clients' menu, click on 'Connect client' and download the client configuration file (client.yaml) +- Save the client configuration file to the FedSimSiam example directory (fedn/examples/FedSimSiam) + +To connect a client, run the following command in your terminal: + +.. code-block:: + + fedn client start -in client.yaml --secure=True --force-ssl + + +Running the example +------------------- + +After everything is set up, go to 'Sessions' and click on 'New Session'. Click on 'Start run' and the example will execute. You can follow the training progress on 'Events' and 'Models', where you +can monitor the training progress. The monitoring is done using a kNN classifier that is fitted on the feature embeddings of the training images that are obtained by +FedSimSiam's encoder, and evaluated on the feature embeddings of the test images. This process is repeated after each training round. + +This is a common method to track FedSimSiam's training progress, as FedSimSiam aims to minimize the distance between the embeddings of similar images. +A high accuracy implies that the feature embeddings for images within the same class are indeed close to each other in the +embedding space, i.e., FedSimSiam learned useful feature embeddings. + + +Running FEDn in local development mode: +--------------------------------------- + +Follow the steps above to install FEDn, generate 'package.tgz' and 'seed.tgz'. + +Start a pseudo-distributed FEDn network using docker-compose: +.. code-block:: + + docker compose \ + -f ../../docker-compose.yaml \ + -f docker-compose.override.yaml \ + up + +This starts up local services for MongoDB, Minio, the API Server, one Combiner and two clients. +You can verify the deployment using these urls: + +- API Server: http://localhost:8092/get_controller_status +- Minio: http://localhost:9000 +- Mongo Express: http://localhost:8081 + +Upload the package and seed model to FEDn controller using the APIClient: + +.. code-block:: + + from fedn import APIClient + client = APIClient(host="localhost", port=8092) + client.set_active_package("package.tgz", helper="numpyhelper") + client.set_active_model("seed.npz") + + +You can now start a training session with 100 rounds using the API client: + +.. code-block:: + + client.start_session(rounds=100) + +Clean up +-------- + +You can clean up by running + +.. code-block:: + + docker-compose \ + -f ../../docker-compose.yaml \ + -f docker-compose.override.yaml \ + down -v diff --git a/examples/FedSimSiam/client/data.py b/examples/FedSimSiam/client/data.py new file mode 100644 index 000000000..95b10e7db --- /dev/null +++ b/examples/FedSimSiam/client/data.py @@ -0,0 +1,150 @@ +import os +from math import floor + +import numpy as np +import torch +import torchvision +from torchvision import transforms + +dir_path = os.path.dirname(os.path.realpath(__file__)) +abs_path = os.path.abspath(dir_path) + + +def get_data(out_dir="data"): + # Make dir if necessary + if not os.path.exists(out_dir): + os.mkdir(out_dir) + + # Only download if not already downloaded + if not os.path.exists(f"{out_dir}/train"): + torchvision.datasets.CIFAR10( + root=f"{out_dir}/train", train=True, download=True) + + if not os.path.exists(f"{out_dir}/test"): + torchvision.datasets.CIFAR10( + root=f"{out_dir}/test", train=False, download=True) + + +def load_data(data_path, is_train=True): + """ Load data from disk. + + :param data_path: Path to data file. + :type data_path: str + :param is_train: Whether to load training or test data. + :type is_train: bool + :return: Tuple of data and labels. + :rtype: tuple + """ + if data_path is None: + data_path = os.environ.get( + "FEDN_DATA_PATH", abs_path+"/data/clients/1/cifar10.pt") + + data = torch.load(data_path) + + if is_train: + X = data["x_train"] + y = data["y_train"] + else: + X = data["x_test"] + y = data["y_test"] + + return X, y + + +def create_knn_monitoring_dataset(out_dir="data"): + """ Creates dataset that is used to monitor the training progress via knn accuracies """ + if not os.path.exists(out_dir): + os.mkdir(out_dir) + + n_splits = int(os.environ.get("FEDN_NUM_DATA_SPLITS", 2)) + + # Make dir + if not os.path.exists(f"{out_dir}/clients"): + os.mkdir(f"{out_dir}/clients") + + normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], + std=[0.247, 0.243, 0.261]) + + memoryset = torchvision.datasets.CIFAR10(root="./data", train=True, + download=True, transform=transforms.Compose([transforms.ToTensor(), normalize])) + testset = torchvision.datasets.CIFAR10(root="./data", train=False, + download=True, transform=transforms.Compose([transforms.ToTensor(), normalize])) + + # save monitoring datasets to all clients + for i in range(n_splits): + subdir = f"{out_dir}/clients/{str(i+1)}" + if not os.path.exists(subdir): + os.mkdir(subdir) + torch.save(memoryset, f"{subdir}/knn_memoryset.pt") + torch.save(testset, f"{subdir}/knn_testset.pt") + + +def load_knn_monitoring_dataset(data_path, batch_size=16): + """ Loads the KNN monitoring dataset.""" + if data_path is None: + data_path = os.environ.get( + "FEDN_DATA_PATH", abs_path+"/data/clients/1/cifar10.pt") + + data_directory = os.path.dirname(data_path) + memory_path = os.path.join(data_directory, "knn_memoryset.pt") + testset_path = os.path.join(data_directory, "knn_testset.pt") + + memoryset = torch.load(memory_path) + testset = torch.load(testset_path) + + memoryset_loader = torch.utils.data.DataLoader( + memoryset, batch_size=batch_size, shuffle=False) + testset_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, + shuffle=False) + return memoryset_loader, testset_loader + + +def splitset(dataset, parts): + n = dataset.shape[0] + local_n = floor(n/parts) + result = [] + for i in range(parts): + result.append(dataset[i*local_n: (i+1)*local_n]) + return result + + +def split(out_dir="data"): + + n_splits = int(os.environ.get("FEDN_NUM_DATA_SPLITS", 2)) + + # Make dir + if not os.path.exists(f"{out_dir}/clients"): + os.mkdir(f"{out_dir}/clients") + + train_data = torchvision.datasets.CIFAR10( + root=f"{out_dir}/train", train=True) + test_data = torchvision.datasets.CIFAR10( + root=f"{out_dir}/test", train=False) + + data = { + "x_train": splitset(train_data.data, n_splits), + "y_train": splitset(np.array(train_data.targets), n_splits), + "x_test": splitset(test_data.data, n_splits), + "y_test": splitset(np.array(test_data.targets), n_splits), + } + + # Make splits + for i in range(n_splits): + subdir = f"{out_dir}/clients/{str(i+1)}" + if not os.path.exists(subdir): + os.mkdir(subdir) + torch.save({ + "x_train": data["x_train"][i], + "y_train": data["y_train"][i], + "x_test": data["x_test"][i], + "y_test": data["y_test"][i], + }, + f"{subdir}/cifar10.pt") + + +if __name__ == "__main__": + # Prepare data if not already done + if not os.path.exists(abs_path+"/data/clients/1"): + get_data() + split() + create_knn_monitoring_dataset() diff --git a/examples/FedSimSiam/client/fedn.yaml b/examples/FedSimSiam/client/fedn.yaml new file mode 100644 index 000000000..b05504102 --- /dev/null +++ b/examples/FedSimSiam/client/fedn.yaml @@ -0,0 +1,10 @@ +python_env: python_env.yaml +entry_points: + build: + command: python model.py + startup: + command: python data.py + train: + command: python train.py + validate: + command: python validate.py \ No newline at end of file diff --git a/examples/FedSimSiam/client/model.py b/examples/FedSimSiam/client/model.py new file mode 100644 index 000000000..d50d15c2e --- /dev/null +++ b/examples/FedSimSiam/client/model.py @@ -0,0 +1,144 @@ +import collections + +import torch +import torch.nn.functional as f +from torch import nn +from torchvision.models import resnet18 + +from fedn.utils.helpers.helpers import get_helper + +HELPER_MODULE = "numpyhelper" +helper = get_helper(HELPER_MODULE) + + +def D(p, z, version="simplified"): # negative cosine similarity + if version == "original": + z = z.detach() # stop gradient + p = f.normalize(p, dim=1) # l2-normalize + z = f.normalize(z, dim=1) # l2-normalize + return -(p*z).sum(dim=1).mean() + + elif version == "simplified": # same thing, much faster. Scroll down, speed test in __main__ + return - f.cosine_similarity(p, z.detach(), dim=-1).mean() + else: + raise Exception + + +class ProjectionMLP(nn.Module): + """Projection MLP f""" + + def __init__(self, in_features, h1_features, h2_features, out_features): + super(ProjectionMLP, self).__init__() + self.l1 = nn.Sequential( + nn.Linear(in_features, h1_features), + nn.BatchNorm1d(h1_features), + nn.ReLU(inplace=True) + ) + self.l2 = nn.Sequential( + nn.Linear(h1_features, out_features), + nn.BatchNorm1d(out_features) + ) + + def forward(self, x): + x = self.l1(x) + x = self.l2(x) + return x + + +class PredictionMLP(nn.Module): + """Prediction MLP h""" + + def __init__(self, in_features, hidden_features, out_features): + super(PredictionMLP, self).__init__() + self.l1 = nn.Sequential( + nn.Linear(in_features, hidden_features), + nn.BatchNorm1d(hidden_features), + nn.ReLU(inplace=True) + ) + self.l2 = nn.Linear(hidden_features, out_features) + + def forward(self, x): + x = self.l1(x) + x = self.l2(x) + return x + + +class SimSiam(nn.Module): + def __init__(self): + super(SimSiam, self).__init__() + backbone = resnet18(pretrained=False) + backbone.output_dim = backbone.fc.in_features + backbone.fc = torch.nn.Identity() + + self.backbone = backbone + + self.projector = ProjectionMLP(backbone.output_dim, 2048, 2048, 2048) + self.encoder = nn.Sequential( + self.backbone, + self.projector + ) + self.predictor = PredictionMLP(2048, 512, 2048) + + def forward(self, x1, x2): + f, h = self.encoder, self.predictor + z1, z2 = f(x1), f(x2) + p1, p2 = h(z1), h(z2) + L = D(p1, z2) / 2 + D(p2, z1) / 2 + return {"loss": L} + + +def compile_model(): + """ Compile the pytorch model. + + :return: The compiled model. + :rtype: torch.nn.Module + """ + model = SimSiam() + + return model + + +def save_parameters(model, out_path): + """ Save model paramters to file. + + :param model: The model to serialize. + :type model: torch.nn.Module + :param out_path: The path to save to. + :type out_path: str + """ + parameters_np = [val.cpu().numpy() + for _, val in model.state_dict().items()] + helper.save(parameters_np, out_path) + + +def load_parameters(model_path): + """ Load model parameters from file and populate model. + + param model_path: The path to load from. + :type model_path: str + :return: The loaded model. + :rtype: torch.nn.Module + """ + model = compile_model() + parameters_np = helper.load(model_path) + + params_dict = zip(model.state_dict().keys(), parameters_np) + state_dict = collections.OrderedDict( + {key: torch.tensor(x) for key, x in params_dict}) + model.load_state_dict(state_dict, strict=True) + return model + + +def init_seed(out_path="seed.npz"): + """ Initialize seed model and save it to file. + + :param out_path: The path to save the seed model to. + :type out_path: str + """ + # Init and save + model = compile_model() + save_parameters(model, out_path) + + +if __name__ == "__main__": + init_seed("../seed.npz") diff --git a/examples/FedSimSiam/client/monitoring.py b/examples/FedSimSiam/client/monitoring.py new file mode 100644 index 000000000..245e7f308 --- /dev/null +++ b/examples/FedSimSiam/client/monitoring.py @@ -0,0 +1,62 @@ +""" knn monitor as in InstDisc https://arxiv.org/abs/1805.01978. +This implementation follows http://github.com/zhirongw/lemniscate.pytorch and https://github.com/leftthomas/SimCLR +""" +import torch +import torch.nn.functional as f + + +def knn_monitor(net, memory_data_loader, test_data_loader, k=200, t=0.1): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + net = net.to(device) + net.eval() + classes = len(memory_data_loader.dataset.classes) + total_top1, total_num, feature_bank = 0.0, 0, [] + with torch.no_grad(): + # generate feature bank + for data, target in memory_data_loader: + # feature = net(data.cuda(non_blocking=True)) + feature = net(data.to(device)) + feature = f.normalize(feature, dim=1) + feature_bank.append(feature) + # [D, N] + feature_bank = torch.cat(feature_bank, dim=0).t().contiguous() + # [N] + feature_labels = torch.tensor( + memory_data_loader.dataset.targets, device=feature_bank.device) + # loop test data to predict the label by weighted knn search + for data, target in test_data_loader: + data, target = data.cuda( + non_blocking=True), target.cuda(non_blocking=True) + feature = net(data) + feature = f.normalize(feature, dim=1) + + pred_labels = knn_predict( + feature, feature_bank, feature_labels, classes, k, t) + + total_num += data.size(0) + total_top1 += (pred_labels[:, 0] == target).float().sum().item() + return total_top1 / total_num + + +def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t): + # compute cos similarity between each feature vector and feature bank ---> [B, N] + sim_matrix = torch.mm(feature, feature_bank) + # [B, K] + sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1) + # [B, K] + sim_labels = torch.gather(feature_labels.expand( + feature.size(0), -1), dim=-1, index=sim_indices) + sim_weight = (sim_weight / knn_t).exp() + + # counts for each class + one_hot_label = torch.zeros(feature.size( + 0) * knn_k, classes, device=sim_labels.device) + # [B*K, C] + one_hot_label = one_hot_label.scatter( + dim=-1, index=sim_labels.view(-1, 1), value=1.0) + # weighted score ---> [B, C] + pred_scores = torch.sum(one_hot_label.view(feature.size( + 0), -1, classes) * sim_weight.unsqueeze(dim=-1), dim=1) + + pred_labels = pred_scores.argsort(dim=-1, descending=True) + return pred_labels diff --git a/examples/FedSimSiam/client/python_env.yaml b/examples/FedSimSiam/client/python_env.yaml new file mode 100644 index 000000000..49b1ad2ec --- /dev/null +++ b/examples/FedSimSiam/client/python_env.yaml @@ -0,0 +1,9 @@ +name: fedsimsiam +build_dependencies: + - pip + - setuptools + - wheel==0.37.1 +dependencies: + - torch==2.2.0 + - torchvision==0.17.0 + - fedn==0.9.0 \ No newline at end of file diff --git a/examples/FedSimSiam/client/train.py b/examples/FedSimSiam/client/train.py new file mode 100644 index 000000000..0e7c565f6 --- /dev/null +++ b/examples/FedSimSiam/client/train.py @@ -0,0 +1,129 @@ +import os +import sys + +import numpy as np +import torch +from data import load_data +from model import load_parameters, save_parameters +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms +from utils import init_lrscheduler + +from fedn.utils.helpers.helpers import save_metadata + +dir_path = os.path.dirname(os.path.realpath(__file__)) +sys.path.append(os.path.abspath(dir_path)) + + +class SimSiamDataset(Dataset): + def __init__(self, x, y, is_train=True): + self.x = x + self.y = y + self.is_train = is_train + + def __getitem__(self, idx): + x = self.x[idx] + x = Image.fromarray(x.astype(np.uint8)) + + y = self.y[idx] + + normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], + std=[0.247, 0.243, 0.261]) + augmentation = [ + transforms.RandomResizedCrop(32, scale=(0.2, 1.)), + transforms.RandomApply([ + transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) + ], p=0.8), + transforms.RandomGrayscale(p=0.2), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize + ] + + if self.is_train: + transform = transforms.Compose(augmentation) + + x1 = transform(x) + x2 = transform(x) + return [x1, x2], y + + else: + transform = transforms.Compose([transforms.ToTensor(), normalize]) + + x = transform(x) + return x, y + + def __len__(self): + return len(self.x) + + +def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1, lr=0.01): + """ Complete a model update. + + Load model paramters from in_model_path (managed by the FEDn client), + perform a model update, and write updated paramters + to out_model_path (picked up by the FEDn client). + + :param in_model_path: The path to the input model. + :type in_model_path: str + :param out_model_path: The path to save the output model to. + :type out_model_path: str + :param data_path: The path to the data file. + :type data_path: str + :param batch_size: The batch size to use. + :type batch_size: int + :param epochs: The number of epochs to train. + :type epochs: int + :param lr: The learning rate to use. + :type lr: float + """ + # Load data + x_train, y_train = load_data(data_path) + + # Load parmeters and initialize model + model = load_parameters(in_model_path) + + trainset = SimSiamDataset(x_train, y_train, is_train=True) + trainloader = DataLoader( + trainset, batch_size=batch_size, shuffle=True) + + device = torch.device( + "cuda") if torch.cuda.is_available() else torch.device("cpu") + model = model.to(device) + + model.train() + + optimizer, lr_scheduler = init_lrscheduler( + model, 500, trainloader) + + for epoch in range(epochs): + for idx, data in enumerate(trainloader): + images = data[0] + optimizer.zero_grad() + data_dict = model.forward(images[0].to( + device, non_blocking=True), images[1].to(device, non_blocking=True)) + loss = data_dict["loss"].mean() + print(loss) + loss.backward() + optimizer.step() + lr_scheduler.step() + + # Metadata needed for aggregation server side + metadata = { + # num_examples are mandatory + "num_examples": len(x_train), + "batch_size": batch_size, + "epochs": epochs, + "lr": lr + } + + # Save JSON metadata file (mandatory) + save_metadata(metadata, out_model_path) + + # Save model update (mandatory) + save_parameters(model, out_model_path) + + +if __name__ == "__main__": + train(sys.argv[1], sys.argv[2]) diff --git a/examples/FedSimSiam/client/utils.py b/examples/FedSimSiam/client/utils.py new file mode 100644 index 000000000..b10e0f06d --- /dev/null +++ b/examples/FedSimSiam/client/utils.py @@ -0,0 +1,78 @@ +import numpy as np +import torch + + +class LrScheduler(object): + def __init__(self, optimizer, warmup_epochs, warmup_lr, num_epochs, base_lr, final_lr, iter_per_epoch, constant_predictor_lr=False): + self.base_lr = base_lr + self.constant_predictor_lr = constant_predictor_lr + warmup_iter = iter_per_epoch * warmup_epochs + warmup_lr_schedule = np.linspace(warmup_lr, base_lr, warmup_iter) + decay_iter = iter_per_epoch * (num_epochs - warmup_epochs) + cosine_lr_schedule = final_lr+0.5 * \ + (base_lr-final_lr)*(1+np.cos(np.pi*np.arange(decay_iter)/decay_iter)) + + self.lr_schedule = np.concatenate( + (warmup_lr_schedule, cosine_lr_schedule)) + self.optimizer = optimizer + self.iter = 0 + self.current_lr = 0 + + def step(self): + for param_group in self.optimizer.param_groups: + + if self.constant_predictor_lr and param_group["name"] == "predictor": + param_group["lr"] = self.base_lr + else: + lr = param_group["lr"] = self.lr_schedule[self.iter] + + self.iter += 1 + self.current_lr = lr + return lr + + def get_lr(self): + return self.current_lr + + +def get_optimizer(name, model, lr, momentum, weight_decay): + + predictor_prefix = ("module.predictor", "predictor") + parameters = [{ + "name": "base", + "params": [param for name, param in model.named_parameters() if not name.startswith(predictor_prefix)], + "lr": lr + }, { + "name": "predictor", + "params": [param for name, param in model.named_parameters() if name.startswith(predictor_prefix)], + "lr": lr + }] + + if name == "sgd": + optimizer = torch.optim.SGD( + parameters, lr=lr, momentum=momentum, weight_decay=weight_decay) + + return optimizer + + +def init_lrscheduler(model, total_epochs, dataloader): + warmup_epochs = 10 + warmup_lr = 0 + base_lr = 0.03 + final_lr = 0 + momentum = 0.9 + weight_decay = 0.0005 + batch_size = 64 + + optimizer = get_optimizer( + "sgd", model, + lr=base_lr*batch_size/256, + momentum=momentum, + weight_decay=weight_decay) + + lr_scheduler = LrScheduler( + optimizer, warmup_epochs, warmup_lr*batch_size/256, + total_epochs, base_lr*batch_size/256, final_lr*batch_size/256, + len(dataloader), + constant_predictor_lr=True + ) + return optimizer, lr_scheduler diff --git a/examples/FedSimSiam/client/validate.py b/examples/FedSimSiam/client/validate.py new file mode 100644 index 000000000..5e6c5ac53 --- /dev/null +++ b/examples/FedSimSiam/client/validate.py @@ -0,0 +1,63 @@ +import os +import sys + +import numpy as np +import torch +from data import load_knn_monitoring_dataset +from model import load_parameters +from monitoring import knn_monitor +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + +from fedn.utils.helpers.helpers import save_metrics + +dir_path = os.path.dirname(os.path.realpath(__file__)) +sys.path.append(os.path.abspath(dir_path)) + + +class Cifar10(Dataset): + def __init__(self, x, y): + self.x = x + self.y = y + self.transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], # Approx. CIFAR-10 means + std=[0.247, 0.243, 0.261]) # Approx. CIFAR-10 std deviations + ]) + + def __len__(self): + return len(self.x) + + def __getitem__(self, idx): + x = self.x[idx] + x = Image.fromarray(x.astype(np.uint8)) + x = self.transform(x) + y = self.y[idx] + return x, y + + +def validate(in_model_path, out_json_path, data_path=None): + + memory_loader, test_loader = load_knn_monitoring_dataset(data_path) + + model = load_parameters(in_model_path) + device = torch.device( + "cuda") if torch.cuda.is_available() else torch.device("cpu") + + knn_accuracy = knn_monitor(model.encoder, memory_loader, test_loader, device, k=min( + 25, len(memory_loader.dataset))) + + print("knn accuracy: ", knn_accuracy) + + # JSON schema + report = { + "knn_accuracy": knn_accuracy, + } + + # Save JSON + save_metrics(report, out_json_path) + + +if __name__ == "__main__": + validate(sys.argv[1], sys.argv[2]) diff --git a/examples/FedSimSiam/docker-compose.override.yaml b/examples/FedSimSiam/docker-compose.override.yaml new file mode 100644 index 000000000..524e39d1d --- /dev/null +++ b/examples/FedSimSiam/docker-compose.override.yaml @@ -0,0 +1,35 @@ +# Compose schema version +version: '3.4' + +# Overriding requirements + +x-env: &defaults + GET_HOSTS_FROM: dns + FEDN_PACKAGE_EXTRACT_DIR: package + FEDN_NUM_DATA_SPLITS: 2 + +services: + + client1: + extends: + file: ${HOST_REPO_DIR:-.}/docker-compose.yaml + service: client + environment: + <<: *defaults + FEDN_DATA_PATH: /app/package/client/data/clients/1/cifar10.pt + deploy: + replicas: 1 + volumes: + - ${HOST_REPO_DIR:-.}/fedn:/app/fedn + + client2: + extends: + file: ${HOST_REPO_DIR:-.}/docker-compose.yaml + service: client + environment: + <<: *defaults + FEDN_DATA_PATH: /app/package/client/data/clients/2/cifar10.pt + deploy: + replicas: 1 + volumes: + - ${HOST_REPO_DIR:-.}/fedn:/app/fedn