diff --git a/examples/async-clients/client/entrypoint.py b/examples/async-clients/client/entrypoint.py index 4ddddd956..220b5299b 100644 --- a/examples/async-clients/client/entrypoint.py +++ b/examples/async-clients/client/entrypoint.py @@ -8,7 +8,7 @@ from fedn.utils.helpers.helpers import get_helper, save_metadata, save_metrics -HELPER_MODULE = 'numpyhelper' +HELPER_MODULE = "numpyhelper" ARRAY_SIZE = 10000 @@ -22,7 +22,7 @@ def compile_model(max_iter=1): def save_parameters(model, out_path): - """ Save model to disk. + """Save model to disk. :param model: The model to save. :type model: torch.nn.Module @@ -36,7 +36,7 @@ def save_parameters(model, out_path): def load_parameters(model_path): - """ Load model from disk. + """Load model from disk. param model_path: The path to load from. :type model_path: str @@ -49,8 +49,8 @@ def load_parameters(model_path): return parameters -def init_seed(out_path='seed.npz'): - """ Initialize seed model. +def init_seed(out_path="seed.npz"): + """Initialize seed model. :param out_path: The path to save the seed model to. :type out_path: str @@ -61,7 +61,7 @@ def init_seed(out_path='seed.npz'): def make_data(n_min=50, n_max=100): - """ Generate / simulate a random number n data points. + """Generate / simulate a random number n data points. n will fall in the interval (n_min, n_max) @@ -78,14 +78,12 @@ def make_data(n_min=50, n_max=100): def train(in_model_path, out_model_path): - """ Train model. - - """ + """Train model.""" # Load model parameters = load_parameters(in_model_path) model = compile_model() - n = len(parameters)//2 + n = len(parameters) // 2 model.coefs_ = parameters[:n] model.intercepts_ = parameters[n:] @@ -97,7 +95,7 @@ def train(in_model_path, out_model_path): # Metadata needed for aggregation server side metadata = { - 'num_examples': len(X_train), + "num_examples": len(X_train), } # Save JSON metadata file @@ -108,7 +106,7 @@ def train(in_model_path, out_model_path): def validate(in_model_path, out_json_path): - """ Validate model. + """Validate model. :param in_model_path: The path to the input model. :type in_model_path: str @@ -119,7 +117,7 @@ def validate(in_model_path, out_json_path): """ parameters = load_parameters(in_model_path) model = compile_model() - n = len(parameters)//2 + n = len(parameters) // 2 model.coefs_ = parameters[:n] model.intercepts_ = parameters[n:] @@ -134,9 +132,5 @@ def validate(in_model_path, out_json_path): save_metrics(report, out_json_path) -if __name__ == '__main__': - fire.Fire({ - 'init_seed': init_seed, - 'train': train, - 'validate': validate - }) +if __name__ == "__main__": + fire.Fire({"init_seed": init_seed, "train": train, "validate": validate}) diff --git a/examples/async-clients/init_fedn.py b/examples/async-clients/init_fedn.py index 2aa298602..8677c472d 100644 --- a/examples/async-clients/init_fedn.py +++ b/examples/async-clients/init_fedn.py @@ -1,8 +1,8 @@ from fedn import APIClient -DISCOVER_HOST = '127.0.0.1' +DISCOVER_HOST = "127.0.0.1" DISCOVER_PORT = 8092 client = APIClient(DISCOVER_HOST, DISCOVER_PORT) -client.set_active_package('package.tgz', 'numpyhelper') -client.set_active_model('seed.npz') +client.set_active_package("package.tgz", "numpyhelper") +client.set_active_model("seed.npz") diff --git a/examples/async-clients/run_clients.py b/examples/async-clients/run_clients.py index 5b56c53c9..82da30ad9 100644 --- a/examples/async-clients/run_clients.py +++ b/examples/async-clients/run_clients.py @@ -26,24 +26,39 @@ # Use with a local deployment settings = { - 'DISCOVER_HOST': '127.0.0.1', - 'DISCOVER_PORT': 8092, - 'TOKEN': None, - 'N_CLIENTS': 10, - 'N_CYCLES': 100, - 'CLIENTS_MAX_DELAY': 10, - 'CLIENTS_ONLINE_FOR_SECONDS': 120 + "DISCOVER_HOST": "127.0.0.1", + "DISCOVER_PORT": 8092, + "TOKEN": None, + "N_CLIENTS": 10, + "N_CYCLES": 100, + "CLIENTS_MAX_DELAY": 10, + "CLIENTS_ONLINE_FOR_SECONDS": 120, } -client_config = {'discover_host': settings['DISCOVER_HOST'], 'discover_port': settings['DISCOVER_PORT'], 'token': settings['TOKEN'], 'name': 'testclient', - 'client_id': 1, 'remote_compute_context': True, 'force_ssl': False, 'dry_run': False, 'secure': False, - 'preshared_cert': False, 'verify': False, 'preferred_combiner': False, - 'validator': True, 'trainer': True, 'init': None, 'logfile': 'test.log', 'heartbeat_interval': 2, - 'reconnect_after_missed_heartbeat': 30} +client_config = { + "discover_host": settings["DISCOVER_HOST"], + "discover_port": settings["DISCOVER_PORT"], + "token": settings["TOKEN"], + "name": "testclient", + "client_id": 1, + "remote_compute_context": True, + "force_ssl": False, + "dry_run": False, + "secure": False, + "preshared_cert": False, + "verify": False, + "preferred_combiner": False, + "validator": True, + "trainer": True, + "init": None, + "logfile": "test.log", + "heartbeat_interval": 2, + "reconnect_after_missed_heartbeat": 30, +} -def run_client(online_for=120, name='client'): - """ Simulates a client that starts and stops +def run_client(online_for=120, name="client"): + """Simulates a client that starts and stops at random intervals. The client will start after a radom time 'mean_delay', @@ -55,23 +70,28 @@ def run_client(online_for=120, name='client'): """ conf = copy.deepcopy(client_config) - conf['name'] = name + conf["name"] = name - for i in range(settings['N_CYCLES']): + for i in range(settings["N_CYCLES"]): # Sample a delay until the client starts - t_start = np.random.randint(0, settings['CLIENTS_MAX_DELAY']) + t_start = np.random.randint(0, settings["CLIENTS_MAX_DELAY"]) time.sleep(t_start) fl_client = Client(conf) time.sleep(online_for) fl_client.disconnect() -if __name__ == '__main__': - +if __name__ == "__main__": # We start N_CLIENTS independent client processes processes = [] - for i in range(settings['N_CLIENTS']): - p = Process(target=run_client, args=(settings['CLIENTS_ONLINE_FOR_SECONDS'], 'client{}'.format(i),)) + for i in range(settings["N_CLIENTS"]): + p = Process( + target=run_client, + args=( + settings["CLIENTS_ONLINE_FOR_SECONDS"], + "client{}".format(i), + ), + ) processes.append(p) p.start() diff --git a/examples/async-clients/run_experiment.py b/examples/async-clients/run_experiment.py index d8d12dca2..f6d5e1f6d 100644 --- a/examples/async-clients/run_experiment.py +++ b/examples/async-clients/run_experiment.py @@ -3,16 +3,14 @@ from fedn import APIClient -DISCOVER_HOST = '127.0.0.1' +DISCOVER_HOST = "127.0.0.1" DISCOVER_PORT = 8092 client = APIClient(DISCOVER_HOST, DISCOVER_PORT) -if __name__ == '__main__': - +if __name__ == "__main__": # Run six sessions, each with 100 rounds. num_sessions = 6 for s in range(num_sessions): - session_config = { "helper": "numpyhelper", "id": str(uuid.uuid4()), @@ -23,12 +21,12 @@ } session = client.start_session(**session_config) - if session['success'] is False: - print(session['message']) + if session["success"] is False: + print(session["message"]) exit(0) print("Started session: {}".format(session)) # Wait for session to finish - while not client.session_is_finished(session_config['id']): + while not client.session_is_finished(session_config["id"]): time.sleep(2) diff --git a/examples/flower-client/client/entrypoint.py b/examples/flower-client/client/entrypoint.py index e80e7d4b5..1a9a8b8cf 100755 --- a/examples/flower-client/client/entrypoint.py +++ b/examples/flower-client/client/entrypoint.py @@ -56,9 +56,7 @@ def train(in_model_path, out_model_path): parameters_np = helper.load(in_model_path) # Train on flower client - params, num_examples = flwr_adapter.train( - parameters=parameters_np, partition_id=_get_node_id(), config={} - ) + params, num_examples = flwr_adapter.train(parameters=parameters_np, partition_id=_get_node_id(), config={}) # Metadata needed for aggregation server side metadata = { diff --git a/examples/flower-client/client/flwr_client.py b/examples/flower-client/client/flwr_client.py index 297df3ca7..843e9d31d 100644 --- a/examples/flower-client/client/flwr_client.py +++ b/examples/flower-client/client/flwr_client.py @@ -3,8 +3,7 @@ """ from flwr.client import ClientApp, NumPyClient -from flwr_task import (DEVICE, Net, get_weights, load_data, set_weights, test, - train) +from flwr_task import DEVICE, Net, get_weights, load_data, set_weights, test, train # Define FlowerClient and client_fn @@ -12,9 +11,7 @@ class FlowerClient(NumPyClient): def __init__(self, cid) -> None: super().__init__() self.net = Net().to(DEVICE) - self.trainloader, self.testloader = load_data( - partition_id=int(cid), num_clients=10 - ) + self.trainloader, self.testloader = load_data(partition_id=int(cid), num_clients=10) def get_parameters(self, config): return [val.cpu().numpy() for _, val in self.net.state_dict().items()] diff --git a/examples/flower-client/init_fedn.py b/examples/flower-client/init_fedn.py index 23078fcd9..81864293b 100644 --- a/examples/flower-client/init_fedn.py +++ b/examples/flower-client/init_fedn.py @@ -1,8 +1,8 @@ from fedn import APIClient -DISCOVER_HOST = '127.0.0.1' +DISCOVER_HOST = "127.0.0.1" DISCOVER_PORT = 8092 client = APIClient(DISCOVER_HOST, DISCOVER_PORT) -client.set_package('package.tgz', 'numpyhelper') -client.set_initial_model('seed.npz') +client.set_package("package.tgz", "numpyhelper") +client.set_initial_model("seed.npz") diff --git a/examples/mnist-keras/client/entrypoint.py b/examples/mnist-keras/client/entrypoint.py index 15c0693f2..5420a78bb 100755 --- a/examples/mnist-keras/client/entrypoint.py +++ b/examples/mnist-keras/client/entrypoint.py @@ -7,7 +7,7 @@ from fedn.utils.helpers.helpers import get_helper, save_metadata, save_metrics -HELPER_MODULE = 'numpyhelper' +HELPER_MODULE = "numpyhelper" helper = get_helper(HELPER_MODULE) NUM_CLASSES = 10 @@ -17,13 +17,13 @@ def _get_data_path(): - data_path = os.environ.get('FEDN_DATA_PATH', abs_path + '/data/clients/1/mnist.npz') + data_path = os.environ.get("FEDN_DATA_PATH", abs_path + "/data/clients/1/mnist.npz") return data_path def compile_model(img_rows=28, img_cols=28): - """ Compile the TF model. + """Compile the TF model. param: img_rows: The number of rows in the image type: img_rows: int @@ -38,13 +38,11 @@ def compile_model(img_rows=28, img_cols=28): # Define model model = tf.keras.models.Sequential() model.add(tf.keras.layers.Flatten(input_shape=input_shape)) - model.add(tf.keras.layers.Dense(64, activation='relu')) + model.add(tf.keras.layers.Dense(64, activation="relu")) model.add(tf.keras.layers.Dropout(0.5)) - model.add(tf.keras.layers.Dense(32, activation='relu')) - model.add(tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')) - model.compile(loss=tf.keras.losses.categorical_crossentropy, - optimizer=tf.keras.optimizers.Adam(), - metrics=['accuracy']) + model.add(tf.keras.layers.Dense(32, activation="relu")) + model.add(tf.keras.layers.Dense(NUM_CLASSES, activation="softmax")) + model.compile(loss=tf.keras.losses.categorical_crossentropy, optimizer=tf.keras.optimizers.Adam(), metrics=["accuracy"]) return model @@ -57,14 +55,14 @@ def load_data(data_path, is_train=True): data = np.load(data_path) if is_train: - X = data['x_train'] - y = data['y_train'] + X = data["x_train"] + y = data["y_train"] else: - X = data['x_test'] - y = data['y_test'] + X = data["x_test"] + y = data["y_test"] # Normalize - X = X.astype('float32') + X = X.astype("float32") X = np.expand_dims(X, -1) X = X / 255 y = tf.keras.utils.to_categorical(y, NUM_CLASSES) @@ -72,8 +70,8 @@ def load_data(data_path, is_train=True): return X, y -def init_seed(out_path='../seed.npz'): - """ Initialize seed model and save it to file. +def init_seed(out_path="../seed.npz"): + """Initialize seed model and save it to file. :param out_path: The path to save the seed model to. :type out_path: str @@ -83,7 +81,7 @@ def init_seed(out_path='../seed.npz'): def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1): - """ Complete a model update. + """Complete a model update. Load model paramters from in_model_path (managed by the FEDn client), perform a model update, and write updated paramters @@ -114,9 +112,9 @@ def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1 # Metadata needed for aggregation server side metadata = { # num_examples are mandatory - 'num_examples': len(x_train), - 'batch_size': batch_size, - 'epochs': epochs, + "num_examples": len(x_train), + "batch_size": batch_size, + "epochs": epochs, } # Save JSON metadata file (mandatory) @@ -128,7 +126,7 @@ def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1 def validate(in_model_path, out_json_path, data_path=None): - """ Validate model. + """Validate model. :param in_model_path: The path to the input model. :type in_model_path: str @@ -182,14 +180,16 @@ def predict(in_model_path, out_json_path, data_path=None): # Save JSON with open(out_json_path, "w") as fh: - fh.write(json.dumps({'predictions': y_pred.tolist()})) - - -if __name__ == '__main__': - fire.Fire({ - 'init_seed': init_seed, - 'train': train, - 'validate': validate, - 'predict': predict, - '_get_data_path': _get_data_path, # for testing - }) + fh.write(json.dumps({"predictions": y_pred.tolist()})) + + +if __name__ == "__main__": + fire.Fire( + { + "init_seed": init_seed, + "train": train, + "validate": validate, + "predict": predict, + "_get_data_path": _get_data_path, # for testing + } + ) diff --git a/examples/mnist-keras/client/get_data.py b/examples/mnist-keras/client/get_data.py index 28a12bd20..ed123a4a3 100755 --- a/examples/mnist-keras/client/get_data.py +++ b/examples/mnist-keras/client/get_data.py @@ -7,14 +7,14 @@ def splitset(dataset, parts): n = dataset.shape[0] - local_n = floor(n/parts) + local_n = floor(n / parts) result = [] for i in range(parts): - result.append(dataset[i*local_n: (i+1)*local_n]) + result.append(dataset[i * local_n : (i + 1) * local_n]) return np.array(result) -def split(dataset='data/mnist.npz', outdir='data', n_splits=2): +def split(dataset="data/mnist.npz", outdir="data", n_splits=2): # Load and convert to dict package = np.load(dataset) data = {} @@ -22,32 +22,27 @@ def split(dataset='data/mnist.npz', outdir='data', n_splits=2): data[key] = splitset(val, n_splits) # Make dir if necessary - if not os.path.exists(f'{outdir}/clients'): - os.mkdir(f'{outdir}/clients') + if not os.path.exists(f"{outdir}/clients"): + os.mkdir(f"{outdir}/clients") # Make splits for i in range(n_splits): - subdir = f'{outdir}/clients/{str(i+1)}' + subdir = f"{outdir}/clients/{str(i+1)}" if not os.path.exists(subdir): os.mkdir(subdir) - np.savez(f'{subdir}/mnist.npz', - x_train=data['x_train'][i], - y_train=data['y_train'][i], - x_test=data['x_test'][i], - y_test=data['y_test'][i]) + np.savez(f"{subdir}/mnist.npz", x_train=data["x_train"][i], y_train=data["y_train"][i], x_test=data["x_test"][i], y_test=data["y_test"][i]) -def get_data(out_dir='data'): +def get_data(out_dir="data"): # Make dir if necessary if not os.path.exists(out_dir): os.mkdir(out_dir) # Download data (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() - np.savez(f'{out_dir}/mnist.npz', x_train=x_train, - y_train=y_train, x_test=x_test, y_test=y_test) + np.savez(f"{out_dir}/mnist.npz", x_train=x_train, y_train=y_train, x_test=x_test, y_test=y_test) -if __name__ == '__main__': +if __name__ == "__main__": get_data() split() diff --git a/examples/mnist-pytorch/client/data.py b/examples/mnist-pytorch/client/data.py index d67274548..b921f3132 100644 --- a/examples/mnist-pytorch/client/data.py +++ b/examples/mnist-pytorch/client/data.py @@ -8,22 +8,20 @@ abs_path = os.path.abspath(dir_path) -def get_data(out_dir='data'): +def get_data(out_dir="data"): # Make dir if necessary if not os.path.exists(out_dir): os.mkdir(out_dir) # Only download if not already downloaded - if not os.path.exists(f'{out_dir}/train'): - torchvision.datasets.MNIST( - root=f'{out_dir}/train', transform=torchvision.transforms.ToTensor, train=True, download=True) - if not os.path.exists(f'{out_dir}/test'): - torchvision.datasets.MNIST( - root=f'{out_dir}/test', transform=torchvision.transforms.ToTensor, train=False, download=True) + if not os.path.exists(f"{out_dir}/train"): + torchvision.datasets.MNIST(root=f"{out_dir}/train", transform=torchvision.transforms.ToTensor, train=True, download=True) + if not os.path.exists(f"{out_dir}/test"): + torchvision.datasets.MNIST(root=f"{out_dir}/test", transform=torchvision.transforms.ToTensor, train=False, download=True) def load_data(data_path, is_train=True): - """ Load data from disk. + """Load data from disk. :param data_path: Path to data file. :type data_path: str @@ -33,16 +31,16 @@ def load_data(data_path, is_train=True): :rtype: tuple """ if data_path is None: - data_path = os.environ.get("FEDN_DATA_PATH", abs_path+'/data/clients/1/mnist.pt') + data_path = os.environ.get("FEDN_DATA_PATH", abs_path + "/data/clients/1/mnist.pt") data = torch.load(data_path) if is_train: - X = data['x_train'] - y = data['y_train'] + X = data["x_train"] + y = data["y_train"] else: - X = data['x_test'] - y = data['y_test'] + X = data["x_test"] + y = data["y_test"] # Normalize X = X / 255 @@ -52,49 +50,48 @@ def load_data(data_path, is_train=True): def splitset(dataset, parts): n = dataset.shape[0] - local_n = floor(n/parts) + local_n = floor(n / parts) result = [] for i in range(parts): - result.append(dataset[i*local_n: (i+1)*local_n]) + result.append(dataset[i * local_n : (i + 1) * local_n]) return result -def split(out_dir='data'): - +def split(out_dir="data"): n_splits = int(os.environ.get("FEDN_NUM_DATA_SPLITS", 2)) # Make dir - if not os.path.exists(f'{out_dir}/clients'): - os.mkdir(f'{out_dir}/clients') + if not os.path.exists(f"{out_dir}/clients"): + os.mkdir(f"{out_dir}/clients") # Load and convert to dict - train_data = torchvision.datasets.MNIST( - root=f'{out_dir}/train', transform=torchvision.transforms.ToTensor, train=True) - test_data = torchvision.datasets.MNIST( - root=f'{out_dir}/test', transform=torchvision.transforms.ToTensor, train=False) + train_data = torchvision.datasets.MNIST(root=f"{out_dir}/train", transform=torchvision.transforms.ToTensor, train=True) + test_data = torchvision.datasets.MNIST(root=f"{out_dir}/test", transform=torchvision.transforms.ToTensor, train=False) data = { - 'x_train': splitset(train_data.data, n_splits), - 'y_train': splitset(train_data.targets, n_splits), - 'x_test': splitset(test_data.data, n_splits), - 'y_test': splitset(test_data.targets, n_splits), + "x_train": splitset(train_data.data, n_splits), + "y_train": splitset(train_data.targets, n_splits), + "x_test": splitset(test_data.data, n_splits), + "y_test": splitset(test_data.targets, n_splits), } # Make splits for i in range(n_splits): - subdir = f'{out_dir}/clients/{str(i+1)}' + subdir = f"{out_dir}/clients/{str(i+1)}" if not os.path.exists(subdir): os.mkdir(subdir) - torch.save({ - 'x_train': data['x_train'][i], - 'y_train': data['y_train'][i], - 'x_test': data['x_test'][i], - 'y_test': data['y_test'][i], - }, - f'{subdir}/mnist.pt') - - -if __name__ == '__main__': + torch.save( + { + "x_train": data["x_train"][i], + "y_train": data["y_train"][i], + "x_test": data["x_test"][i], + "y_test": data["y_test"][i], + }, + f"{subdir}/mnist.pt", + ) + + +if __name__ == "__main__": # Prepare data if not already done - if not os.path.exists(abs_path+'/data/clients/1'): + if not os.path.exists(abs_path + "/data/clients/1"): get_data() split() diff --git a/examples/mnist-pytorch/client/model.py b/examples/mnist-pytorch/client/model.py index cb5b7afc2..6ad344770 100644 --- a/examples/mnist-pytorch/client/model.py +++ b/examples/mnist-pytorch/client/model.py @@ -4,16 +4,17 @@ from fedn.utils.helpers.helpers import get_helper -HELPER_MODULE = 'numpyhelper' +HELPER_MODULE = "numpyhelper" helper = get_helper(HELPER_MODULE) def compile_model(): - """ Compile the pytorch model. + """Compile the pytorch model. :return: The compiled model. :rtype: torch.nn.Module """ + class Net(torch.nn.Module): def __init__(self): super(Net, self).__init__() @@ -32,7 +33,7 @@ def forward(self, x): def save_parameters(model, out_path): - """ Save model paramters to file. + """Save model paramters to file. :param model: The model to serialize. :type model: torch.nn.Module @@ -44,7 +45,7 @@ def save_parameters(model, out_path): def load_parameters(model_path): - """ Load model parameters from file and populate model. + """Load model parameters from file and populate model. param model_path: The path to load from. :type model_path: str @@ -60,8 +61,8 @@ def load_parameters(model_path): return model -def init_seed(out_path='seed.npz'): - """ Initialize seed model and save it to file. +def init_seed(out_path="seed.npz"): + """Initialize seed model and save it to file. :param out_path: The path to save the seed model to. :type out_path: str @@ -72,4 +73,4 @@ def init_seed(out_path='seed.npz'): if __name__ == "__main__": - init_seed('../seed.npz') + init_seed("../seed.npz") diff --git a/examples/mnist-pytorch/client/train.py b/examples/mnist-pytorch/client/train.py index fdf2480ae..9ac9cce61 100644 --- a/examples/mnist-pytorch/client/train.py +++ b/examples/mnist-pytorch/client/train.py @@ -3,9 +3,9 @@ import sys import torch -from data import load_data from model import load_parameters, save_parameters +from data import load_data from fedn.utils.helpers.helpers import save_metadata dir_path = os.path.dirname(os.path.realpath(__file__)) @@ -13,7 +13,7 @@ def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1, lr=0.01): - """ Complete a model update. + """Complete a model update. Load model paramters from in_model_path (managed by the FEDn client), perform a model update, and write updated paramters @@ -45,8 +45,8 @@ def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1 for e in range(epochs): # epoch loop for b in range(n_batches): # batch loop # Retrieve current batch - batch_x = x_train[b * batch_size:(b + 1) * batch_size] - batch_y = y_train[b * batch_size:(b + 1) * batch_size] + batch_x = x_train[b * batch_size : (b + 1) * batch_size] + batch_y = y_train[b * batch_size : (b + 1) * batch_size] # Train on batch optimizer.zero_grad() outputs = model(batch_x) @@ -55,16 +55,15 @@ def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1 optimizer.step() # Log if b % 100 == 0: - print( - f"Epoch {e}/{epochs-1} | Batch: {b}/{n_batches-1} | Loss: {loss.item()}") + print(f"Epoch {e}/{epochs-1} | Batch: {b}/{n_batches-1} | Loss: {loss.item()}") # Metadata needed for aggregation server side metadata = { # num_examples are mandatory - 'num_examples': len(x_train), - 'batch_size': batch_size, - 'epochs': epochs, - 'lr': lr + "num_examples": len(x_train), + "batch_size": batch_size, + "epochs": epochs, + "lr": lr, } # Save JSON metadata file (mandatory) diff --git a/examples/mnist-pytorch/client/validate.py b/examples/mnist-pytorch/client/validate.py index 0a5592368..09328181f 100644 --- a/examples/mnist-pytorch/client/validate.py +++ b/examples/mnist-pytorch/client/validate.py @@ -2,9 +2,9 @@ import sys import torch -from data import load_data from model import load_parameters +from data import load_data from fedn.utils.helpers.helpers import save_metrics dir_path = os.path.dirname(os.path.realpath(__file__)) @@ -12,7 +12,7 @@ def validate(in_model_path, out_json_path, data_path=None): - """ Validate model. + """Validate model. :param in_model_path: The path to the input model. :type in_model_path: str @@ -34,12 +34,10 @@ def validate(in_model_path, out_json_path, data_path=None): with torch.no_grad(): train_out = model(x_train) training_loss = criterion(train_out, y_train) - training_accuracy = torch.sum(torch.argmax( - train_out, dim=1) == y_train) / len(train_out) + training_accuracy = torch.sum(torch.argmax(train_out, dim=1) == y_train) / len(train_out) test_out = model(x_test) test_loss = criterion(test_out, y_test) - test_accuracy = torch.sum(torch.argmax( - test_out, dim=1) == y_test) / len(test_out) + test_accuracy = torch.sum(torch.argmax(test_out, dim=1) == y_test) / len(test_out) # JSON schema report = { diff --git a/fedn/cli/client_cmd.py b/fedn/cli/client_cmd.py index f9916985c..e72f29569 100644 --- a/fedn/cli/client_cmd.py +++ b/fedn/cli/client_cmd.py @@ -7,8 +7,7 @@ from fedn.network.clients.client import Client from .main import main -from .shared import (CONTROLLER_DEFAULTS, apply_config, get_api_url, get_token, - print_response) +from .shared import CONTROLLER_DEFAULTS, apply_config, get_api_url, get_token, print_response def validate_client_config(config): @@ -18,16 +17,15 @@ def validate_client_config(config): """ try: - if config['discover_host'] is None or \ - config['discover_host'] == '': + if config["discover_host"] is None or config["discover_host"] == "": raise InvalidClientConfig("Missing required configuration: discover_host") - if 'discover_port' not in config.keys(): - config['discover_port'] = None + if "discover_port" not in config.keys(): + config["discover_port"] = None except Exception: raise InvalidClientConfig("Could not load config from file. Check config") -@main.group('client') +@main.group("client") @click.pass_context def client_cmd(ctx): """ @@ -37,12 +35,12 @@ def client_cmd(ctx): pass -@click.option('-p', '--protocol', required=False, default=CONTROLLER_DEFAULTS['protocol'], help='Communication protocol of controller (api)') -@click.option('-H', '--host', required=False, default=CONTROLLER_DEFAULTS['host'], help='Hostname of controller (api)') -@click.option('-P', '--port', required=False, default=CONTROLLER_DEFAULTS['port'], help='Port of controller (api)') -@click.option('-t', '--token', required=False, help='Authentication token') -@click.option('--n_max', required=False, help='Number of items to list') -@client_cmd.command('list') +@click.option("-p", "--protocol", required=False, default=CONTROLLER_DEFAULTS["protocol"], help="Communication protocol of controller (api)") +@click.option("-H", "--host", required=False, default=CONTROLLER_DEFAULTS["host"], help="Hostname of controller (api)") +@click.option("-P", "--port", required=False, default=CONTROLLER_DEFAULTS["port"], help="Port of controller (api)") +@click.option("-t", "--token", required=False, help="Authentication token") +@click.option("--n_max", required=False, help="Number of items to list") +@client_cmd.command("list") @click.pass_context def list_clients(ctx, protocol: str, host: str, port: str, token: str = None, n_max: int = None): """ @@ -52,53 +50,70 @@ def list_clients(ctx, protocol: str, host: str, port: str, token: str = None, n_ - result: list of clients """ - url = get_api_url(protocol=protocol, host=host, port=port, endpoint='clients') + url = get_api_url(protocol=protocol, host=host, port=port, endpoint="clients") headers = {} if n_max: - headers['X-Limit'] = n_max + headers["X-Limit"] = n_max _token = get_token(token) if _token: - headers['Authorization'] = _token + headers["Authorization"] = _token - click.echo(f'\nListing clients: {url}\n') - click.echo(f'Headers: {headers}') + click.echo(f"\nListing clients: {url}\n") + click.echo(f"Headers: {headers}") try: response = requests.get(url, headers=headers) - print_response(response, 'clients') + print_response(response, "clients") except requests.exceptions.ConnectionError: - click.echo(f'Error: Could not connect to {url}') - - -@client_cmd.command('start') -@click.option('-d', '--discoverhost', required=False, help='Hostname for discovery services(reducer).') -@click.option('-p', '--discoverport', required=False, help='Port for discovery services (reducer).') -@click.option('--token', required=False, help='Set token provided by reducer if enabled') -@click.option('-n', '--name', required=False, default="client" + str(uuid.uuid4())[:8]) -@click.option('-i', '--client_id', required=False) -@click.option('--local-package', is_flag=True, help='Enable local compute package') -@click.option('--force-ssl', is_flag=True, help='Force SSL/TLS for REST service') -@click.option('-u', '--dry-run', required=False, default=False) -@click.option('-s', '--secure', required=False, default=False) -@click.option('-pc', '--preshared-cert', required=False, default=False) -@click.option('-v', '--verify', is_flag=True, help='Verify SSL/TLS for REST service') -@click.option('-c', '--preferred-combiner', required=False, default=False) -@click.option('-va', '--validator', required=False, default=True) -@click.option('-tr', '--trainer', required=False, default=True) -@click.option('-in', '--init', required=False, default=None, - help='Set to a filename to (re)init client from file state.') -@click.option('-l', '--logfile', required=False, default=None, - help='Set logfile for client log to file.') -@click.option('--heartbeat-interval', required=False, default=2) -@click.option('--reconnect-after-missed-heartbeat', required=False, default=30) -@click.option('--verbosity', required=False, default='INFO', type=click.Choice(['CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG'], case_sensitive=False)) + click.echo(f"Error: Could not connect to {url}") + + +@client_cmd.command("start") +@click.option("-d", "--discoverhost", required=False, help="Hostname for discovery services(reducer).") +@click.option("-p", "--discoverport", required=False, help="Port for discovery services (reducer).") +@click.option("--token", required=False, help="Set token provided by reducer if enabled") +@click.option("-n", "--name", required=False, default="client" + str(uuid.uuid4())[:8]) +@click.option("-i", "--client_id", required=False) +@click.option("--local-package", is_flag=True, help="Enable local compute package") +@click.option("--force-ssl", is_flag=True, help="Force SSL/TLS for REST service") +@click.option("-u", "--dry-run", required=False, default=False) +@click.option("-s", "--secure", required=False, default=False) +@click.option("-pc", "--preshared-cert", required=False, default=False) +@click.option("-v", "--verify", is_flag=True, help="Verify SSL/TLS for REST service") +@click.option("-c", "--preferred-combiner", required=False, default=False) +@click.option("-va", "--validator", required=False, default=True) +@click.option("-tr", "--trainer", required=False, default=True) +@click.option("-in", "--init", required=False, default=None, help="Set to a filename to (re)init client from file state.") +@click.option("-l", "--logfile", required=False, default=None, help="Set logfile for client log to file.") +@click.option("--heartbeat-interval", required=False, default=2) +@click.option("--reconnect-after-missed-heartbeat", required=False, default=30) +@click.option("--verbosity", required=False, default="INFO", type=click.Choice(["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"], case_sensitive=False)) @click.pass_context -def client_cmd(ctx, discoverhost, discoverport, token, name, client_id, local_package, force_ssl, dry_run, secure, preshared_cert, - verify, preferred_combiner, validator, trainer, init, logfile, heartbeat_interval, reconnect_after_missed_heartbeat, - verbosity): +def client_cmd( + ctx, + discoverhost, + discoverport, + token, + name, + client_id, + local_package, + force_ssl, + dry_run, + secure, + preshared_cert, + verify, + preferred_combiner, + validator, + trainer, + init, + logfile, + heartbeat_interval, + reconnect_after_missed_heartbeat, + verbosity, +): """ :param ctx: @@ -121,21 +136,36 @@ def client_cmd(ctx, discoverhost, discoverport, token, name, client_id, local_pa :return: """ remote = False if local_package else True - config = {'discover_host': discoverhost, 'discover_port': discoverport, 'token': token, 'name': name, - 'client_id': client_id, 'remote_compute_context': remote, 'force_ssl': force_ssl, 'dry_run': dry_run, 'secure': secure, - 'preshared_cert': preshared_cert, 'verify': verify, 'preferred_combiner': preferred_combiner, - 'validator': validator, 'trainer': trainer, 'logfile': logfile, 'heartbeat_interval': heartbeat_interval, - 'reconnect_after_missed_heartbeat': reconnect_after_missed_heartbeat, 'verbosity': verbosity} + config = { + "discover_host": discoverhost, + "discover_port": discoverport, + "token": token, + "name": name, + "client_id": client_id, + "remote_compute_context": remote, + "force_ssl": force_ssl, + "dry_run": dry_run, + "secure": secure, + "preshared_cert": preshared_cert, + "verify": verify, + "preferred_combiner": preferred_combiner, + "validator": validator, + "trainer": trainer, + "logfile": logfile, + "heartbeat_interval": heartbeat_interval, + "reconnect_after_missed_heartbeat": reconnect_after_missed_heartbeat, + "verbosity": verbosity, + } if init: apply_config(init, config) - click.echo(f'\nClient configuration loaded from file: {init}') - click.echo('Values set in file override defaults and command line arguments...\n') + click.echo(f"\nClient configuration loaded from file: {init}") + click.echo("Values set in file override defaults and command line arguments...\n") try: validate_client_config(config) except InvalidClientConfig as e: - click.echo(f'Error: {e}') + click.echo(f"Error: {e}") return client = Client(config) diff --git a/fedn/cli/combiner_cmd.py b/fedn/cli/combiner_cmd.py index 758dda718..2b4447437 100644 --- a/fedn/cli/combiner_cmd.py +++ b/fedn/cli/combiner_cmd.py @@ -6,11 +6,10 @@ from fedn.network.combiner.combiner import Combiner from .main import main -from .shared import (CONTROLLER_DEFAULTS, apply_config, get_api_url, get_token, - print_response) +from .shared import CONTROLLER_DEFAULTS, apply_config, get_api_url, get_token, print_response -@main.group('combiner') +@main.group("combiner") @click.pass_context def combiner_cmd(ctx): """ @@ -20,19 +19,18 @@ def combiner_cmd(ctx): pass -@combiner_cmd.command('start') -@click.option('-d', '--discoverhost', required=False, help='Hostname for discovery services (reducer).') -@click.option('-p', '--discoverport', required=False, help='Port for discovery services (reducer).') -@click.option('-t', '--token', required=False, help='Set token provided by reducer if enabled') -@click.option('-n', '--name', required=False, default="combiner" + str(uuid.uuid4())[:8], help='Set name for combiner.') -@click.option('-h', '--host', required=False, default="combiner", help='Set hostname.') -@click.option('-i', '--port', required=False, default=12080, help='Set port.') -@click.option('-f', '--fqdn', required=False, default=None, help='Set fully qualified domain name') -@click.option('-s', '--secure', is_flag=True, help='Enable SSL/TLS encrypted gRPC channels.') -@click.option('-v', '--verify', is_flag=True, help='Verify SSL/TLS for REST discovery service (reducer)') -@click.option('-c', '--max_clients', required=False, default=30, help='The maximal number of client connections allowed.') -@click.option('-in', '--init', required=False, default=None, - help='Path to configuration file to (re)init combiner.') +@combiner_cmd.command("start") +@click.option("-d", "--discoverhost", required=False, help="Hostname for discovery services (reducer).") +@click.option("-p", "--discoverport", required=False, help="Port for discovery services (reducer).") +@click.option("-t", "--token", required=False, help="Set token provided by reducer if enabled") +@click.option("-n", "--name", required=False, default="combiner" + str(uuid.uuid4())[:8], help="Set name for combiner.") +@click.option("-h", "--host", required=False, default="combiner", help="Set hostname.") +@click.option("-i", "--port", required=False, default=12080, help="Set port.") +@click.option("-f", "--fqdn", required=False, default=None, help="Set fully qualified domain name") +@click.option("-s", "--secure", is_flag=True, help="Enable SSL/TLS encrypted gRPC channels.") +@click.option("-v", "--verify", is_flag=True, help="Verify SSL/TLS for REST discovery service (reducer)") +@click.option("-c", "--max_clients", required=False, default=30, help="The maximal number of client connections allowed.") +@click.option("-in", "--init", required=False, default=None, help="Path to configuration file to (re)init combiner.") @click.pass_context def start_cmd(ctx, discoverhost, discoverport, token, name, host, port, fqdn, secure, verify, max_clients, init): """ @@ -48,24 +46,34 @@ def start_cmd(ctx, discoverhost, discoverport, token, name, host, port, fqdn, se :param max_clients: :param init: """ - config = {'discover_host': discoverhost, 'discover_port': discoverport, 'token': token, 'host': host, - 'port': port, 'fqdn': fqdn, 'name': name, 'secure': secure, 'verify': verify, 'max_clients': max_clients} + config = { + "discover_host": discoverhost, + "discover_port": discoverport, + "token": token, + "host": host, + "port": port, + "fqdn": fqdn, + "name": name, + "secure": secure, + "verify": verify, + "max_clients": max_clients, + } if init: apply_config(init, config) - click.echo(f'\nCombiner configuration loaded from file: {init}') - click.echo('Values set in file override defaults and command line arguments...\n') + click.echo(f"\nCombiner configuration loaded from file: {init}") + click.echo("Values set in file override defaults and command line arguments...\n") combiner = Combiner(config) combiner.run() -@click.option('-p', '--protocol', required=False, default=CONTROLLER_DEFAULTS['protocol'], help='Communication protocol of controller (api)') -@click.option('-H', '--host', required=False, default=CONTROLLER_DEFAULTS['host'], help='Hostname of controller (api)') -@click.option('-P', '--port', required=False, default=CONTROLLER_DEFAULTS['port'], help='Port of controller (api)') -@click.option('-t', '--token', required=False, help='Authentication token') -@click.option('--n_max', required=False, help='Number of items to list') -@combiner_cmd.command('list') +@click.option("-p", "--protocol", required=False, default=CONTROLLER_DEFAULTS["protocol"], help="Communication protocol of controller (api)") +@click.option("-H", "--host", required=False, default=CONTROLLER_DEFAULTS["host"], help="Hostname of controller (api)") +@click.option("-P", "--port", required=False, default=CONTROLLER_DEFAULTS["port"], help="Port of controller (api)") +@click.option("-t", "--token", required=False, help="Authentication token") +@click.option("--n_max", required=False, help="Number of items to list") +@combiner_cmd.command("list") @click.pass_context def list_combiners(ctx, protocol: str, host: str, port: str, token: str = None, n_max: int = None): """ @@ -75,22 +83,22 @@ def list_combiners(ctx, protocol: str, host: str, port: str, token: str = None, - result: list of combiners """ - url = get_api_url(protocol=protocol, host=host, port=port, endpoint='combiners') + url = get_api_url(protocol=protocol, host=host, port=port, endpoint="combiners") headers = {} if n_max: - headers['X-Limit'] = n_max + headers["X-Limit"] = n_max _token = get_token(token) if _token: - headers['Authorization'] = _token + headers["Authorization"] = _token - click.echo(f'\nListing combiners: {url}\n') - click.echo(f'Headers: {headers}') + click.echo(f"\nListing combiners: {url}\n") + click.echo(f"Headers: {headers}") try: response = requests.get(url, headers=headers) - print_response(response, 'combiners') + print_response(response, "combiners") except requests.exceptions.ConnectionError: - click.echo(f'Error: Could not connect to {url}') + click.echo(f"Error: Could not connect to {url}") diff --git a/fedn/cli/config_cmd.py b/fedn/cli/config_cmd.py index 856882e62..d5286997f 100644 --- a/fedn/cli/config_cmd.py +++ b/fedn/cli/config_cmd.py @@ -5,50 +5,32 @@ from .main import main envs = [ - { - "name": "FEDN_CONTROLLER_PROTOCOL", - "description": "The protocol to use for communication with the controller." - }, - { - "name": "FEDN_CONTROLLER_HOST", - "description": "The host to use for communication with the controller." - }, - { - "name": "FEDN_CONTROLLER_PORT", - "description": "The port to use for communication with the controller." - }, - { - "name": "FEDN_AUTH_TOKEN", - "description": "The authentication token to use for communication with the controller and combiner." - }, - { - "name": "FEDN_AUTH_SCHEME", - "description": "The authentication scheme to use for communication with the controller and combiner." - }, + {"name": "FEDN_CONTROLLER_PROTOCOL", "description": "The protocol to use for communication with the controller."}, + {"name": "FEDN_CONTROLLER_HOST", "description": "The host to use for communication with the controller."}, + {"name": "FEDN_CONTROLLER_PORT", "description": "The port to use for communication with the controller."}, + {"name": "FEDN_AUTH_TOKEN", "description": "The authentication token to use for communication with the controller and combiner."}, + {"name": "FEDN_AUTH_SCHEME", "description": "The authentication scheme to use for communication with the controller and combiner."}, { "name": "FEDN_CONTROLLER_URL", - "description": "The URL of the controller. Overrides FEDN_CONTROLLER_PROTOCOL, FEDN_CONTROLLER_HOST and FEDN_CONTROLLER_PORT." + "description": "The URL of the controller. Overrides FEDN_CONTROLLER_PROTOCOL, FEDN_CONTROLLER_HOST and FEDN_CONTROLLER_PORT.", }, - { - "name": "FEDN_PACKAGE_EXTRACT_DIR", - "description": "The directory to extract packages to." - } + {"name": "FEDN_PACKAGE_EXTRACT_DIR", "description": "The directory to extract packages to."}, ] -@main.group('config', invoke_without_command=True) +@main.group("config", invoke_without_command=True) @click.pass_context def config_cmd(ctx): """ - Configuration commands for the FEDn CLI. """ if ctx.invoked_subcommand is None: - click.echo('\n--- FEDn Cli Configuration ---\n') - click.echo('Current configuration:\n') + click.echo("\n--- FEDn Cli Configuration ---\n") + click.echo("Current configuration:\n") for env in envs: - name = env['name'] + name = env["name"] value = os.environ.get(name) click.echo(f'{name}: {value or "Not set"}') click.echo(f'{env["description"]}\n') - click.echo('\n') + click.echo("\n") diff --git a/fedn/cli/main.py b/fedn/cli/main.py index d004c9605..52276c418 100644 --- a/fedn/cli/main.py +++ b/fedn/cli/main.py @@ -2,7 +2,7 @@ CONTEXT_SETTINGS = dict( # Support -h as a shortcut for --help - help_option_names=['-h', '--help'], + help_option_names=["-h", "--help"], ) diff --git a/fedn/cli/model_cmd.py b/fedn/cli/model_cmd.py index ddccd9e2d..e44793a9f 100644 --- a/fedn/cli/model_cmd.py +++ b/fedn/cli/model_cmd.py @@ -1,4 +1,3 @@ - import click import requests @@ -6,7 +5,7 @@ from .shared import CONTROLLER_DEFAULTS, get_api_url, get_token, print_response -@main.group('model') +@main.group("model") @click.pass_context def model_cmd(ctx): """ @@ -16,12 +15,12 @@ def model_cmd(ctx): pass -@click.option('-p', '--protocol', required=False, default=CONTROLLER_DEFAULTS['protocol'], help='Communication protocol of controller (api)') -@click.option('-H', '--host', required=False, default=CONTROLLER_DEFAULTS['host'], help='Hostname of controller (api)') -@click.option('-P', '--port', required=False, default=CONTROLLER_DEFAULTS['port'], help='Port of controller (api)') -@click.option('-t', '--token', required=False, help='Authentication token') -@click.option('--n_max', required=False, help='Number of items to list') -@model_cmd.command('list') +@click.option("-p", "--protocol", required=False, default=CONTROLLER_DEFAULTS["protocol"], help="Communication protocol of controller (api)") +@click.option("-H", "--host", required=False, default=CONTROLLER_DEFAULTS["host"], help="Hostname of controller (api)") +@click.option("-P", "--port", required=False, default=CONTROLLER_DEFAULTS["port"], help="Port of controller (api)") +@click.option("-t", "--token", required=False, help="Authentication token") +@click.option("--n_max", required=False, help="Number of items to list") +@model_cmd.command("list") @click.pass_context def list_models(ctx, protocol: str, host: str, port: str, token: str = None, n_max: int = None): """ @@ -31,22 +30,22 @@ def list_models(ctx, protocol: str, host: str, port: str, token: str = None, n_m - result: list of models """ - url = get_api_url(protocol=protocol, host=host, port=port, endpoint='models') + url = get_api_url(protocol=protocol, host=host, port=port, endpoint="models") headers = {} if n_max: - headers['X-Limit'] = n_max + headers["X-Limit"] = n_max _token = get_token(token) if _token: - headers['Authorization'] = _token + headers["Authorization"] = _token - click.echo(f'\nListing models: {url}\n') - click.echo(f'Headers: {headers}') + click.echo(f"\nListing models: {url}\n") + click.echo(f"Headers: {headers}") try: response = requests.get(url, headers=headers) - print_response(response, 'models') + print_response(response, "models") except requests.exceptions.ConnectionError: - click.echo(f'Error: Could not connect to {url}') + click.echo(f"Error: Could not connect to {url}") diff --git a/fedn/cli/package_cmd.py b/fedn/cli/package_cmd.py index a19ed2f9e..6d503d414 100644 --- a/fedn/cli/package_cmd.py +++ b/fedn/cli/package_cmd.py @@ -10,7 +10,7 @@ from .shared import CONTROLLER_DEFAULTS, get_api_url, get_token, print_response -@main.group('package') +@main.group("package") @click.pass_context def package_cmd(ctx): """ @@ -20,12 +20,12 @@ def package_cmd(ctx): pass -@package_cmd.command('create') -@click.option('-p', '--path', required=True, help='Path to package directory containing fedn.yaml') -@click.option('-n', '--name', required=False, default='package.tgz', help='Name of package tarball') +@package_cmd.command("create") +@click.option("-p", "--path", required=True, help="Path to package directory containing fedn.yaml") +@click.option("-n", "--name", required=False, default="package.tgz", help="Name of package tarball") @click.pass_context def create_cmd(ctx, path, name): - """ Create compute package. + """Create compute package. Make a tar.gz archive of folder given by --path @@ -33,7 +33,7 @@ def create_cmd(ctx, path, name): :param path: """ path = os.path.abspath(path) - yaml_file = os.path.join(path, 'fedn.yaml') + yaml_file = os.path.join(path, "fedn.yaml") if not os.path.exists(yaml_file): logger.error(f"Could not find fedn.yaml in {path}") exit(-1) @@ -43,12 +43,12 @@ def create_cmd(ctx, path, name): logger.info(f"Created package {name}") -@click.option('-p', '--protocol', required=False, default=CONTROLLER_DEFAULTS['protocol'], help='Communication protocol of controller (api)') -@click.option('-H', '--host', required=False, default=CONTROLLER_DEFAULTS['host'], help='Hostname of controller (api)') -@click.option('-P', '--port', required=False, default=CONTROLLER_DEFAULTS['port'], help='Port of controller (api)') -@click.option('-t', '--token', required=False, help='Authentication token') -@click.option('--n_max', required=False, help='Number of items to list') -@package_cmd.command('list') +@click.option("-p", "--protocol", required=False, default=CONTROLLER_DEFAULTS["protocol"], help="Communication protocol of controller (api)") +@click.option("-H", "--host", required=False, default=CONTROLLER_DEFAULTS["host"], help="Hostname of controller (api)") +@click.option("-P", "--port", required=False, default=CONTROLLER_DEFAULTS["port"], help="Port of controller (api)") +@click.option("-t", "--token", required=False, help="Authentication token") +@click.option("--n_max", required=False, help="Number of items to list") +@package_cmd.command("list") @click.pass_context def list_packages(ctx, protocol: str, host: str, port: str, token: str = None, n_max: int = None): """ @@ -58,22 +58,22 @@ def list_packages(ctx, protocol: str, host: str, port: str, token: str = None, n - result: list of packages """ - url = get_api_url(protocol=protocol, host=host, port=port, endpoint='packages') + url = get_api_url(protocol=protocol, host=host, port=port, endpoint="packages") headers = {} if n_max: - headers['X-Limit'] = n_max + headers["X-Limit"] = n_max _token = get_token(token) if _token: - headers['Authorization'] = _token + headers["Authorization"] = _token - click.echo(f'\nListing packages: {url}\n') - click.echo(f'Headers: {headers}') + click.echo(f"\nListing packages: {url}\n") + click.echo(f"Headers: {headers}") try: response = requests.get(url, headers=headers) - print_response(response, 'packages') + print_response(response, "packages") except requests.exceptions.ConnectionError: - click.echo(f'Error: Could not connect to {url}') + click.echo(f"Error: Could not connect to {url}") diff --git a/fedn/cli/round_cmd.py b/fedn/cli/round_cmd.py index 31f4accc4..ca23cafe7 100644 --- a/fedn/cli/round_cmd.py +++ b/fedn/cli/round_cmd.py @@ -1,4 +1,3 @@ - import click import requests @@ -6,7 +5,7 @@ from .shared import CONTROLLER_DEFAULTS, get_api_url, get_token, print_response -@main.group('round') +@main.group("round") @click.pass_context def round_cmd(ctx): """ @@ -16,12 +15,12 @@ def round_cmd(ctx): pass -@click.option('-p', '--protocol', required=False, default=CONTROLLER_DEFAULTS['protocol'], help='Communication protocol of controller (api)') -@click.option('-H', '--host', required=False, default=CONTROLLER_DEFAULTS['host'], help='Hostname of controller (api)') -@click.option('-P', '--port', required=False, default=CONTROLLER_DEFAULTS['port'], help='Port of controller (api)') -@click.option('-t', '--token', required=False, help='Authentication token') -@click.option('--n_max', required=False, help='Number of items to list') -@round_cmd.command('list') +@click.option("-p", "--protocol", required=False, default=CONTROLLER_DEFAULTS["protocol"], help="Communication protocol of controller (api)") +@click.option("-H", "--host", required=False, default=CONTROLLER_DEFAULTS["host"], help="Hostname of controller (api)") +@click.option("-P", "--port", required=False, default=CONTROLLER_DEFAULTS["port"], help="Port of controller (api)") +@click.option("-t", "--token", required=False, help="Authentication token") +@click.option("--n_max", required=False, help="Number of items to list") +@round_cmd.command("list") @click.pass_context def list_rounds(ctx, protocol: str, host: str, port: str, token: str = None, n_max: int = None): """ @@ -31,22 +30,22 @@ def list_rounds(ctx, protocol: str, host: str, port: str, token: str = None, n_m - result: list of rounds """ - url = get_api_url(protocol=protocol, host=host, port=port, endpoint='rounds') + url = get_api_url(protocol=protocol, host=host, port=port, endpoint="rounds") headers = {} if n_max: - headers['X-Limit'] = n_max + headers["X-Limit"] = n_max _token = get_token(token) if _token: - headers['Authorization'] = _token + headers["Authorization"] = _token - click.echo(f'\nListing rounds: {url}\n') - click.echo(f'Headers: {headers}') + click.echo(f"\nListing rounds: {url}\n") + click.echo(f"Headers: {headers}") try: response = requests.get(url, headers=headers) - print_response(response, 'rounds') + print_response(response, "rounds") except requests.exceptions.ConnectionError: - click.echo(f'Error: Could not connect to {url}') + click.echo(f"Error: Could not connect to {url}") diff --git a/fedn/cli/run_cmd.py b/fedn/cli/run_cmd.py index 87a54f7f1..b9fe4528e 100644 --- a/fedn/cli/run_cmd.py +++ b/fedn/cli/run_cmd.py @@ -22,7 +22,7 @@ def get_statestore_config_from_file(init): :param init: :return: """ - with open(init, 'r') as file: + with open(init, "r") as file: try: settings = dict(yaml.safe_load(file)) return settings @@ -31,7 +31,7 @@ def get_statestore_config_from_file(init): def check_helper_config_file(config): - control = config['control'] + control = config["control"] try: helper = control["helper"] except KeyError: @@ -40,7 +40,7 @@ def check_helper_config_file(config): return helper -@main.group('run') +@main.group("run") @click.pass_context def run_cmd(ctx): """ @@ -50,25 +50,25 @@ def run_cmd(ctx): pass -@run_cmd.command('build') -@click.option('-p', '--path', required=True, help='Path to package directory containing fedn.yaml') +@run_cmd.command("build") +@click.option("-p", "--path", required=True, help="Path to package directory containing fedn.yaml") @click.pass_context def build_cmd(ctx, path): - """ Execute 'build' entrypoint in fedn.yaml. + """Execute 'build' entrypoint in fedn.yaml. :param ctx: :param path: Path to folder containing fedn.yaml :type path: str """ path = os.path.abspath(path) - yaml_file = os.path.join(path, 'fedn.yaml') + yaml_file = os.path.join(path, "fedn.yaml") if not os.path.exists(yaml_file): logger.error(f"Could not find fedn.yaml in {path}") exit(-1) config = _read_yaml_file(yaml_file) # Check that build is defined in fedn.yaml under entry_points - if 'build' not in config['entry_points']: + if "build" not in config["entry_points"]: logger.error("No build command defined in fedn.yaml") exit(-1) @@ -82,32 +82,49 @@ def build_cmd(ctx, path): shutil.rmtree(dispatcher.python_env_path) -@run_cmd.command('client') -@click.option('-d', '--discoverhost', required=False, help='Hostname for discovery services(reducer).') -@click.option('-p', '--discoverport', required=False, help='Port for discovery services (reducer).') -@click.option('--token', required=False, help='Set token provided by reducer if enabled') -@click.option('-n', '--name', required=False, default="client" + str(uuid.uuid4())[:8]) -@click.option('-i', '--client_id', required=False) -@click.option('--local-package', is_flag=True, help='Enable local compute package') -@click.option('--force-ssl', is_flag=True, help='Force SSL/TLS for REST service') -@click.option('-u', '--dry-run', required=False, default=False) -@click.option('-s', '--secure', required=False, default=False) -@click.option('-pc', '--preshared-cert', required=False, default=False) -@click.option('-v', '--verify', is_flag=True, help='Verify SSL/TLS for REST service') -@click.option('-c', '--preferred-combiner', required=False, default=False) -@click.option('-va', '--validator', required=False, default=True) -@click.option('-tr', '--trainer', required=False, default=True) -@click.option('-in', '--init', required=False, default=None, - help='Set to a filename to (re)init client from file state.') -@click.option('-l', '--logfile', required=False, default=None, - help='Set logfile for client log to file.') -@click.option('--heartbeat-interval', required=False, default=2) -@click.option('--reconnect-after-missed-heartbeat', required=False, default=30) -@click.option('--verbosity', required=False, default='INFO', type=click.Choice(['CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG'], case_sensitive=False)) +@run_cmd.command("client") +@click.option("-d", "--discoverhost", required=False, help="Hostname for discovery services(reducer).") +@click.option("-p", "--discoverport", required=False, help="Port for discovery services (reducer).") +@click.option("--token", required=False, help="Set token provided by reducer if enabled") +@click.option("-n", "--name", required=False, default="client" + str(uuid.uuid4())[:8]) +@click.option("-i", "--client_id", required=False) +@click.option("--local-package", is_flag=True, help="Enable local compute package") +@click.option("--force-ssl", is_flag=True, help="Force SSL/TLS for REST service") +@click.option("-u", "--dry-run", required=False, default=False) +@click.option("-s", "--secure", required=False, default=False) +@click.option("-pc", "--preshared-cert", required=False, default=False) +@click.option("-v", "--verify", is_flag=True, help="Verify SSL/TLS for REST service") +@click.option("-c", "--preferred-combiner", required=False, default=False) +@click.option("-va", "--validator", required=False, default=True) +@click.option("-tr", "--trainer", required=False, default=True) +@click.option("-in", "--init", required=False, default=None, help="Set to a filename to (re)init client from file state.") +@click.option("-l", "--logfile", required=False, default=None, help="Set logfile for client log to file.") +@click.option("--heartbeat-interval", required=False, default=2) +@click.option("--reconnect-after-missed-heartbeat", required=False, default=30) +@click.option("--verbosity", required=False, default="INFO", type=click.Choice(["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"], case_sensitive=False)) @click.pass_context -def client_cmd(ctx, discoverhost, discoverport, token, name, client_id, local_package, force_ssl, dry_run, secure, preshared_cert, - verify, preferred_combiner, validator, trainer, init, logfile, heartbeat_interval, reconnect_after_missed_heartbeat, - verbosity): +def client_cmd( + ctx, + discoverhost, + discoverport, + token, + name, + client_id, + local_package, + force_ssl, + dry_run, + secure, + preshared_cert, + verify, + preferred_combiner, + validator, + trainer, + init, + logfile, + heartbeat_interval, + reconnect_after_missed_heartbeat, + verbosity, +): """ :param ctx: @@ -130,49 +147,58 @@ def client_cmd(ctx, discoverhost, discoverport, token, name, client_id, local_pa :return: """ remote = False if local_package else True - config = {'discover_host': discoverhost, 'discover_port': discoverport, 'token': token, 'name': name, - 'client_id': client_id, 'remote_compute_context': remote, 'force_ssl': force_ssl, 'dry_run': dry_run, 'secure': secure, - 'preshared_cert': preshared_cert, 'verify': verify, 'preferred_combiner': preferred_combiner, - 'validator': validator, 'trainer': trainer, 'logfile': logfile, 'heartbeat_interval': heartbeat_interval, - 'reconnect_after_missed_heartbeat': reconnect_after_missed_heartbeat, 'verbosity': verbosity} + config = { + "discover_host": discoverhost, + "discover_port": discoverport, + "token": token, + "name": name, + "client_id": client_id, + "remote_compute_context": remote, + "force_ssl": force_ssl, + "dry_run": dry_run, + "secure": secure, + "preshared_cert": preshared_cert, + "verify": verify, + "preferred_combiner": preferred_combiner, + "validator": validator, + "trainer": trainer, + "logfile": logfile, + "heartbeat_interval": heartbeat_interval, + "reconnect_after_missed_heartbeat": reconnect_after_missed_heartbeat, + "verbosity": verbosity, + } click.echo( - click.style( - '\n*** fedn run client is deprecated and will be removed. Please use fedn client start instead. ***\n', - blink=True, - bold=True, - fg='red' - ) + click.style("\n*** fedn run client is deprecated and will be removed. Please use fedn client start instead. ***\n", blink=True, bold=True, fg="red") ) if init: apply_config(init, config) - click.echo(f'\nClient configuration loaded from file: {init}') - click.echo('Values set in file override defaults and command line arguments...\n') + click.echo(f"\nClient configuration loaded from file: {init}") + click.echo("Values set in file override defaults and command line arguments...\n") try: validate_client_config(config) except InvalidClientConfig as e: - click.echo(f'Error: {e}') + click.echo(f"Error: {e}") return client = Client(config) client.run() -@run_cmd.command('combiner') -@click.option('-d', '--discoverhost', required=False, help='Hostname for discovery services (reducer).') -@click.option('-p', '--discoverport', required=False, help='Port for discovery services (reducer).') -@click.option('-t', '--token', required=False, help='Set token provided by reducer if enabled') -@click.option('-n', '--name', required=False, default="combiner" + str(uuid.uuid4())[:8], help='Set name for combiner.') -@click.option('-h', '--host', required=False, default="combiner", help='Set hostname.') -@click.option('-i', '--port', required=False, default=12080, help='Set port.') -@click.option('-f', '--fqdn', required=False, default=None, help='Set fully qualified domain name') -@click.option('-s', '--secure', is_flag=True, help='Enable SSL/TLS encrypted gRPC channels.') -@click.option('-v', '--verify', is_flag=True, help='Verify SSL/TLS for REST discovery service (reducer)') -@click.option('-c', '--max_clients', required=False, default=30, help='The maximal number of client connections allowed.') -@click.option('-in', '--init', required=False, default=None, - help='Path to configuration file to (re)init combiner.') +@run_cmd.command("combiner") +@click.option("-d", "--discoverhost", required=False, help="Hostname for discovery services (reducer).") +@click.option("-p", "--discoverport", required=False, help="Port for discovery services (reducer).") +@click.option("-t", "--token", required=False, help="Set token provided by reducer if enabled") +@click.option("-n", "--name", required=False, default="combiner" + str(uuid.uuid4())[:8], help="Set name for combiner.") +@click.option("-h", "--host", required=False, default="combiner", help="Set hostname.") +@click.option("-i", "--port", required=False, default=12080, help="Set port.") +@click.option("-f", "--fqdn", required=False, default=None, help="Set fully qualified domain name") +@click.option("-s", "--secure", is_flag=True, help="Enable SSL/TLS encrypted gRPC channels.") +@click.option("-v", "--verify", is_flag=True, help="Verify SSL/TLS for REST discovery service (reducer)") +@click.option("-c", "--max_clients", required=False, default=30, help="The maximal number of client connections allowed.") +@click.option("-in", "--init", required=False, default=None, help="Path to configuration file to (re)init combiner.") @click.pass_context def combiner_cmd(ctx, discoverhost, discoverport, token, name, host, port, fqdn, secure, verify, max_clients, init): """ @@ -188,22 +214,27 @@ def combiner_cmd(ctx, discoverhost, discoverport, token, name, host, port, fqdn, :param max_clients: :param init: """ - config = {'discover_host': discoverhost, 'discover_port': discoverport, 'token': token, 'host': host, - 'port': port, 'fqdn': fqdn, 'name': name, 'secure': secure, 'verify': verify, 'max_clients': max_clients} + config = { + "discover_host": discoverhost, + "discover_port": discoverport, + "token": token, + "host": host, + "port": port, + "fqdn": fqdn, + "name": name, + "secure": secure, + "verify": verify, + "max_clients": max_clients, + } click.echo( - click.style( - '\n*** fedn run combiner is deprecated and will be removed. Please use fedn combiner start instead. ***\n', - blink=True, - bold=True, - fg='red' - ) + click.style("\n*** fedn run combiner is deprecated and will be removed. Please use fedn combiner start instead. ***\n", blink=True, bold=True, fg="red") ) if init: apply_config(init, config) - click.echo(f'\nCombiner configuration loaded from file: {init}') - click.echo('Values set in file override defaults and command line arguments...\n') + click.echo(f"\nCombiner configuration loaded from file: {init}") + click.echo("Values set in file override defaults and command line arguments...\n") combiner = Combiner(config) combiner.run() diff --git a/fedn/cli/session_cmd.py b/fedn/cli/session_cmd.py index 37eb3a8a6..55597b5b3 100644 --- a/fedn/cli/session_cmd.py +++ b/fedn/cli/session_cmd.py @@ -1,4 +1,3 @@ - import click import requests @@ -6,7 +5,7 @@ from .shared import CONTROLLER_DEFAULTS, get_api_url, get_token, print_response -@main.group('session') +@main.group("session") @click.pass_context def session_cmd(ctx): """ @@ -16,12 +15,12 @@ def session_cmd(ctx): pass -@click.option('-p', '--protocol', required=False, default=CONTROLLER_DEFAULTS['protocol'], help='Communication protocol of controller (api)') -@click.option('-H', '--host', required=False, default=CONTROLLER_DEFAULTS['host'], help='Hostname of controller (api)') -@click.option('-P', '--port', required=False, default=CONTROLLER_DEFAULTS['port'], help='Port of controller (api)') -@click.option('-t', '--token', required=False, help='Authentication token') -@click.option('--n_max', required=False, help='Number of items to list') -@session_cmd.command('list') +@click.option("-p", "--protocol", required=False, default=CONTROLLER_DEFAULTS["protocol"], help="Communication protocol of controller (api)") +@click.option("-H", "--host", required=False, default=CONTROLLER_DEFAULTS["host"], help="Hostname of controller (api)") +@click.option("-P", "--port", required=False, default=CONTROLLER_DEFAULTS["port"], help="Port of controller (api)") +@click.option("-t", "--token", required=False, help="Authentication token") +@click.option("--n_max", required=False, help="Number of items to list") +@session_cmd.command("list") @click.pass_context def list_sessions(ctx, protocol: str, host: str, port: str, token: str = None, n_max: int = None): """ @@ -31,22 +30,22 @@ def list_sessions(ctx, protocol: str, host: str, port: str, token: str = None, n - result: list of sessions """ - url = get_api_url(protocol=protocol, host=host, port=port, endpoint='sessions') + url = get_api_url(protocol=protocol, host=host, port=port, endpoint="sessions") headers = {} if n_max: - headers['X-Limit'] = n_max + headers["X-Limit"] = n_max _token = get_token(token) if _token: - headers['Authorization'] = _token + headers["Authorization"] = _token - click.echo(f'\nListing sessions: {url}\n') - click.echo(f'Headers: {headers}') + click.echo(f"\nListing sessions: {url}\n") + click.echo(f"Headers: {headers}") try: response = requests.get(url, headers=headers) - print_response(response, 'sessions') + print_response(response, "sessions") except requests.exceptions.ConnectionError: - click.echo(f'Error: Could not connect to {url}') + click.echo(f"Error: Could not connect to {url}") diff --git a/fedn/cli/shared.py b/fedn/cli/shared.py index 81e5ebd07..2500d9e2b 100644 --- a/fedn/cli/shared.py +++ b/fedn/cli/shared.py @@ -5,28 +5,16 @@ from fedn.common.log_config import logger -CONTROLLER_DEFAULTS = { - 'protocol': 'http', - 'host': 'localhost', - 'port': 8092, - 'debug': False -} +CONTROLLER_DEFAULTS = {"protocol": "http", "host": "localhost", "port": 8092, "debug": False} -COMBINER_DEFAULTS = { - 'discover_host': 'localhost', - 'discover_port': 8092, - 'host': 'localhost', - 'port': 12080, - "name": "combiner", - "max_clients": 30 -} +COMBINER_DEFAULTS = {"discover_host": "localhost", "discover_port": 8092, "host": "localhost", "port": 12080, "name": "combiner", "max_clients": 30} CLIENT_DEFAULTS = { - 'discover_host': 'localhost', - 'discover_port': 8092, + "discover_host": "localhost", + "discover_port": 8092, } -API_VERSION = 'v1' +API_VERSION = "v1" def apply_config(path: str, config: dict): @@ -36,11 +24,11 @@ def apply_config(path: str, config: dict): :param config: Client config (dict). """ - with open(path, 'r') as file: + with open(path, "r") as file: try: settings = dict(yaml.safe_load(file)) except Exception: - logger.error('Failed to read config from settings file, exiting.') + logger.error("Failed to read config from settings file, exiting.") return for key, val in settings.items(): @@ -48,16 +36,16 @@ def apply_config(path: str, config: dict): def get_api_url(protocol: str, host: str, port: str, endpoint: str) -> str: - _url = os.environ.get('FEDN_CONTROLLER_URL') + _url = os.environ.get("FEDN_CONTROLLER_URL") if _url: - return f'{_url}/api/{API_VERSION}/{endpoint}/' + return f"{_url}/api/{API_VERSION}/{endpoint}/" - _protocol = protocol or os.environ.get('FEDN_CONTROLLER_PROTOCOL') or CONTROLLER_DEFAULTS['protocol'] - _host = host or os.environ.get('FEDN_CONTROLLER_HOST') or CONTROLLER_DEFAULTS['host'] - _port = port or os.environ.get('FEDN_CONTROLLER_PORT') or CONTROLLER_DEFAULTS['port'] + _protocol = protocol or os.environ.get("FEDN_CONTROLLER_PROTOCOL") or CONTROLLER_DEFAULTS["protocol"] + _host = host or os.environ.get("FEDN_CONTROLLER_HOST") or CONTROLLER_DEFAULTS["host"] + _port = port or os.environ.get("FEDN_CONTROLLER_PORT") or CONTROLLER_DEFAULTS["port"] - return f'{_protocol}://{_host}:{_port}/api/{API_VERSION}/{endpoint}/' + return f"{_protocol}://{_host}:{_port}/api/{API_VERSION}/{endpoint}/" def get_token(token: str) -> str: @@ -72,7 +60,7 @@ def get_token(token: str) -> str: def get_client_package_dir(path: str) -> str: - return path or os.environ.get('FEDN_PACKAGE_DIR', None) + return path or os.environ.get("FEDN_PACKAGE_DIR", None) # Print response from api (list of entities) @@ -90,15 +78,15 @@ def print_response(response, entity_name: str): if response.status_code == 200: json_data = response.json() count, result = json_data.values() - click.echo(f'Found {count} {entity_name}') - click.echo('\n---------------------------------\n') + click.echo(f"Found {count} {entity_name}") + click.echo("\n---------------------------------\n") for obj in result: - click.echo('{') + click.echo("{") for k, v in obj.items(): - click.echo(f'\t{k}: {v}') - click.echo('}') + click.echo(f"\t{k}: {v}") + click.echo("}") elif response.status_code == 500: json_data = response.json() click.echo(f'Error: {json_data["message"]}') else: - click.echo(f'Error: {response.status_code}') + click.echo(f"Error: {response.status_code}") diff --git a/fedn/cli/status_cmd.py b/fedn/cli/status_cmd.py index 457fc9c00..a4f17e349 100644 --- a/fedn/cli/status_cmd.py +++ b/fedn/cli/status_cmd.py @@ -5,7 +5,7 @@ from .shared import CONTROLLER_DEFAULTS, get_api_url, get_token, print_response -@main.group('status') +@main.group("status") @click.pass_context def status_cmd(ctx): """ @@ -15,12 +15,12 @@ def status_cmd(ctx): pass -@click.option('-p', '--protocol', required=False, default=CONTROLLER_DEFAULTS['protocol'], help='Communication protocol of controller (api)') -@click.option('-H', '--host', required=False, default=CONTROLLER_DEFAULTS['host'], help='Hostname of controller (api)') -@click.option('-P', '--port', required=False, default=CONTROLLER_DEFAULTS['port'], help='Port of controller (api)') -@click.option('-t', '--token', required=False, help='Authentication token') -@click.option('--n_max', required=False, help='Number of items to list') -@status_cmd.command('list') +@click.option("-p", "--protocol", required=False, default=CONTROLLER_DEFAULTS["protocol"], help="Communication protocol of controller (api)") +@click.option("-H", "--host", required=False, default=CONTROLLER_DEFAULTS["host"], help="Hostname of controller (api)") +@click.option("-P", "--port", required=False, default=CONTROLLER_DEFAULTS["port"], help="Port of controller (api)") +@click.option("-t", "--token", required=False, help="Authentication token") +@click.option("--n_max", required=False, help="Number of items to list") +@status_cmd.command("list") @click.pass_context def list_statuses(ctx, protocol: str, host: str, port: str, token: str = None, n_max: int = None): """ @@ -30,22 +30,22 @@ def list_statuses(ctx, protocol: str, host: str, port: str, token: str = None, n - result: list of statuses """ - url = get_api_url(protocol=protocol, host=host, port=port, endpoint='statuses') + url = get_api_url(protocol=protocol, host=host, port=port, endpoint="statuses") headers = {} if n_max: - headers['X-Limit'] = n_max + headers["X-Limit"] = n_max _token = get_token(token) if _token: - headers['Authorization'] = _token + headers["Authorization"] = _token - click.echo(f'\nListing statuses: {url}\n') - click.echo(f'Headers: {headers}') + click.echo(f"\nListing statuses: {url}\n") + click.echo(f"Headers: {headers}") try: response = requests.get(url, headers=headers) - print_response(response, 'statuses') + print_response(response, "statuses") except requests.exceptions.ConnectionError: - click.echo(f'Error: Could not connect to {url}') + click.echo(f"Error: Could not connect to {url}") diff --git a/fedn/cli/validation_cmd.py b/fedn/cli/validation_cmd.py index 3707f9bb8..055be0c65 100644 --- a/fedn/cli/validation_cmd.py +++ b/fedn/cli/validation_cmd.py @@ -5,7 +5,7 @@ from .shared import CONTROLLER_DEFAULTS, get_api_url, get_token, print_response -@main.group('validation') +@main.group("validation") @click.pass_context def validation_cmd(ctx): """ @@ -15,12 +15,12 @@ def validation_cmd(ctx): pass -@click.option('-p', '--protocol', required=False, default=CONTROLLER_DEFAULTS['protocol'], help='Communication protocol of controller (api)') -@click.option('-H', '--host', required=False, default=CONTROLLER_DEFAULTS['host'], help='Hostname of controller (api)') -@click.option('-P', '--port', required=False, default=CONTROLLER_DEFAULTS['port'], help='Port of controller (api)') -@click.option('-t', '--token', required=False, help='Authentication token') -@click.option('--n_max', required=False, help='Number of items to list') -@validation_cmd.command('list') +@click.option("-p", "--protocol", required=False, default=CONTROLLER_DEFAULTS["protocol"], help="Communication protocol of controller (api)") +@click.option("-H", "--host", required=False, default=CONTROLLER_DEFAULTS["host"], help="Hostname of controller (api)") +@click.option("-P", "--port", required=False, default=CONTROLLER_DEFAULTS["port"], help="Port of controller (api)") +@click.option("-t", "--token", required=False, help="Authentication token") +@click.option("--n_max", required=False, help="Number of items to list") +@validation_cmd.command("list") @click.pass_context def list_validations(ctx, protocol: str, host: str, port: str, token: str = None, n_max: int = None): """ @@ -30,22 +30,22 @@ def list_validations(ctx, protocol: str, host: str, port: str, token: str = None - result: list of validations """ - url = get_api_url(protocol=protocol, host=host, port=port, endpoint='validations') + url = get_api_url(protocol=protocol, host=host, port=port, endpoint="validations") headers = {} if n_max: - headers['X-Limit'] = n_max + headers["X-Limit"] = n_max _token = get_token(token) if _token: - headers['Authorization'] = _token + headers["Authorization"] = _token - click.echo(f'\nListing validations: {url}\n') - click.echo(f'Headers: {headers}') + click.echo(f"\nListing validations: {url}\n") + click.echo(f"Headers: {headers}") try: response = requests.get(url, headers=headers) - print_response(response, 'validations') + print_response(response, "validations") except requests.exceptions.ConnectionError: - click.echo(f'Error: Could not connect to {url}') + click.echo(f"Error: Could not connect to {url}") diff --git a/fedn/common/certificate/certificate.py b/fedn/common/certificate/certificate.py index a2c059748..857a05e7c 100644 --- a/fedn/common/certificate/certificate.py +++ b/fedn/common/certificate/certificate.py @@ -13,19 +13,18 @@ class Certificate: Utility to generate unsigned certificates. """ + CERT_NAME = "cert.pem" KEY_NAME = "key.pem" BITS = 2048 def __init__(self, cwd, name=None, key_name="key.pem", cert_name="cert.pem", create_dirs=True): - try: os.makedirs(cwd) except OSError: logger.info("Directory exists, will store all cert and keys here.") else: - logger.info( - "Successfully created the directory to store cert and keys in {}".format(cwd)) + logger.info("Successfully created the directory to store cert and keys in {}".format(cwd)) self.key_path = os.path.join(cwd, key_name) self.cert_path = os.path.join(cwd, cert_name) @@ -35,7 +34,9 @@ def __init__(self, cwd, name=None, key_name="key.pem", cert_name="cert.pem", cre else: self.name = str(uuid.uuid4()) - def gen_keypair(self, ): + def gen_keypair( + self, + ): """ Generate keypair. @@ -70,21 +71,19 @@ def set_keypair_raw(self, certificate, privatekey): :param privatekey: """ with open(self.key_path, "wb") as keyfile: - keyfile.write(crypto.dump_privatekey( - crypto.FILETYPE_PEM, privatekey)) + keyfile.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, privatekey)) with open(self.cert_path, "wb") as certfile: - certfile.write(crypto.dump_certificate( - crypto.FILETYPE_PEM, certificate)) + certfile.write(crypto.dump_certificate(crypto.FILETYPE_PEM, certificate)) def get_keypair_raw(self): """ :return: """ - with open(self.key_path, 'rb') as keyfile: + with open(self.key_path, "rb") as keyfile: key_buf = keyfile.read() - with open(self.cert_path, 'rb') as certfile: + with open(self.cert_path, "rb") as certfile: cert_buf = certfile.read() return copy.deepcopy(cert_buf), copy.deepcopy(key_buf) @@ -93,7 +92,7 @@ def get_key(self): :return: """ - with open(self.key_path, 'rb') as keyfile: + with open(self.key_path, "rb") as keyfile: key_buf = keyfile.read() key = crypto.load_privatekey(crypto.FILETYPE_PEM, key_buf) return key @@ -103,7 +102,7 @@ def get_cert(self): :return: """ - with open(self.cert_path, 'rb') as certfile: + with open(self.cert_path, "rb") as certfile: cert_buf = certfile.read() cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert_buf) return cert diff --git a/fedn/common/certificate/certificatemanager.py b/fedn/common/certificate/certificatemanager.py index 3d34fa1ad..ce165d862 100644 --- a/fedn/common/certificate/certificatemanager.py +++ b/fedn/common/certificate/certificatemanager.py @@ -10,7 +10,6 @@ class CertificateManager: """ def __init__(self, directory): - self.directory = directory self.certificates = [] self.allowed = dict() @@ -28,8 +27,7 @@ def get_or_create(self, name): if search: return search else: - cert = Certificate(self.directory, name=name, - cert_name=name + '-cert.pem', key_name=name + '-key.pem') + cert = Certificate(self.directory, name=name, cert_name=name + "-cert.pem", key_name=name + "-key.pem") cert.gen_keypair() self.certificates.append(cert) return cert @@ -53,12 +51,11 @@ def load_all(self): """ for filename in sorted(os.listdir(self.directory)): - if filename.endswith('cert.pem'): - name = filename.split('-')[0] - key_name = name + '-key.pem' + if filename.endswith("cert.pem"): + name = filename.split("-")[0] + key_name = name + "-key.pem" - c = Certificate(self.directory, name=name, - cert_name=filename, key_name=key_name) + c = Certificate(self.directory, name=name, cert_name=filename, key_name=key_name) self.certificates.append(c) def find(self, name): diff --git a/fedn/common/config.py b/fedn/common/config.py index 71f4b1698..4864ce1ef 100644 --- a/fedn/common/config.py +++ b/fedn/common/config.py @@ -24,12 +24,8 @@ def get_environment_config(): global STATESTORE_CONFIG global MODELSTORAGE_CONFIG - STATESTORE_CONFIG = os.environ.get( - "STATESTORE_CONFIG", "/workspaces/fedn/config/settings-reducer.yaml.template" - ) - MODELSTORAGE_CONFIG = os.environ.get( - "MODELSTORAGE_CONFIG", "/workspaces/fedn/config/settings-reducer.yaml.template" - ) + STATESTORE_CONFIG = os.environ.get("STATESTORE_CONFIG", "/workspaces/fedn/config/settings-reducer.yaml.template") + MODELSTORAGE_CONFIG = os.environ.get("MODELSTORAGE_CONFIG", "/workspaces/fedn/config/settings-reducer.yaml.template") def get_statestore_config(file=None): diff --git a/fedn/common/log_config.py b/fedn/common/log_config.py index 0e61a6a83..b8aa1218b 100644 --- a/fedn/common/log_config.py +++ b/fedn/common/log_config.py @@ -5,13 +5,7 @@ import requests import urllib3 -log_levels = { - "DEBUG": logging.DEBUG, - "INFO": logging.INFO, - "WARNING": logging.WARNING, - "ERROR": logging.ERROR, - "CRITICAL": logging.CRITICAL -} +log_levels = {"DEBUG": logging.DEBUG, "INFO": logging.INFO, "WARNING": logging.WARNING, "ERROR": logging.ERROR, "CRITICAL": logging.CRITICAL} urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) @@ -21,12 +15,12 @@ logger = logging.getLogger("fedn") logger.addHandler(handler) logger.setLevel(logging.DEBUG) -formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S') +formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S") handler.setFormatter(formatter) class StudioHTTPHandler(logging.handlers.HTTPHandler): - def __init__(self, host, url, method='POST', token=None): + def __init__(self, host, url, method="POST", token=None): super().__init__(host, url, method) self.token = token @@ -34,41 +28,35 @@ def emit(self, record): log_entry = self.mapLogRecord(record) log_entry = { - "msg": log_entry['msg'], - "levelname": log_entry['levelname'], + "msg": log_entry["msg"], + "levelname": log_entry["levelname"], "project": os.environ.get("PROJECT_ID"), - "appinstance": os.environ.get("APP_ID") - + "appinstance": os.environ.get("APP_ID"), } # Setup headers headers = { - 'Content-type': 'application/json', + "Content-type": "application/json", } if self.token: - remote_token_protocol = os.environ.get('FEDN_REMOTE_LOG_TOKEN_PROTOCOL', "Token") - headers['Authorization'] = f"{remote_token_protocol} {self.token}" - if self.method.lower() == 'post': - requests.post(self.host+self.url, json=log_entry, headers=headers) + remote_token_protocol = os.environ.get("FEDN_REMOTE_LOG_TOKEN_PROTOCOL", "Token") + headers["Authorization"] = f"{remote_token_protocol} {self.token}" + if self.method.lower() == "post": + requests.post(self.host + self.url, json=log_entry, headers=headers) else: # No other methods implemented. return # Remote logging can only be configured via environment variables for now. -REMOTE_LOG_SERVER = os.environ.get('FEDN_REMOTE_LOG_SERVER', False) -REMOTE_LOG_PATH = os.environ.get('FEDN_REMOTE_LOG_PATH', False) -REMOTE_LOG_LEVEL = os.environ.get('FEDN_REMOTE_LOG_LEVEL', 'INFO') +REMOTE_LOG_SERVER = os.environ.get("FEDN_REMOTE_LOG_SERVER", False) +REMOTE_LOG_PATH = os.environ.get("FEDN_REMOTE_LOG_PATH", False) +REMOTE_LOG_LEVEL = os.environ.get("FEDN_REMOTE_LOG_LEVEL", "INFO") if REMOTE_LOG_SERVER: rloglevel = log_levels.get(REMOTE_LOG_LEVEL, logging.INFO) - remote_token = os.environ.get('FEDN_REMOTE_LOG_TOKEN', None) - - http_handler = StudioHTTPHandler( - host=REMOTE_LOG_SERVER, - url=REMOTE_LOG_PATH, - method='POST', - token=remote_token - ) + remote_token = os.environ.get("FEDN_REMOTE_LOG_TOKEN", None) + + http_handler = StudioHTTPHandler(host=REMOTE_LOG_SERVER, url=REMOTE_LOG_PATH, method="POST", token=remote_token) http_handler.setLevel(rloglevel) logger.addHandler(http_handler) @@ -79,11 +67,11 @@ def set_log_level_from_string(level_str): """ # Mapping of string representation to logging constants level_mapping = { - 'CRITICAL': logging.CRITICAL, - 'ERROR': logging.ERROR, - 'WARNING': logging.WARNING, - 'INFO': logging.INFO, - 'DEBUG': logging.DEBUG, + "CRITICAL": logging.CRITICAL, + "ERROR": logging.ERROR, + "WARNING": logging.WARNING, + "INFO": logging.INFO, + "DEBUG": logging.DEBUG, } # Get the logging level from the mapping diff --git a/fedn/network/api/auth.py b/fedn/network/api/auth.py index bf43c2f69..e3240e743 100644 --- a/fedn/network/api/auth.py +++ b/fedn/network/api/auth.py @@ -3,16 +3,20 @@ import jwt from flask import jsonify, request -from fedn.common.config import (FEDN_AUTH_SCHEME, - FEDN_AUTH_WHITELIST_URL_PREFIX, - FEDN_JWT_ALGORITHM, FEDN_JWT_CUSTOM_CLAIM_KEY, - FEDN_JWT_CUSTOM_CLAIM_VALUE, SECRET_KEY) +from fedn.common.config import ( + FEDN_AUTH_SCHEME, + FEDN_AUTH_WHITELIST_URL_PREFIX, + FEDN_JWT_ALGORITHM, + FEDN_JWT_CUSTOM_CLAIM_KEY, + FEDN_JWT_CUSTOM_CLAIM_VALUE, + SECRET_KEY, +) def check_role_claims(payload, role): - if 'role' not in payload: + if "role" not in payload: return False - if payload['role'] != role: + if payload["role"] != role: return False return True @@ -41,30 +45,29 @@ def actual_decorator(func): def decorated(*args, **kwargs): if if_whitelisted_url_prefix(request.path): return func(*args, **kwargs) - token = request.headers.get('Authorization') + token = request.headers.get("Authorization") if not token: - return jsonify({'message': 'Missing token'}), 401 + return jsonify({"message": "Missing token"}), 401 # Get token from the header Bearer if token.startswith(FEDN_AUTH_SCHEME): - token = token.split(' ')[1] + token = token.split(" ")[1] else: - return jsonify({'message': - f'Invalid token scheme, expected {FEDN_AUTH_SCHEME}' - }), 401 + return jsonify({"message": f"Invalid token scheme, expected {FEDN_AUTH_SCHEME}"}), 401 try: payload = jwt.decode(token, SECRET_KEY, algorithms=[FEDN_JWT_ALGORITHM]) if not check_role_claims(payload, role): - return jsonify({'message': 'Invalid token'}), 401 + return jsonify({"message": "Invalid token"}), 401 if not check_custom_claims(payload): - return jsonify({'message': 'Invalid token'}), 401 + return jsonify({"message": "Invalid token"}), 401 except jwt.ExpiredSignatureError: - return jsonify({'message': 'Token expired'}), 401 + return jsonify({"message": "Token expired"}), 401 except jwt.InvalidTokenError: - return jsonify({'message': 'Invalid token'}), 401 + return jsonify({"message": "Invalid token"}), 401 return func(*args, **kwargs) return decorated + return actual_decorator diff --git a/fedn/network/api/client.py b/fedn/network/api/client.py index 60b4f3cd9..43678fbcf 100644 --- a/fedn/network/api/client.py +++ b/fedn/network/api/client.py @@ -2,11 +2,11 @@ import requests -__all__ = ['APIClient'] +__all__ = ["APIClient"] class APIClient: - """ An API client for interacting with the statestore and controller. + """An API client for interacting with the statestore and controller. :param host: The host of the api server. :type host: str @@ -37,34 +37,34 @@ def __init__(self, host, port=None, secure=False, verify=False, token=None, auth def _get_url(self, endpoint): if self.secure: - protocol = 'https' + protocol = "https" else: - protocol = 'http' + protocol = "http" if self.port: - return f'{protocol}://{self.host}:{self.port}/{endpoint}' - return f'{protocol}://{self.host}/{endpoint}' + return f"{protocol}://{self.host}:{self.port}/{endpoint}" + return f"{protocol}://{self.host}/{endpoint}" def _get_url_api_v1(self, endpoint): - return self._get_url(f'api/v1/{endpoint}') + return self._get_url(f"api/v1/{endpoint}") # --- Clients --- # def get_client(self, id: str): - """ Get a client from the statestore. + """Get a client from the statestore. :param id: The client id to get. :type id: str :return: Client. :rtype: dict """ - response = requests.get(self._get_url_api_v1(f'clients/{id}'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1(f"clients/{id}"), verify=self.verify, headers=self.headers) _json = response.json() return _json def get_clients(self, n_max: int = None): - """ Get clients from the statestore. + """Get clients from the statestore. :param n_max: The maximum number of clients to get (If none all will be fetched). :type n_max: int @@ -74,28 +74,28 @@ def get_clients(self, n_max: int = None): _headers = self.headers.copy() if n_max: - _headers['X-Limit'] = str(n_max) + _headers["X-Limit"] = str(n_max) - response = requests.get(self._get_url_api_v1('clients/'), verify=self.verify, headers=_headers) + response = requests.get(self._get_url_api_v1("clients/"), verify=self.verify, headers=_headers) _json = response.json() return _json def get_clients_count(self): - """ Get the number of clients in the statestore. + """Get the number of clients in the statestore. :return: The number of clients. :rtype: dict """ - response = requests.get(self._get_url_api_v1('clients/count'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1("clients/count"), verify=self.verify, headers=self.headers) _json = response.json() return _json def get_client_config(self, checksum=True): - """ Get client config from controller. Optionally include the checksum. + """Get client config from controller. Optionally include the checksum. The config is used for clients to connect to the controller and ask for combiner assignment. :param checksum: Whether to include the checksum of the package. @@ -103,18 +103,16 @@ def get_client_config(self, checksum=True): :return: The client configuration. :rtype: dict """ - _params = { - 'checksum': "true" if checksum else "false" - } + _params = {"checksum": "true" if checksum else "false"} - response = requests.get(self._get_url('get_client_config'), params=_params, verify=self.verify, headers=self.headers) + response = requests.get(self._get_url("get_client_config"), params=_params, verify=self.verify, headers=self.headers) _json = response.json() return _json def get_active_clients(self, combiner_id: str = None, n_max: int = None): - """ Get active clients from the statestore. + """Get active clients from the statestore. :param combiner_id: The combiner id to get active clients for. :type combiner_id: str @@ -126,14 +124,14 @@ def get_active_clients(self, combiner_id: str = None, n_max: int = None): _params = {"status": "online"} if combiner_id: - _params['combiner'] = combiner_id + _params["combiner"] = combiner_id _headers = self.headers.copy() if n_max: - _headers['X-Limit'] = str(n_max) + _headers["X-Limit"] = str(n_max) - response = requests.get(self._get_url_api_v1('clients/'), params=_params, verify=self.verify, headers=_headers) + response = requests.get(self._get_url_api_v1("clients/"), params=_params, verify=self.verify, headers=_headers) _json = response.json() @@ -142,21 +140,21 @@ def get_active_clients(self, combiner_id: str = None, n_max: int = None): # --- Combiners --- # def get_combiner(self, id: str): - """ Get a combiner from the statestore. + """Get a combiner from the statestore. :param id: The combiner id to get. :type id: str :return: Combiner. :rtype: dict """ - response = requests.get(self._get_url_api_v1(f'combiners/{id}'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1(f"combiners/{id}"), verify=self.verify, headers=self.headers) _json = response.json() return _json def get_combiners(self, n_max: int = None): - """ Get combiners in the network. + """Get combiners in the network. :param n_max: The maximum number of combiners to get (If none all will be fetched). :type n_max: int @@ -166,21 +164,21 @@ def get_combiners(self, n_max: int = None): _headers = self.headers.copy() if n_max: - _headers['X-Limit'] = str(n_max) + _headers["X-Limit"] = str(n_max) - response = requests.get(self._get_url_api_v1('combiners/'), verify=self.verify, headers=_headers) + response = requests.get(self._get_url_api_v1("combiners/"), verify=self.verify, headers=_headers) _json = response.json() return _json def get_combiners_count(self): - """ Get the number of combiners in the statestore. + """Get the number of combiners in the statestore. :return: The number of combiners. :rtype: dict """ - response = requests.get(self._get_url_api_v1('combiners/count'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1("combiners/count"), verify=self.verify, headers=self.headers) _json = response.json() @@ -189,12 +187,12 @@ def get_combiners_count(self): # --- Controllers --- # def get_controller_status(self): - """ Get the status of the controller. + """Get the status of the controller. :return: The status of the controller. :rtype: dict """ - response = requests.get(self._get_url('get_controller_status'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url("get_controller_status"), verify=self.verify, headers=self.headers) _json = response.json() @@ -203,21 +201,21 @@ def get_controller_status(self): # --- Models --- # def get_model(self, id: str): - """ Get a model from the statestore. + """Get a model from the statestore. :param id: The id (or model property) of the model to get. :type id: str :return: Model. :rtype: dict """ - response = requests.get(self._get_url_api_v1(f'models/{id}'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1(f"models/{id}"), verify=self.verify, headers=self.headers) _json = response.json() return _json def get_models(self, session_id: str = None, n_max: int = None): - """ Get models from the statestore. + """Get models from the statestore. :param session_id: The session id to get models for. (optional) :type session_id: str @@ -229,41 +227,41 @@ def get_models(self, session_id: str = None, n_max: int = None): _params = {} if session_id: - _params['session_id'] = session_id + _params["session_id"] = session_id _headers = self.headers.copy() if n_max: - _headers['X-Limit'] = str(n_max) + _headers["X-Limit"] = str(n_max) - response = requests.get(self._get_url_api_v1('models/'), params=_params, verify=self.verify, headers=_headers) + response = requests.get(self._get_url_api_v1("models/"), params=_params, verify=self.verify, headers=_headers) _json = response.json() return _json def get_models_count(self): - """ Get the number of models in the statestore. + """Get the number of models in the statestore. :return: The number of models. :rtype: dict """ - response = requests.get(self._get_url_api_v1('models/count'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1("models/count"), verify=self.verify, headers=self.headers) _json = response.json() return _json def get_active_model(self): - """ Get the latest model from the statestore. + """Get the latest model from the statestore. :return: The latest model. :rtype: dict """ _headers = self.headers.copy() - _headers['X-Limit'] = "1" + _headers["X-Limit"] = "1" - response = requests.get(self._get_url_api_v1('models/'), verify=self.verify, headers=_headers) + response = requests.get(self._get_url_api_v1("models/"), verify=self.verify, headers=_headers) _json = response.json() if "result" in _json and len(_json["result"]) > 0: @@ -272,7 +270,7 @@ def get_active_model(self): return _json def get_model_trail(self, id: str = None, include_self: bool = True, reverse: bool = True, n_max: int = None): - """ Get the model trail. + """Get the model trail. :param id: The id (or model property) of the model to start the trail from. (optional) :type id: str @@ -291,18 +289,18 @@ def get_model_trail(self, id: str = None, include_self: bool = True, reverse: bo _headers = self.headers.copy() _count: int = n_max if n_max else self.get_models_count() - _headers['X-Limit'] = str(_count) - _headers['X-Reverse'] = "true" if reverse else "false" + _headers["X-Limit"] = str(_count) + _headers["X-Reverse"] = "true" if reverse else "false" _include_self_str: str = "true" if include_self else "false" - response = requests.get(self._get_url_api_v1(f'models/{id}/ancestors?include_self={_include_self_str}'), verify=self.verify, headers=_headers) + response = requests.get(self._get_url_api_v1(f"models/{id}/ancestors?include_self={_include_self_str}"), verify=self.verify, headers=_headers) _json = response.json() return _json def download_model(self, id: str, path: str): - """ Download the model with id id. + """Download the model with id id. :param id: The id (or model property) of the model to download. :type id: str @@ -311,46 +309,45 @@ def download_model(self, id: str, path: str): :return: Message with success or failure. :rtype: dict """ - response = requests.get(self._get_url_api_v1(f'models/{id}/download'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1(f"models/{id}/download"), verify=self.verify, headers=self.headers) if response.status_code == 200: - - with open(path, 'wb') as file: + with open(path, "wb") as file: file.write(response.content) - return {'success': True, 'message': 'Model downloaded successfully.'} + return {"success": True, "message": "Model downloaded successfully."} else: - return {'success': False, 'message': 'Failed to download model.'} + return {"success": False, "message": "Failed to download model."} def set_active_model(self, path): - """ Set the initial model in the statestore and upload to model repository. + """Set the initial model in the statestore and upload to model repository. :param path: The file path of the initial model to set. :type path: str :return: A dict with success or failure message. :rtype: dict """ - with open(path, 'rb') as file: - response = requests.post(self._get_url('set_initial_model'), files={'file': file}, verify=self.verify, headers=self.headers) + with open(path, "rb") as file: + response = requests.post(self._get_url("set_initial_model"), files={"file": file}, verify=self.verify, headers=self.headers) return response.json() # --- Packages --- # def get_package(self, id: str): - """ Get a compute package from the statestore. + """Get a compute package from the statestore. :param id: The id of the compute package to get. :type id: str :return: Package. :rtype: dict """ - response = requests.get(self._get_url_api_v1(f'packages/{id}'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1(f"packages/{id}"), verify=self.verify, headers=self.headers) _json = response.json() return _json def get_packages(self, n_max: int = None): - """ Get compute packages from the statestore. + """Get compute packages from the statestore. :param n_max: The maximum number of packages to get (If none all will be fetched). :type n_max: int @@ -360,68 +357,68 @@ def get_packages(self, n_max: int = None): _headers = self.headers.copy() if n_max: - _headers['X-Limit'] = str(n_max) + _headers["X-Limit"] = str(n_max) - response = requests.get(self._get_url_api_v1('packages/'), verify=self.verify, headers=_headers) + response = requests.get(self._get_url_api_v1("packages/"), verify=self.verify, headers=_headers) _json = response.json() return _json def get_packages_count(self): - """ Get the number of compute packages in the statestore. + """Get the number of compute packages in the statestore. :return: The number of packages. :rtype: dict """ - response = requests.get(self._get_url_api_v1('packages/count'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1("packages/count"), verify=self.verify, headers=self.headers) _json = response.json() return _json def get_active_package(self): - """ Get the (active) compute package from the statestore. + """Get the (active) compute package from the statestore. :return: Package. :rtype: dict """ - response = requests.get(self._get_url_api_v1('packages/active'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1("packages/active"), verify=self.verify, headers=self.headers) _json = response.json() return _json def get_package_checksum(self): - """ Get the checksum of the compute package. + """Get the checksum of the compute package. :return: The checksum. :rtype: dict """ - response = requests.get(self._get_url('get_package_checksum'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url("get_package_checksum"), verify=self.verify, headers=self.headers) _json = response.json() return _json def download_package(self, path: str): - """ Download the compute package. + """Download the compute package. :param path: The path to download the compute package to. :type path: str :return: Message with success or failure. :rtype: dict """ - response = requests.get(self._get_url('download_package'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url("download_package"), verify=self.verify, headers=self.headers) if response.status_code == 200: - with open(path, 'wb') as file: + with open(path, "wb") as file: file.write(response.content) - return {'success': True, 'message': 'Package downloaded successfully.'} + return {"success": True, "message": "Package downloaded successfully."} else: - return {'success': False, 'message': 'Failed to download package.'} + return {"success": False, "message": "Failed to download package."} def set_active_package(self, path: str, helper: str, name: str = None, description: str = None): - """ Set the compute package in the statestore. + """Set the compute package in the statestore. :param path: The file path of the compute package to set. :type path: str @@ -430,9 +427,14 @@ def set_active_package(self, path: str, helper: str, name: str = None, descripti :return: A dict with success or failure message. :rtype: dict """ - with open(path, 'rb') as file: - response = requests.post(self._get_url('set_package'), files={'file': file}, data={ - 'helper': helper, 'name': name, 'description': description}, verify=self.verify, headers=self.headers) + with open(path, "rb") as file: + response = requests.post( + self._get_url("set_package"), + files={"file": file}, + data={"helper": helper, "name": name, "description": description}, + verify=self.verify, + headers=self.headers, + ) _json = response.json() @@ -441,21 +443,21 @@ def set_active_package(self, path: str, helper: str, name: str = None, descripti # --- Rounds --- # def get_round(self, id: str): - """ Get a round from the statestore. + """Get a round from the statestore. :param round_id: The round id to get. :type round_id: str :return: Round (config and metrics). :rtype: dict """ - response = requests.get(self._get_url_api_v1(f'rounds/{id}'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1(f"rounds/{id}"), verify=self.verify, headers=self.headers) _json = response.json() return _json def get_rounds(self, n_max: int = None): - """ Get all rounds from the statestore. + """Get all rounds from the statestore. :param n_max: The maximum number of rounds to get (If none all will be fetched). :type n_max: int @@ -465,21 +467,21 @@ def get_rounds(self, n_max: int = None): _headers = self.headers.copy() if n_max: - _headers['X-Limit'] = str(n_max) + _headers["X-Limit"] = str(n_max) - response = requests.get(self._get_url_api_v1('rounds/'), verify=self.verify, headers=_headers) + response = requests.get(self._get_url_api_v1("rounds/"), verify=self.verify, headers=_headers) _json = response.json() return _json def get_rounds_count(self): - """ Get the number of rounds in the statestore. + """Get the number of rounds in the statestore. :return: The number of rounds. :rtype: dict """ - response = requests.get(self._get_url_api_v1('rounds/count'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1("rounds/count"), verify=self.verify, headers=self.headers) _json = response.json() @@ -488,7 +490,7 @@ def get_rounds_count(self): # --- Sessions --- # def get_session(self, id: str): - """ Get a session from the statestore. + """Get a session from the statestore. :param id: The session id to get. :type id: str @@ -496,14 +498,14 @@ def get_session(self, id: str): :rtype: dict """ - response = requests.get(self._get_url_api_v1(f'sessions/{id}'), self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1(f"sessions/{id}"), self.verify, headers=self.headers) _json = response.json() return _json def get_sessions(self, n_max: int = None): - """ Get sessions from the statestore. + """Get sessions from the statestore. :param n_max: The maximum number of sessions to get (If none all will be fetched). :type n_max: int @@ -513,28 +515,28 @@ def get_sessions(self, n_max: int = None): _headers = self.headers.copy() if n_max: - _headers['X-Limit'] = str(n_max) + _headers["X-Limit"] = str(n_max) - response = requests.get(self._get_url_api_v1('sessions/'), verify=self.verify, headers=_headers) + response = requests.get(self._get_url_api_v1("sessions/"), verify=self.verify, headers=_headers) _json = response.json() return _json def get_sessions_count(self): - """ Get the number of sessions in the statestore. + """Get the number of sessions in the statestore. :return: The number of sessions. :rtype: dict """ - response = requests.get(self._get_url_api_v1('sessions/count'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1("sessions/count"), verify=self.verify, headers=self.headers) _json = response.json() return _json def get_session_status(self, id: str): - """ Get the status of a session. + """Get the status of a session. :param id: The id of the session to get. :type id: str @@ -549,7 +551,7 @@ def get_session_status(self, id: str): return "Could not retrieve session status." def session_is_finished(self, id: str): - """ Check if a session with id has finished. + """Check if a session with id has finished. :param id: The id of the session to get. :type id: str @@ -560,18 +562,21 @@ def session_is_finished(self, id: str): return status and status.lower() == "finished" def start_session( - self, - id: str = None, - aggregator: str = 'fedavg', - aggregator_kwargs: dict = None, - model_id: str = None, - round_timeout: int = 180, - rounds: int = 5, - round_buffer_size: int = -1, - delete_models: bool = True, - validate: bool = True, helper: str = 'numpyhelper', min_clients: int = 1, requested_clients: int = 8 + self, + id: str = None, + aggregator: str = "fedavg", + aggregator_kwargs: dict = None, + model_id: str = None, + round_timeout: int = 180, + rounds: int = 5, + round_buffer_size: int = -1, + delete_models: bool = True, + validate: bool = True, + helper: str = "numpyhelper", + min_clients: int = 1, + requested_clients: int = 8, ): - """ Start a new session. + """Start a new session. :param id: The session id to start. :type id: str @@ -598,20 +603,25 @@ def start_session( :return: A dict with success or failure message and session config. :rtype: dict """ - response = requests.post(self._get_url('start_session'), json={ - 'session_id': id, - 'aggregator': aggregator, - 'aggregator_kwargs': aggregator_kwargs, - 'model_id': model_id, - 'round_timeout': round_timeout, - 'rounds': rounds, - 'round_buffer_size': round_buffer_size, - 'delete_models': delete_models, - 'validate': validate, - 'helper': helper, - 'min_clients': min_clients, - 'requested_clients': requested_clients - }, verify=self.verify, headers=self.headers) + response = requests.post( + self._get_url("start_session"), + json={ + "session_id": id, + "aggregator": aggregator, + "aggregator_kwargs": aggregator_kwargs, + "model_id": model_id, + "round_timeout": round_timeout, + "rounds": rounds, + "round_buffer_size": round_buffer_size, + "delete_models": delete_models, + "validate": validate, + "helper": helper, + "min_clients": min_clients, + "requested_clients": requested_clients, + }, + verify=self.verify, + headers=self.headers, + ) _json = response.json() @@ -620,21 +630,21 @@ def start_session( # --- Statuses --- # def get_status(self, id: str): - """ Get a status object (event) from the statestore. + """Get a status object (event) from the statestore. :param id: The id of the status to get. :type id: str :return: Status. :rtype: dict """ - response = requests.get(self._get_url_api_v1(f'statuses/{id}'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1(f"statuses/{id}"), verify=self.verify, headers=self.headers) _json = response.json() return _json def get_statuses(self, session_id: str = None, event_type: str = None, sender_name: str = None, sender_role: str = None, n_max: int = None): - """ Get statuses from the statestore. Filter by input parameters + """Get statuses from the statestore. Filter by input parameters :param session_id: The session id to get statuses for. :type session_id: str @@ -665,21 +675,21 @@ def get_statuses(self, session_id: str = None, event_type: str = None, sender_na _headers = self.headers.copy() if n_max: - _headers['X-Limit'] = str(n_max) + _headers["X-Limit"] = str(n_max) - response = requests.get(self._get_url_api_v1('statuses/'), params=_params, verify=self.verify, headers=_headers) + response = requests.get(self._get_url_api_v1("statuses/"), params=_params, verify=self.verify, headers=_headers) _json = response.json() return _json def get_statuses_count(self): - """ Get the number of statuses in the statestore. + """Get the number of statuses in the statestore. :return: The number of statuses. :rtype: dict """ - response = requests.get(self._get_url_api_v1('statuses/count'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1("statuses/count"), verify=self.verify, headers=self.headers) _json = response.json() @@ -688,14 +698,14 @@ def get_statuses_count(self): # --- Validations --- # def get_validation(self, id: str): - """ Get a validation from the statestore. + """Get a validation from the statestore. :param id: The id of the validation to get. :type id: str :return: Validation. :rtype: dict """ - response = requests.get(self._get_url_api_v1(f'validations/{id}'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1(f"validations/{id}"), verify=self.verify, headers=self.headers) _json = response.json() @@ -710,9 +720,9 @@ def get_validations( sender_role: str = None, receiver_name: str = None, receiver_role: str = None, - n_max: int = None + n_max: int = None, ): - """ Get validations from the statestore. Filter by input parameters. + """Get validations from the statestore. Filter by input parameters. :param session_id: The session id to get validations for. :type session_id: str @@ -759,21 +769,21 @@ def get_validations( _headers = self.headers.copy() if n_max: - _headers['X-Limit'] = str(n_max) + _headers["X-Limit"] = str(n_max) - response = requests.get(self._get_url_api_v1('validations/'), params=_params, verify=self.verify, headers=_headers) + response = requests.get(self._get_url_api_v1("validations/"), params=_params, verify=self.verify, headers=_headers) _json = response.json() return _json def get_validations_count(self): - """ Get the number of validations in the statestore. + """Get the number of validations in the statestore. :return: The number of validations. :rtype: dict """ - response = requests.get(self._get_url_api_v1('validations/count'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1("validations/count"), verify=self.verify, headers=self.headers) _json = response.json() diff --git a/fedn/network/api/interface.py b/fedn/network/api/interface.py index 38d6c8bf6..718cd8a18 100644 --- a/fedn/network/api/interface.py +++ b/fedn/network/api/interface.py @@ -10,8 +10,7 @@ from fedn.common.config import get_controller_config, get_network_config from fedn.common.log_config import logger -from fedn.network.combiner.interfaces import (CombinerInterface, - CombinerUnavailableError) +from fedn.network.combiner.interfaces import CombinerInterface, CombinerUnavailableError from fedn.network.state import ReducerState, ReducerStateToString from fedn.utils.checksum import sha from fedn.utils.plots import Plot @@ -36,9 +35,7 @@ def _to_dict(self): data = {"name": self.name} return data - def _allowed_file_extension( - self, filename, ALLOWED_EXTENSIONS={"gz", "bz2", "tar", "zip", "tgz"} - ): + def _allowed_file_extension(self, filename, ALLOWED_EXTENSIONS={"gz", "bz2", "tar", "zip", "tgz"}): """Check if file extension is allowed. :param filename: The filename to check. @@ -170,11 +167,10 @@ def get_session(self, session_id): info = session_object["session_config"][0] status = session_object["status"] payload[id] = info - payload['status'] = status + payload["status"] = status return jsonify(payload) def set_active_compute_package(self, id: str): - success = self.statestore.set_active_compute_package(id) if not success: @@ -199,16 +195,12 @@ def set_compute_package(self, file, helper_type: str, name: str = None, descript :rtype: :class:`flask.Response` """ - if ( - self.control.state() == ReducerState.instructing - or self.control.state() == ReducerState.monitoring - ): + if self.control.state() == ReducerState.instructing or self.control.state() == ReducerState.monitoring: return ( jsonify( { "success": False, - "message": "Reducer is in instructing or monitoring state." - "Cannot set compute package.", + "message": "Reducer is in instructing or monitoring state." "Cannot set compute package.", } ), 400, @@ -288,9 +280,7 @@ def get_compute_package(self): result = self.statestore.get_compute_package() if result is None: return ( - jsonify( - {"success": False, "message": "No compute package found."} - ), + jsonify({"success": False, "message": "No compute package found."}), 404, ) @@ -327,9 +317,7 @@ def list_compute_packages(self, limit: str = None, skip: str = None, include_act result = self.statestore.list_compute_packages(limit, skip) if result is None: return ( - jsonify( - {"success": False, "message": "No compute packages found."} - ), + jsonify({"success": False, "message": "No compute packages found."}), 404, ) @@ -386,9 +374,7 @@ def download_compute_package(self, name): mutex = threading.Lock() mutex.acquire() # TODO: make configurable, perhaps in config.py or package.py - return send_from_directory( - "/app/client/package/", name, as_attachment=True - ) + return send_from_directory("/app/client/package/", name, as_attachment=True) except Exception: try: data = self.control.get_compute_package(name) @@ -397,9 +383,7 @@ def download_compute_package(self, name): with open(file_path, "wb") as fh: fh.write(data) # TODO: make configurable, perhaps in config.py or package.py - return send_from_directory( - "/app/client/package/", name, as_attachment=True - ) + return send_from_directory("/app/client/package/", name, as_attachment=True) except Exception: raise finally: @@ -418,9 +402,7 @@ def _create_checksum(self, name=None): name, message = self._get_compute_package_name() if name is None: return False, message, "" - file_path = os.path.join( - "/app/client/package/", name - ) # TODO: make configurable, perhaps in config.py or package.py + file_path = os.path.join("/app/client/package/", name) # TODO: make configurable, perhaps in config.py or package.py try: sum = str(sha(file_path)) except FileNotFoundError: @@ -505,9 +487,7 @@ def get_all_validations(self, **kwargs): payload[id] = info return jsonify(payload) - def add_combiner( - self, combiner_id, secure_grpc, address, remote_addr, fqdn, port - ): + def add_combiner(self, combiner_id, secure_grpc, address, remote_addr, fqdn, port): """Add a combiner to the network. :param combiner_id: The combiner id to add. @@ -540,9 +520,7 @@ def add_combiner( combiner = self.control.network.get_combiner(combiner_id) if not combiner: if secure_grpc == "True": - certificate, key = self.certificate_manager.get_or_create( - address - ).get_keypair_raw() + certificate, key = self.certificate_manager.get_or_create(address).get_keypair_raw() _ = base64.b64encode(certificate) _ = base64.b64encode(key) @@ -566,9 +544,7 @@ def add_combiner( # Check combiner now exists combiner = self.control.network.get_combiner(combiner_id) if not combiner: - return jsonify( - {"success": False, "message": "Combiner not added."} - ) + return jsonify({"success": False, "message": "Combiner not added."}) payload = { "success": True, @@ -623,9 +599,7 @@ def add_client(self, client_id, preferred_combiner, remote_addr): combiner = self.control.network.find_available_combiner() if combiner is None: return ( - jsonify( - {"success": False, "message": "No combiner available."} - ), + jsonify({"success": False, "message": "No combiner available."}), 400, ) @@ -691,9 +665,7 @@ def set_initial_model(self, file): logger.debug(e) return jsonify({"success": False, "message": e}) - return jsonify( - {"success": True, "message": "Initial model added successfully."} - ) + return jsonify({"success": True, "message": "Initial model added successfully."}) def get_latest_model(self): """Get the latest model from the statestore. @@ -706,9 +678,7 @@ def get_latest_model(self): payload = {"model_id": model_id} return jsonify(payload) else: - return jsonify( - {"success": False, "message": "No initial model set."} - ) + return jsonify({"success": False, "message": "No initial model set."}) def set_current_model(self, model_id: str): """Set the active model in the statestore. @@ -745,7 +715,6 @@ def get_models(self, session_id: str = None, limit: str = None, skip: str = None include_active: bool = include_active == "true" if include_active: - latest_model = self.statestore.get_latest_model() arr = [ @@ -801,9 +770,7 @@ def get_model_trail(self): if model_info: return jsonify(model_info) else: - return jsonify( - {"success": False, "message": "No model trail available."} - ) + return jsonify({"success": False, "message": "No model trail available."}) def get_model_ancestors(self, model_id: str, limit: str = None): """Get the model ancestors for a given model. @@ -816,15 +783,12 @@ def get_model_ancestors(self, model_id: str, limit: str = None): :rtype: :class:`flask.Response` """ if model_id is None: - return jsonify( - {"success": False, "message": "No model id provided."} - ) + return jsonify({"success": False, "message": "No model id provided."}) limit: int = int(limit) if limit is not None else 10 # if limit is None, default to 10 response = self.statestore.get_model_ancestors(model_id, limit) if response: - arr: list = [] for element in response: @@ -840,9 +804,7 @@ def get_model_ancestors(self, model_id: str, limit: str = None): return jsonify(result) else: - return jsonify( - {"success": False, "message": "No model ancestors available."} - ) + return jsonify({"success": False, "message": "No model ancestors available."}) def get_model_descendants(self, model_id: str, limit: str = None): """Get the model descendants for a given model. @@ -856,16 +818,13 @@ def get_model_descendants(self, model_id: str, limit: str = None): """ if model_id is None: - return jsonify( - {"success": False, "message": "No model id provided."} - ) + return jsonify({"success": False, "message": "No model id provided."}) limit: int = int(limit) if limit is not None else 10 response: list = self.statestore.get_model_descendants(model_id, limit) if response: - arr: list = [] for element in response: @@ -881,9 +840,7 @@ def get_model_descendants(self, model_id: str, limit: str = None): return jsonify(result) else: - return jsonify( - {"success": False, "message": "No model descendants available."} - ) + return jsonify({"success": False, "message": "No model descendants available."}) def get_all_rounds(self): """Get all rounds. @@ -926,8 +883,8 @@ def get_round(self, round_id): if round_object is None: return jsonify({"success": False, "message": "Round not found."}) payload = { - 'round_id': round_object['round_id'], - 'combiners': round_object['combiners'], + "round_id": round_object["round_id"], + "combiners": round_object["combiners"], } return jsonify(payload) @@ -992,7 +949,6 @@ def list_combiners_data(self, combiners): # order list by combiner name for element in response: - obj = { "combiner": element["_id"], "count": element["count"], @@ -1007,7 +963,7 @@ def list_combiners_data(self, combiners): def start_session( self, session_id, - aggregator='fedavg', + aggregator="fedavg", aggregator_kwargs=None, model_id=None, rounds=5, @@ -1047,15 +1003,11 @@ def start_session( # Check if session already exists session = self.statestore.get_session(session_id) if session: - return jsonify( - {"success": False, "message": "Session already exists."} - ) + return jsonify({"success": False, "message": "Session already exists."}) # Check if session is running if self.control.state() == ReducerState.monitoring: - return jsonify( - {"success": False, "message": "A session is already running."} - ) + return jsonify({"success": False, "message": "A session is already running."}) # Check if compute package is set if not self.statestore.get_compute_package(): @@ -1123,9 +1075,7 @@ def start_session( } # Start session - threading.Thread( - target=self.control.session, args=(session_config,) - ).start() + threading.Thread(target=self.control.session, args=(session_config,)).start() # Return success response return jsonify( diff --git a/fedn/network/api/network.py b/fedn/network/api/network.py index c6a3c3838..cb105f10a 100644 --- a/fedn/network/api/network.py +++ b/fedn/network/api/network.py @@ -4,14 +4,14 @@ from fedn.network.combiner.interfaces import CombinerInterface from fedn.network.loadbalancer.leastpacked import LeastPacked -__all__ = 'Network', +__all__ = ("Network",) class Network: - """ FEDn network interface. This class is used to interact with the network. - Note: This class contain redundant code, which is not used in the current version of FEDn. - Some methods has been moved to :class:`fedn.network.api.interface.API`. - """ + """FEDn network interface. This class is used to interact with the network. + Note: This class contain redundant code, which is not used in the current version of FEDn. + Some methods has been moved to :class:`fedn.network.api.interface.API`. + """ def __init__(self, control, statestore, load_balancer=None): """ """ @@ -25,7 +25,7 @@ def __init__(self, control, statestore, load_balancer=None): self.load_balancer = load_balancer def get_combiner(self, name): - """ Get combiner by name. + """Get combiner by name. :param name: name of combiner :type name: str @@ -39,7 +39,7 @@ def get_combiner(self, name): return None def get_combiners(self): - """ Get all combiners in the network. + """Get all combiners in the network. :return: list of combiners objects :rtype: list(:class:`fedn.network.combiner.interfaces.CombinerInterface`) @@ -47,21 +47,19 @@ def get_combiners(self): data = self.statestore.get_combiners() combiners = [] for c in data["result"]: - if c['certificate']: - cert = base64.b64decode(c['certificate']) - key = base64.b64decode(c['key']) + if c["certificate"]: + cert = base64.b64decode(c["certificate"]) + key = base64.b64decode(c["key"]) else: cert = None key = None - combiners.append( - CombinerInterface(c['parent'], c['name'], c['address'], c['fqdn'], c['port'], - certificate=cert, key=key, ip=c['ip'])) + combiners.append(CombinerInterface(c["parent"], c["name"], c["address"], c["fqdn"], c["port"], certificate=cert, key=key, ip=c["ip"])) return combiners def add_combiner(self, combiner): - """ Add a new combiner to the network. + """Add a new combiner to the network. :param combiner: The combiner instance object :type combiner: :class:`fedn.network.combiner.interfaces.CombinerInterface` @@ -78,7 +76,7 @@ def add_combiner(self, combiner): self.statestore.set_combiner(combiner.to_dict()) def remove_combiner(self, combiner): - """ Remove a combiner from the network. + """Remove a combiner from the network. :param combiner: The combiner instance object :type combiner: :class:`fedn.network.combiner.interfaces.CombinerInterface` @@ -90,7 +88,7 @@ def remove_combiner(self, combiner): self.statestore.delete_combiner(combiner.name) def find_available_combiner(self): - """ Find an available combiner in the network. + """Find an available combiner in the network. :return: The combiner instance object :rtype: :class:`fedn.network.combiner.interfaces.CombinerInterface` @@ -99,32 +97,31 @@ def find_available_combiner(self): return combiner def handle_unavailable_combiner(self, combiner): - """ This callback is triggered if a combiner is found to be unresponsive. + """This callback is triggered if a combiner is found to be unresponsive. :param combiner: The combiner instance object :type combiner: :class:`fedn.network.combiner.interfaces.CombinerInterface` :return: None """ # TODO: Implement strategy to handle an unavailable combiner. - logger.warning("REDUCER CONTROL: Combiner {} unavailable.".format( - combiner.name)) + logger.warning("REDUCER CONTROL: Combiner {} unavailable.".format(combiner.name)) def add_client(self, client): - """ Add a new client to the network. + """Add a new client to the network. :param client: The client instance object :type client: dict :return: None """ - if self.get_client(client['name']): + if self.get_client(client["name"]): return - logger.info("adding client {}".format(client['name'])) + logger.info("adding client {}".format(client["name"])) self.statestore.set_client(client) def get_client(self, name): - """ Get client by name. + """Get client by name. :param name: name of client :type name: str @@ -135,7 +132,7 @@ def get_client(self, name): return ret def update_client_data(self, client_data, status, role): - """ Update client status in statestore. + """Update client status in statestore. :param client_data: The client instance object :type client_data: dict @@ -148,7 +145,7 @@ def update_client_data(self, client_data, status, role): self.statestore.update_client_status(client_data, status, role) def get_client_info(self): - """ list available client in statestore. + """List available client in statestore. :return: list of client objects :rtype: list(ObjectId) diff --git a/fedn/network/api/server.py b/fedn/network/api/server.py index 45cf410ce..5f645e4e2 100644 --- a/fedn/network/api/server.py +++ b/fedn/network/api/server.py @@ -2,8 +2,7 @@ from flask import Flask, jsonify, request -from fedn.common.config import (get_controller_config, get_modelstorage_config, - get_network_config, get_statestore_config) +from fedn.common.config import get_controller_config, get_modelstorage_config, get_network_config, get_statestore_config from fedn.network.api.auth import jwt_auth_required from fedn.network.api.interface import API from fedn.network.api.v1 import _routes @@ -23,14 +22,12 @@ for bp in _routes: app.register_blueprint(bp) if custom_url_prefix: - app.register_blueprint(bp, - name=f"{bp.name}_custom", - url_prefix=f"{custom_url_prefix}{bp.url_prefix}") + app.register_blueprint(bp, name=f"{bp.name}_custom", url_prefix=f"{custom_url_prefix}{bp.url_prefix}") -@app.route('/health', methods=["GET"]) +@app.route("/health", methods=["GET"]) def health_check(): - return 'OK', 200 + return "OK", 200 if custom_url_prefix: @@ -364,9 +361,7 @@ def set_package(): file = request.files["file"] except KeyError: return jsonify({"success": False, "message": "Missing file."}), 400 - return api.set_compute_package( - file=file, helper_type=helper_type, name=name, description=description - ) + return api.set_compute_package(file=file, helper_type=helper_type, name=name, description=description) if custom_url_prefix: @@ -399,9 +394,7 @@ def list_compute_packages(): skip = request.args.get("skip", None) include_active = request.args.get("include_active", None) - return api.list_compute_packages( - limit=limit, skip=skip, include_active=include_active - ) + return api.list_compute_packages(limit=limit, skip=skip, include_active=include_active) if custom_url_prefix: diff --git a/fedn/network/api/v1/client_routes.py b/fedn/network/api/v1/client_routes.py index 30322a9b7..d5ccc58ee 100644 --- a/fedn/network/api/v1/client_routes.py +++ b/fedn/network/api/v1/client_routes.py @@ -1,8 +1,7 @@ from flask import Blueprint, jsonify, request from fedn.network.api.auth import jwt_auth_required -from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs, - get_typed_list_headers, mdb) +from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers, mdb from fedn.network.storage.statestore.stores.client_store import ClientStore from fedn.network.storage.statestore.stores.shared import EntityNotFound @@ -115,9 +114,7 @@ def get_clients(): limit, skip, sort_key, sort_order, _ = get_typed_list_headers(request.headers) kwargs = request.args.to_dict() - clients = client_store.list( - limit, skip, sort_key, sort_order, use_typing=False, **kwargs - ) + clients = client_store.list(limit, skip, sort_key, sort_order, use_typing=False, **kwargs) result = clients["result"] @@ -202,9 +199,7 @@ def list_clients(): try: limit, skip, sort_key, sort_order, _ = get_typed_list_headers(request.headers) kwargs = get_post_data_to_kwargs(request) - clients = client_store.list( - limit, skip, sort_key, sort_order, use_typing=False, **kwargs - ) + clients = client_store.list(limit, skip, sort_key, sort_order, use_typing=False, **kwargs) result = clients["result"] diff --git a/fedn/network/api/v1/combiner_routes.py b/fedn/network/api/v1/combiner_routes.py index 7d1761bee..1f9360461 100644 --- a/fedn/network/api/v1/combiner_routes.py +++ b/fedn/network/api/v1/combiner_routes.py @@ -1,8 +1,7 @@ from flask import Blueprint, jsonify, request from fedn.network.api.auth import jwt_auth_required -from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs, - get_typed_list_headers, mdb) +from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers, mdb from fedn.network.storage.statestore.stores.combiner_store import CombinerStore from fedn.network.storage.statestore.stores.shared import EntityNotFound @@ -107,9 +106,7 @@ def get_combiners(): kwargs = request.args.to_dict() - combiners = combiner_store.list( - limit, skip, sort_key, sort_order, use_typing=False, **kwargs - ) + combiners = combiner_store.list(limit, skip, sort_key, sort_order, use_typing=False, **kwargs) result = combiners["result"] @@ -192,9 +189,7 @@ def list_combiners(): kwargs = get_post_data_to_kwargs(request) - combiners = combiner_store.list( - limit, skip, sort_key, sort_order, use_typing=False, **kwargs - ) + combiners = combiner_store.list(limit, skip, sort_key, sort_order, use_typing=False, **kwargs) result = combiners["result"] diff --git a/fedn/network/api/v1/model_routes.py b/fedn/network/api/v1/model_routes.py index 8e9308408..f9708a149 100644 --- a/fedn/network/api/v1/model_routes.py +++ b/fedn/network/api/v1/model_routes.py @@ -4,10 +4,7 @@ from flask import Blueprint, jsonify, request, send_file from fedn.network.api.auth import jwt_auth_required -from fedn.network.api.v1.shared import (api_version, get_limit, - get_post_data_to_kwargs, get_reverse, - get_typed_list_headers, mdb, - modelstorage_config) +from fedn.network.api.v1.shared import api_version, get_limit, get_post_data_to_kwargs, get_reverse, get_typed_list_headers, mdb, modelstorage_config from fedn.network.storage.s3.base import RepositoryBase from fedn.network.storage.s3.miniorepository import MINIORepository from fedn.network.storage.statestore.stores.model_store import ModelStore @@ -112,9 +109,7 @@ def get_models(): limit, skip, sort_key, sort_order, _ = get_typed_list_headers(request.headers) kwargs = request.args.to_dict() - models = model_store.list( - limit, skip, sort_key, sort_order, use_typing=False, **kwargs - ) + models = model_store.list(limit, skip, sort_key, sort_order, use_typing=False, **kwargs) result = models["result"] @@ -199,9 +194,7 @@ def list_models(): limit, skip, sort_key, sort_order, _ = get_typed_list_headers(request.headers) kwargs = get_post_data_to_kwargs(request) - models = model_store.list( - limit, skip, sort_key, sort_order, use_typing=False, **kwargs - ) + models = model_store.list(limit, skip, sort_key, sort_order, use_typing=False, **kwargs) result = models["result"] @@ -466,7 +459,7 @@ def get_ancestors(id: str): try: limit = get_limit(request.headers) reverse = get_reverse(request.headers) - include_self_param: str = request.args.get('include_self') + include_self_param: str = request.args.get("include_self") include_self: bool = include_self_param and include_self_param.lower() == "true" diff --git a/fedn/network/api/v1/package_routes.py b/fedn/network/api/v1/package_routes.py index 30ac4d51e..65783f54b 100644 --- a/fedn/network/api/v1/package_routes.py +++ b/fedn/network/api/v1/package_routes.py @@ -1,9 +1,7 @@ from flask import Blueprint, jsonify, request from fedn.network.api.auth import jwt_auth_required -from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs, - get_typed_list_headers, get_use_typing, - mdb) +from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers, get_use_typing, mdb from fedn.network.storage.statestore.stores.package_store import PackageStore from fedn.network.storage.statestore.stores.shared import EntityNotFound @@ -120,9 +118,7 @@ def get_packages(): limit, skip, sort_key, sort_order, _ = get_typed_list_headers(request.headers) kwargs = request.args.to_dict() - packages = package_store.list( - limit, skip, sort_key, sort_order, use_typing=True, **kwargs - ) + packages = package_store.list(limit, skip, sort_key, sort_order, use_typing=True, **kwargs) result = [package.__dict__ for package in packages["result"]] @@ -210,9 +206,7 @@ def list_packages(): limit, skip, sort_key, sort_order, _ = get_typed_list_headers(request.headers) kwargs = get_post_data_to_kwargs(request) - packages = package_store.list( - limit, skip, sort_key, sort_order, use_typing=True, **kwargs - ) + packages = package_store.list(limit, skip, sort_key, sort_order, use_typing=True, **kwargs) result = [package.__dict__ for package in packages["result"]] diff --git a/fedn/network/api/v1/round_routes.py b/fedn/network/api/v1/round_routes.py index 8890c510a..4c2eb0c44 100644 --- a/fedn/network/api/v1/round_routes.py +++ b/fedn/network/api/v1/round_routes.py @@ -1,8 +1,7 @@ from flask import Blueprint, jsonify, request from fedn.network.api.auth import jwt_auth_required -from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs, - get_typed_list_headers, mdb) +from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers, mdb from fedn.network.storage.statestore.stores.round_store import RoundStore from fedn.network.storage.statestore.stores.shared import EntityNotFound @@ -95,9 +94,7 @@ def get_rounds(): kwargs = request.args.to_dict() - rounds = round_store.list( - limit, skip, sort_key, sort_order, use_typing=False, **kwargs - ) + rounds = round_store.list(limit, skip, sort_key, sort_order, use_typing=False, **kwargs) result = rounds["result"] @@ -176,9 +173,7 @@ def list_rounds(): kwargs = get_post_data_to_kwargs(request) - rounds = round_store.list( - limit, skip, sort_key, sort_order, use_typing=False, **kwargs - ) + rounds = round_store.list(limit, skip, sort_key, sort_order, use_typing=False, **kwargs) result = rounds["result"] diff --git a/fedn/network/api/v1/session_routes.py b/fedn/network/api/v1/session_routes.py index 99c52d8db..ccfde590a 100644 --- a/fedn/network/api/v1/session_routes.py +++ b/fedn/network/api/v1/session_routes.py @@ -1,8 +1,7 @@ from flask import Blueprint, jsonify, request from fedn.network.api.auth import jwt_auth_required -from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs, - get_typed_list_headers, mdb) +from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers, mdb from fedn.network.storage.statestore.stores.session_store import SessionStore from fedn.network.storage.statestore.stores.shared import EntityNotFound @@ -87,9 +86,7 @@ def get_sessions(): limit, skip, sort_key, sort_order, _ = get_typed_list_headers(request.headers) kwargs = request.args.to_dict() - sessions = session_store.list( - limit, skip, sort_key, sort_order, use_typing=False, **kwargs - ) + sessions = session_store.list(limit, skip, sort_key, sort_order, use_typing=False, **kwargs) result = sessions["result"] @@ -167,9 +164,7 @@ def list_sessions(): limit, skip, sort_key, sort_order, _ = get_typed_list_headers(request.headers) kwargs = get_post_data_to_kwargs(request) - sessions = session_store.list( - limit, skip, sort_key, sort_order, use_typing=False, **kwargs - ) + sessions = session_store.list(limit, skip, sort_key, sort_order, use_typing=False, **kwargs) result = sessions["result"] diff --git a/fedn/network/api/v1/shared.py b/fedn/network/api/v1/shared.py index 753414324..2fb6063c0 100644 --- a/fedn/network/api/v1/shared.py +++ b/fedn/network/api/v1/shared.py @@ -3,8 +3,7 @@ import pymongo from pymongo.database import Database -from fedn.common.config import (get_modelstorage_config, get_network_config, - get_statestore_config) +from fedn.common.config import get_modelstorage_config, get_network_config, get_statestore_config api_version = "v1" @@ -58,9 +57,7 @@ def get_typed_list_headers( use_typing: bool = get_use_typing(headers) if sort_order is not None: - sort_order = ( - pymongo.ASCENDING if sort_order.lower() == "asc" else pymongo.DESCENDING - ) + sort_order = pymongo.ASCENDING if sort_order.lower() == "asc" else pymongo.DESCENDING else: sort_order = pymongo.DESCENDING diff --git a/fedn/network/api/v1/status_routes.py b/fedn/network/api/v1/status_routes.py index e78c18533..b88772b01 100644 --- a/fedn/network/api/v1/status_routes.py +++ b/fedn/network/api/v1/status_routes.py @@ -1,9 +1,7 @@ from flask import Blueprint, jsonify, request from fedn.network.api.auth import jwt_auth_required -from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs, - get_typed_list_headers, get_use_typing, - mdb) +from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers, get_use_typing, mdb from fedn.network.storage.statestore.stores.shared import EntityNotFound from fedn.network.storage.statestore.stores.status_store import StatusStore @@ -123,20 +121,12 @@ def get_statuses(): type: string """ try: - limit, skip, sort_key, sort_order, use_typing = get_typed_list_headers( - request.headers - ) + limit, skip, sort_key, sort_order, use_typing = get_typed_list_headers(request.headers) kwargs = request.args.to_dict() - statuses = status_store.list( - limit, skip, sort_key, sort_order, use_typing=use_typing, **kwargs - ) + statuses = status_store.list(limit, skip, sort_key, sort_order, use_typing=use_typing, **kwargs) - result = ( - [status.__dict__ for status in statuses["result"]] - if use_typing - else statuses["result"] - ) + result = [status.__dict__ for status in statuses["result"]] if use_typing else statuses["result"] response = {"count": statuses["count"], "result": result} @@ -226,20 +216,12 @@ def list_statuses(): type: string """ try: - limit, skip, sort_key, sort_order, use_typing = get_typed_list_headers( - request.headers - ) + limit, skip, sort_key, sort_order, use_typing = get_typed_list_headers(request.headers) kwargs = get_post_data_to_kwargs(request) - statuses = status_store.list( - limit, skip, sort_key, sort_order, use_typing=use_typing, **kwargs - ) + statuses = status_store.list(limit, skip, sort_key, sort_order, use_typing=use_typing, **kwargs) - result = ( - [status.__dict__ for status in statuses["result"]] - if use_typing - else statuses["result"] - ) + result = [status.__dict__ for status in statuses["result"]] if use_typing else statuses["result"] response = {"count": statuses["count"], "result": result} diff --git a/fedn/network/api/v1/validation_routes.py b/fedn/network/api/v1/validation_routes.py index 96fbac55c..59767e3e8 100644 --- a/fedn/network/api/v1/validation_routes.py +++ b/fedn/network/api/v1/validation_routes.py @@ -1,12 +1,9 @@ from flask import Blueprint, jsonify, request from fedn.network.api.auth import jwt_auth_required -from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs, - get_typed_list_headers, get_use_typing, - mdb) +from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers, get_use_typing, mdb from fedn.network.storage.statestore.stores.shared import EntityNotFound -from fedn.network.storage.statestore.stores.validation_store import \ - ValidationStore +from fedn.network.storage.statestore.stores.validation_store import ValidationStore bp = Blueprint("validation", __name__, url_prefix=f"/api/{api_version}/validations") @@ -131,20 +128,12 @@ def get_validations(): type: string """ try: - limit, skip, sort_key, sort_order, use_typing = get_typed_list_headers( - request.headers - ) + limit, skip, sort_key, sort_order, use_typing = get_typed_list_headers(request.headers) kwargs = request.args.to_dict() - validations = validation_store.list( - limit, skip, sort_key, sort_order, use_typing=use_typing, **kwargs - ) + validations = validation_store.list(limit, skip, sort_key, sort_order, use_typing=use_typing, **kwargs) - result = ( - [validation.__dict__ for validation in validations["result"]] - if use_typing - else validations["result"] - ) + result = [validation.__dict__ for validation in validations["result"]] if use_typing else validations["result"] response = {"count": validations["count"], "result": result} @@ -237,20 +226,12 @@ def list_validations(): type: string """ try: - limit, skip, sort_key, sort_order, use_typing = get_typed_list_headers( - request.headers - ) + limit, skip, sort_key, sort_order, use_typing = get_typed_list_headers(request.headers) kwargs = get_post_data_to_kwargs(request) - validations = validation_store.list( - limit, skip, sort_key, sort_order, use_typing=use_typing, **kwargs - ) + validations = validation_store.list(limit, skip, sort_key, sort_order, use_typing=use_typing, **kwargs) - result = ( - [validation.__dict__ for validation in validations["result"]] - if use_typing - else validations["result"] - ) + result = [validation.__dict__ for validation in validations["result"]] if use_typing else validations["result"] response = {"count": validations["count"], "result": result} diff --git a/fedn/network/clients/client.py b/fedn/network/clients/client.py index e830006e4..c8a5afc4f 100644 --- a/fedn/network/clients/client.py +++ b/fedn/network/clients/client.py @@ -22,18 +22,16 @@ import fedn.network.grpc.fedn_pb2 as fedn import fedn.network.grpc.fedn_pb2_grpc as rpc from fedn.common.config import FEDN_AUTH_SCHEME, FEDN_PACKAGE_EXTRACT_DIR -from fedn.common.log_config import (logger, set_log_level_from_string, - set_log_stream) +from fedn.common.log_config import logger, set_log_level_from_string, set_log_stream from fedn.network.clients.connect import ConnectorClient, Status from fedn.network.clients.package import PackageRuntime from fedn.network.clients.state import ClientState, ClientStateToString -from fedn.network.combiner.modelservice import (get_tmp_path, - upload_request_generator) +from fedn.network.combiner.modelservice import get_tmp_path, upload_request_generator from fedn.utils.dispatcher import Dispatcher from fedn.utils.helpers.helpers import get_helper CHUNK_SIZE = 1024 * 1024 -VALID_NAME_REGEX = '^[a-zA-Z0-9_-]*$' +VALID_NAME_REGEX = "^[a-zA-Z0-9_-]*$" class GrpcAuth(grpc.AuthMetadataPlugin): @@ -41,7 +39,7 @@ def __init__(self, key): self._key = key def __call__(self, context, callback): - callback((('authorization', f'{FEDN_AUTH_SCHEME} {self._key}'),), None) + callback((("authorization", f"{FEDN_AUTH_SCHEME} {self._key}"),), None) class Client: @@ -61,30 +59,32 @@ def __init__(self, config): self._missed_heartbeat = 0 self.config = config self.trace_attribs = False - set_log_level_from_string(config.get('verbosity', "INFO")) - set_log_stream(config.get('logfile', None)) - - self.connector = ConnectorClient(host=config['discover_host'], - port=config['discover_port'], - token=config['token'], - name=config['name'], - remote_package=config['remote_compute_context'], - force_ssl=config['force_ssl'], - verify=config['verify'], - combiner=config['preferred_combiner'], - id=config['client_id']) + set_log_level_from_string(config.get("verbosity", "INFO")) + set_log_stream(config.get("logfile", None)) + + self.connector = ConnectorClient( + host=config["discover_host"], + port=config["discover_port"], + token=config["token"], + name=config["name"], + remote_package=config["remote_compute_context"], + force_ssl=config["force_ssl"], + verify=config["verify"], + combiner=config["preferred_combiner"], + id=config["client_id"], + ) # Validate client name - match = re.search(VALID_NAME_REGEX, config['name']) + match = re.search(VALID_NAME_REGEX, config["name"]) if not match: - raise ValueError('Unallowed character in client name. Allowed characters: a-z, A-Z, 0-9, _, -.') + raise ValueError("Unallowed character in client name. Allowed characters: a-z, A-Z, 0-9, _, -.") # Folder where the client will store downloaded compute package and logs - self.name = config['name'] + self.name = config["name"] if FEDN_PACKAGE_EXTRACT_DIR: self.run_path = os.path.join(os.getcwd(), FEDN_PACKAGE_EXTRACT_DIR) else: - dirname = self.name+"-"+time.strftime("%Y%m%d-%H%M%S") + dirname = self.name + "-" + time.strftime("%Y%m%d-%H%M%S") self.run_path = os.path.join(os.getcwd(), dirname) if not os.path.exists(self.run_path): os.mkdir(self.run_path) @@ -102,8 +102,7 @@ def __init__(self, config): self._initialize_helper(combiner_config) if not self.helper: - logger.warning("Failed to retrieve helper class settings: {}".format( - combiner_config)) + logger.warning("Failed to retrieve helper class settings: {}".format(combiner_config)) self._subscribe_to_combiner(self.config) @@ -147,14 +146,14 @@ def _add_grpc_metadata(self, key, value): :type value: str """ # Check if metadata exists and add if not - if not hasattr(self, 'metadata'): + if not hasattr(self, "metadata"): self.metadata = () # Check if metadata key already exists and replace value if so for i, (k, v) in enumerate(self.metadata): if k == key: # Replace value - self.metadata = self.metadata[:i] + ((key, value),) + self.metadata[i + 1:] + self.metadata = self.metadata[:i] + ((key, value),) + self.metadata[i + 1 :] return # Set metadata using tuple concatenation @@ -186,39 +185,38 @@ def connect(self, combiner_config): return None # TODO use the combiner_config['certificate'] for setting up secure comms' - host = combiner_config['host'] + host = combiner_config["host"] # Add host to gRPC metadata - self._add_grpc_metadata('grpc-server', host) + self._add_grpc_metadata("grpc-server", host) logger.debug("Client using metadata: {}.".format(self.metadata)) - port = combiner_config['port'] + port = combiner_config["port"] secure = False - if combiner_config['fqdn'] is not None: - host = combiner_config['fqdn'] + if combiner_config["fqdn"] is not None: + host = combiner_config["fqdn"] # assuming https if fqdn is used port = 443 logger.info(f"Initiating connection to combiner host at: {host}:{port}") - if combiner_config['certificate']: + if combiner_config["certificate"]: logger.info("Utilizing CA certificate for GRPC channel authentication.") secure = True - cert = base64.b64decode( - combiner_config['certificate']) # .decode('utf-8') + cert = base64.b64decode(combiner_config["certificate"]) # .decode('utf-8') credentials = grpc.ssl_channel_credentials(root_certificates=cert) channel = grpc.secure_channel("{}:{}".format(host, str(port)), credentials) elif os.getenv("FEDN_GRPC_ROOT_CERT_PATH"): secure = True logger.info("Using root certificate from environment variable for GRPC channel.") - with open(os.environ["FEDN_GRPC_ROOT_CERT_PATH"], 'rb') as f: + with open(os.environ["FEDN_GRPC_ROOT_CERT_PATH"], "rb") as f: credentials = grpc.ssl_channel_credentials(f.read()) channel = grpc.secure_channel("{}:{}".format(host, str(port)), credentials) - elif self.config['secure']: + elif self.config["secure"]: secure = True logger.info("Using CA certificate for GRPC channel.") cert = self._get_ssl_certificate(host, port=port) - credentials = grpc.ssl_channel_credentials(cert.encode('utf-8')) - if self.config['token']: - token = self.config['token'] + credentials = grpc.ssl_channel_credentials(cert.encode("utf-8")) + if self.config["token"]: + token = self.config["token"] auth_creds = grpc.metadata_call_credentials(GrpcAuth(token)) channel = grpc.secure_channel("{}:{}".format(host, str(port)), grpc.composite_channel_credentials(credentials, auth_creds)) else: @@ -227,9 +225,7 @@ def connect(self, combiner_config): logger.info("Using insecure GRPC channel.") if port == 443: port = 80 - channel = grpc.insecure_channel("{}:{}".format( - host, - str(port))) + channel = grpc.insecure_channel("{}:{}".format(host, str(port))) self.channel = channel @@ -237,12 +233,9 @@ def connect(self, combiner_config): self.combinerStub = rpc.CombinerStub(channel) self.modelStub = rpc.ModelServiceStub(channel) - logger.info("Successfully established {} connection to {}:{}".format("secure" if secure else "insecure", - host, - port)) + logger.info("Successfully established {} connection to {}:{}".format("secure" if secure else "insecure", host, port)) - logger.info("Using {} compute package.".format( - combiner_config["package"])) + logger.info("Using {} compute package.".format(combiner_config["package"])) self._connected = True @@ -265,8 +258,8 @@ def _initialize_helper(self, combiner_config): :return: """ - if 'helper_type' in combiner_config.keys(): - self.helper = get_helper(combiner_config['helper_type']) + if "helper_type" in combiner_config.keys(): + self.helper = get_helper(combiner_config["helper_type"]) def _subscribe_to_combiner(self, config): """Listen to combiner message stream and start all processing threads. @@ -277,12 +270,10 @@ def _subscribe_to_combiner(self, config): """ # Start sending heartbeats to the combiner. - threading.Thread(target=self._send_heartbeat, kwargs={ - 'update_frequency': config['heartbeat_interval']}, daemon=True).start() + threading.Thread(target=self._send_heartbeat, kwargs={"update_frequency": config["heartbeat_interval"]}, daemon=True).start() # Start listening for combiner training and validation messages - threading.Thread( - target=self._listen_to_task_stream, daemon=True).start() + threading.Thread(target=self._listen_to_task_stream, daemon=True).start() self._connected = True # Start processing the client message inbox @@ -294,7 +285,7 @@ def untar_package(self, package_runtime): return package_runpath def _initialize_dispatcher(self, config): - """ Initialize the dispatcher for the client. + """Initialize the dispatcher for the client. :param config: A configuration dictionary containing connection information for | the discovery service (controller) and settings governing e.g. @@ -310,11 +301,7 @@ def _initialize_dispatcher(self, config): while tries > 0: retval = pr.download( - host=config['discover_host'], - port=config['discover_port'], - token=config['token'], - force_ssl=config['force_ssl'], - secure=config['secure'] + host=config["discover_host"], port=config["discover_port"], token=config["token"], force_ssl=config["force_ssl"], secure=config["secure"] ) if retval: break @@ -323,10 +310,10 @@ def _initialize_dispatcher(self, config): tries -= 1 if retval: - if 'checksum' not in config: + if "checksum" not in config: logger.warning("Bypassing validation of package checksum. Ensure the package source is trusted.") else: - checks_out = pr.validate(config['checksum']) + checks_out = pr.validate(config["checksum"]) if not checks_out: logger.critical("Validation of local package failed. Client terminating.") self.error_state = True @@ -348,11 +335,14 @@ def _initialize_dispatcher(self, config): else: # TODO: Deprecate - dispatch_config = {'entry_points': - {'predict': {'command': 'python3 predict.py'}, - 'train': {'command': 'python3 train.py'}, - 'validate': {'command': 'python3 validate.py'}}} - from_path = os.path.join(os.getcwd(), 'client') + dispatch_config = { + "entry_points": { + "predict": {"command": "python3 predict.py"}, + "train": {"command": "python3 train.py"}, + "validate": {"command": "python3 validate.py"}, + } + } + from_path = os.path.join(os.getcwd(), "client") copy_tree(from_path, self.run_path) self.dispatcher = Dispatcher(dispatch_config, self.run_path) @@ -378,7 +368,6 @@ def get_model_from_combiner(self, id, timeout=20): try: for part in self.modelStub.Download(request, metadata=self.metadata): - if part.status == fedn.ModelStatus.IN_PROGRESS: data.write(part.data) @@ -436,7 +425,7 @@ def _listen_to_task_stream(self): r.sender.name = self.name r.sender.role = fedn.WORKER # Add client to metadata - self._add_grpc_metadata('client', self.name) + self._add_grpc_metadata("client", self.name) while self._connected: try: @@ -445,14 +434,19 @@ def _listen_to_task_stream(self): logger.debug("Received model update request from combiner: {}.".format(request)) if request.sender.role == fedn.COMBINER: # Process training request - self.send_status("Received model update request.", log_level=fedn.Status.AUDIT, - type=fedn.StatusType.MODEL_UPDATE_REQUEST, request=request, sesssion_id=request.session_id) + self.send_status( + "Received model update request.", + log_level=fedn.Status.AUDIT, + type=fedn.StatusType.MODEL_UPDATE_REQUEST, + request=request, + sesssion_id=request.session_id, + ) logger.info("Received model update request of type {} for model_id {}".format(request.type, request.model_id)) - if request.type == fedn.StatusType.MODEL_UPDATE and self.config['trainer']: - self.inbox.put(('train', request)) - elif request.type == fedn.StatusType.MODEL_VALIDATION and self.config['validator']: - self.inbox.put(('validate', request)) + if request.type == fedn.StatusType.MODEL_UPDATE and self.config["trainer"]: + self.inbox.put(("train", request)) + elif request.type == fedn.StatusType.MODEL_VALIDATION and self.config["validator"]: + self.inbox.put(("validate", request)) else: logger.error("Unknown request type: {}".format(request.type)) @@ -465,7 +459,7 @@ def _listen_to_task_stream(self): time.sleep(5) if status_code == grpc.StatusCode.UNAUTHENTICATED: details = e.details() - if details == 'Token expired': + if details == "Token expired": logger.warning("GRPC TaskStream: Token expired. Reconnecting.") self.detach() @@ -496,8 +490,7 @@ def _process_training_request(self, model_id: str, session_id: str = None): :rtype: tuple """ - self.send_status( - "\t Starting processing of training request for model_id {}".format(model_id), sesssion_id=session_id) + self.send_status("\t Starting processing of training request for model_id {}".format(model_id), sesssion_id=session_id) self.state = ClientState.training try: @@ -507,10 +500,10 @@ def _process_training_request(self, model_id: str, session_id: str = None): if mdl is None: logger.error("Could not retrieve model from combiner. Aborting training request.") return None, None - meta['fetch_model'] = time.time() - tic + meta["fetch_model"] = time.time() - tic inpath = self.helper.get_tmp_path() - with open(inpath, 'wb') as fh: + with open(inpath, "wb") as fh: fh.write(mdl.getbuffer()) outpath = self.helper.get_tmp_path() @@ -519,7 +512,7 @@ def _process_training_request(self, model_id: str, session_id: str = None): self.dispatcher.run_cmd("train {} {}".format(inpath, outpath)) - meta['exec_training'] = time.time() - tic + meta["exec_training"] = time.time() - tic tic = time.time() out_model = None @@ -530,21 +523,21 @@ def _process_training_request(self, model_id: str, session_id: str = None): # Stream model update to combiner server updated_model_id = uuid.uuid4() self.send_model_to_combiner(out_model, str(updated_model_id)) - meta['upload_model'] = time.time() - tic + meta["upload_model"] = time.time() - tic # Read the metadata file - with open(outpath+'-metadata', 'r') as fh: + with open(outpath + "-metadata", "r") as fh: training_metadata = json.loads(fh.read()) - meta['training_metadata'] = training_metadata + meta["training_metadata"] = training_metadata os.unlink(inpath) os.unlink(outpath) - os.unlink(outpath+'-metadata') + os.unlink(outpath + "-metadata") except Exception as e: logger.error("Could not process training request due to error: {}".format(e)) updated_model_id = None - meta = {'status': 'failed', 'error': str(e)} + meta = {"status": "failed", "error": str(e)} self.state = ClientState.idle @@ -564,12 +557,11 @@ def _process_validation_request(self, model_id: str, is_inference: bool, session """ # Figure out cmd if is_inference: - cmd = 'infer' + cmd = "infer" else: - cmd = 'validate' + cmd = "validate" - self.send_status( - f"Processing {cmd} request for model_id {model_id}", sesssion_id=session_id) + self.send_status(f"Processing {cmd} request for model_id {model_id}", sesssion_id=session_id) self.state = ClientState.validating try: model = self.get_model_from_combiner(str(model_id)) @@ -599,25 +591,22 @@ def _process_validation_request(self, model_id: str, is_inference: bool, session return validation def process_request(self): - """Process training and validation tasks. """ + """Process training and validation tasks.""" while True: - if not self._connected: return try: (task_type, request) = self.inbox.get(timeout=1.0) - if task_type == 'train': - + if task_type == "train": tic = time.time() self.state = ClientState.training - model_id, meta = self._process_training_request( - request.model_id, session_id=request.session_id) + model_id, meta = self._process_training_request(request.model_id, session_id=request.session_id) if meta is not None: - processing_time = time.time()-tic - meta['processing_time'] = processing_time - meta['config'] = request.data + processing_time = time.time() - tic + meta["processing_time"] = processing_time + meta["config"] = request.data if model_id is not None: # Send model update to combiner @@ -634,28 +623,31 @@ def process_request(self): try: _ = self.combinerStub.SendModelUpdate(update, metadata=self.metadata) - self.send_status("Model update completed.", log_level=fedn.Status.AUDIT, - type=fedn.StatusType.MODEL_UPDATE, request=update, sesssion_id=request.session_id) + self.send_status( + "Model update completed.", + log_level=fedn.Status.AUDIT, + type=fedn.StatusType.MODEL_UPDATE, + request=update, + sesssion_id=request.session_id, + ) except grpc.RpcError as e: status_code = e.code() - logger.error("GRPC error, {}.".format( - status_code.name)) + logger.error("GRPC error, {}.".format(status_code.name)) logger.debug(e) except ValueError as e: logger.error("GRPC error, RPC channel closed. {}".format(e)) logger.debug(e) else: - self.send_status("Client {} failed to complete model update.", - log_level=fedn.Status.WARNING, - request=request, sesssion_id=request.session_id) + self.send_status( + "Client {} failed to complete model update.", log_level=fedn.Status.WARNING, request=request, sesssion_id=request.session_id + ) self.state = ClientState.idle self.inbox.task_done() - elif task_type == 'validate': + elif task_type == "validate": self.state = ClientState.validating - metrics = self._process_validation_request( - request.model_id, False, request.session_id) + metrics = self._process_validation_request(request.model_id, False, request.session_id) if metrics is not None: # Send validation @@ -671,23 +663,26 @@ def process_request(self): validation.session_id = request.session_id try: - _ = self.combinerStub.SendModelValidation( - validation, metadata=self.metadata) + _ = self.combinerStub.SendModelValidation(validation, metadata=self.metadata) status_type = fedn.StatusType.MODEL_VALIDATION - self.send_status("Model validation completed.", log_level=fedn.Status.AUDIT, - type=status_type, request=validation, sesssion_id=request.session_id) + self.send_status( + "Model validation completed.", log_level=fedn.Status.AUDIT, type=status_type, request=validation, sesssion_id=request.session_id + ) except grpc.RpcError as e: status_code = e.code() - logger.error("GRPC error, {}.".format( - status_code.name)) + logger.error("GRPC error, {}.".format(status_code.name)) logger.debug(e) except ValueError as e: logger.error("GRPC error, RPC channel closed. {}".format(e)) logger.debug(e) else: - self.send_status("Client {} failed to complete model validation.".format(self.name), - log_level=fedn.Status.WARNING, request=request, sesssion_id=request.session_id) + self.send_status( + "Client {} failed to complete model validation.".format(self.name), + log_level=fedn.Status.WARNING, + request=request, + sesssion_id=request.session_id, + ) self.state = ClientState.idle self.inbox.task_done() @@ -705,8 +700,7 @@ def _send_heartbeat(self, update_frequency=2.0): :rtype: None """ while True: - heartbeat = fedn.Heartbeat(sender=fedn.Client( - name=self.name, role=fedn.WORKER)) + heartbeat = fedn.Heartbeat(sender=fedn.Client(name=self.name, role=fedn.WORKER)) try: self.connectorStub.SendHeartbeat(heartbeat, metadata=self.metadata) self._missed_heartbeat = 0 @@ -714,13 +708,16 @@ def _send_heartbeat(self, update_frequency=2.0): status_code = e.code() if status_code == grpc.StatusCode.UNAVAILABLE: self._missed_heartbeat += 1 - logger.error("GRPC hearbeat: combiner unavailable, retrying (attempt {}/{}).".format(self._missed_heartbeat, - self.config['reconnect_after_missed_heartbeat'])) - if self._missed_heartbeat > self.config['reconnect_after_missed_heartbeat']: + logger.error( + "GRPC hearbeat: combiner unavailable, retrying (attempt {}/{}).".format( + self._missed_heartbeat, self.config["reconnect_after_missed_heartbeat"] + ) + ) + if self._missed_heartbeat > self.config["reconnect_after_missed_heartbeat"]: self.disconnect() if status_code == grpc.StatusCode.UNAUTHENTICATED: details = e.details() - if details == 'Token expired': + if details == "Token expired": logger.error("GRPC hearbeat: Token expired. Disconnecting.") self.disconnect() sys.exit("Unauthorized. Token expired. Please obtain a new token.") @@ -761,9 +758,7 @@ def send_status(self, msg, log_level=fedn.Status.INFO, type=None, request=None, if request is not None: status.data = MessageToJson(request) - self.logs.append( - "{} {} LOG LEVEL {} MESSAGE {}".format(str(datetime.now()), status.sender.name, status.log_level, - status.status)) + self.logs.append("{} {} LOG LEVEL {} MESSAGE {}".format(str(datetime.now()), status.sender.name, status.log_level, status.status)) try: _ = self.connectorStub.SendStatus(status, metadata=self.metadata) except grpc.RpcError as e: @@ -772,11 +767,11 @@ def send_status(self, msg, log_level=fedn.Status.INFO, type=None, request=None, logger.warning("GRPC SendStatus: server unavailable during send status.") if status_code == grpc.StatusCode.UNAUTHENTICATED: details = e.details() - if details == 'Token expired': + if details == "Token expired": logger.warning("GRPC SendStatus: Token expired.") def run(self): - """ Run the client. """ + """Run the client.""" try: cnt = 0 old_state = self.state diff --git a/fedn/network/clients/connect.py b/fedn/network/clients/connect.py index bbe2599b8..efeb3d1e9 100644 --- a/fedn/network/clients/connect.py +++ b/fedn/network/clients/connect.py @@ -8,14 +8,13 @@ import requests -from fedn.common.config import (FEDN_AUTH_REFRESH_TOKEN, - FEDN_AUTH_REFRESH_TOKEN_URI, FEDN_AUTH_SCHEME, - FEDN_CUSTOM_URL_PREFIX) +from fedn.common.config import FEDN_AUTH_REFRESH_TOKEN, FEDN_AUTH_REFRESH_TOKEN_URI, FEDN_AUTH_SCHEME, FEDN_CUSTOM_URL_PREFIX from fedn.common.log_config import logger class Status(enum.Enum): - """ Enum for representing the status of a client assignment.""" + """Enum for representing the status of a client assignment.""" + Unassigned = 0 Assigned = 1 TryAgain = 2 @@ -24,7 +23,7 @@ class Status(enum.Enum): class ConnectorClient: - """ Connector for assigning client to a combiner in the FEDn network. + """Connector for assigning client to a combiner in the FEDn network. :param host: host of discovery service :type host: str @@ -46,7 +45,6 @@ class ConnectorClient: """ def __init__(self, host, port, token, name, remote_package, force_ssl=False, verify=False, combiner=None, id=None): - self.host = host self.port = port self.token = token @@ -54,7 +52,7 @@ def __init__(self, host, port, token, name, remote_package, force_ssl=False, ver self.verify = verify self.preferred_combiner = combiner self.id = id - self.package = 'remote' if remote_package else 'local' + self.package = "remote" if remote_package else "local" # for https we assume a an ingress handles permanent redirect (308) if force_ssl: @@ -62,11 +60,9 @@ def __init__(self, host, port, token, name, remote_package, force_ssl=False, ver else: self.prefix = "http://" if self.port: - self.connect_string = "{}{}:{}".format( - self.prefix, self.host, self.port) + self.connect_string = "{}{}:{}".format(self.prefix, self.host, self.port) else: - self.connect_string = "{}{}".format( - self.prefix, self.host) + self.connect_string = "{}{}".format(self.prefix, self.host) logger.info("Setting connection string to {}.".format(self.connect_string)) @@ -79,26 +75,28 @@ def assign(self): """ try: retval = None - payload = {'client_id': self.name, 'preferred_combiner': self.preferred_combiner} - retval = requests.post(self.connect_string + FEDN_CUSTOM_URL_PREFIX + '/add_client', - json=payload, - verify=self.verify, - allow_redirects=True, - headers={'Authorization': f"{FEDN_AUTH_SCHEME} {self.token}"}) + payload = {"client_id": self.name, "preferred_combiner": self.preferred_combiner} + retval = requests.post( + self.connect_string + FEDN_CUSTOM_URL_PREFIX + "/add_client", + json=payload, + verify=self.verify, + allow_redirects=True, + headers={"Authorization": f"{FEDN_AUTH_SCHEME} {self.token}"}, + ) except Exception as e: - logger.debug('***** {}'.format(e)) + logger.debug("***** {}".format(e)) return Status.Unassigned, {} if retval.status_code == 400: # Get error messange from response - reason = retval.json()['message'] + reason = retval.json()["message"] return Status.UnMatchedConfig, reason if retval.status_code == 401: - if 'message' in retval.json(): - reason = retval.json()['message'] + if "message" in retval.json(): + reason = retval.json()["message"] logger.warning(reason) - if reason == 'Token expired': + if reason == "Token expired": status_code = self.refresh_token() if status_code >= 200 and status_code < 204: logger.info("Token refreshed.") @@ -109,19 +107,19 @@ def assign(self): return Status.UnAuthorized, reason if retval.status_code >= 200 and retval.status_code < 204: - if retval.json()['status'] == 'retry': - if 'message' in retval.json(): - reason = retval.json()['message'] + if retval.json()["status"] == "retry": + if "message" in retval.json(): + reason = retval.json()["message"] else: reason = "Reducer was not ready. Try again later." return Status.TryAgain, reason - reducer_package = retval.json()['package'] + reducer_package = retval.json()["package"] if reducer_package != self.package: - reason = "Unmatched config of compute package between client and reducer.\n" +\ - "Reducer uses {} package and client uses {}.".format( - reducer_package, self.package) + reason = "Unmatched config of compute package between client and reducer.\n" + "Reducer uses {} package and client uses {}.".format( + reducer_package, self.package + ) return Status.UnMatchedConfig, reason return Status.Assigned, retval.json() @@ -139,9 +137,6 @@ def refresh_token(self): logger.error("No refresh token URI/Token set, cannot refresh token.") return 401 - payload = requests.post(FEDN_AUTH_REFRESH_TOKEN_URI, - verify=self.verify, - allow_redirects=True, - json={'refresh': FEDN_AUTH_REFRESH_TOKEN}) - self.token = payload.json()['access'] + payload = requests.post(FEDN_AUTH_REFRESH_TOKEN_URI, verify=self.verify, allow_redirects=True, json={"refresh": FEDN_AUTH_REFRESH_TOKEN}) + self.token = payload.json()["access"] return payload.status_code diff --git a/fedn/network/clients/package.py b/fedn/network/clients/package.py index c06d79fc5..54f45b883 100644 --- a/fedn/network/clients/package.py +++ b/fedn/network/clients/package.py @@ -14,7 +14,7 @@ class PackageRuntime: - """ PackageRuntime is used to download, validate and unpack compute packages. + """PackageRuntime is used to download, validate and unpack compute packages. :param package_path: path to compute package :type package_path: str @@ -23,11 +23,13 @@ class PackageRuntime: """ def __init__(self, package_path): - - self.dispatch_config = {'entry_points': - {'predict': {'command': 'python3 predict.py'}, - 'train': {'command': 'python3 train.py'}, - 'validate': {'command': 'python3 validate.py'}}} + self.dispatch_config = { + "entry_points": { + "predict": {"command": "python3 predict.py"}, + "train": {"command": "python3 train.py"}, + "validate": {"command": "python3 validate.py"}, + } + } self.pkg_path = package_path self.pkg_name = None @@ -35,7 +37,7 @@ def __init__(self, package_path): self.expected_checksum = None def download(self, host, port, token, force_ssl=False, secure=False, name=None): - """ Download compute package from controller + """Download compute package from controller :param host: host of controller :param port: port of controller @@ -56,18 +58,16 @@ def download(self, host, port, token, force_ssl=False, secure=False, name=None): if name: path = path + "?name={}".format(name) - with requests.get(path, stream=True, verify=False, headers={'Authorization': f'{FEDN_AUTH_SCHEME} {token}'}) as r: + with requests.get(path, stream=True, verify=False, headers={"Authorization": f"{FEDN_AUTH_SCHEME} {token}"}) as r: if 200 <= r.status_code < 204: - - params = cgi.parse_header( - r.headers.get('Content-Disposition', ''))[-1] + params = cgi.parse_header(r.headers.get("Content-Disposition", ""))[-1] try: - self.pkg_name = params['filename'] + self.pkg_name = params["filename"] except KeyError: logger.error("No package returned.") return None r.raise_for_status() - with open(os.path.join(self.pkg_path, self.pkg_name), 'wb') as f: + with open(os.path.join(self.pkg_path, self.pkg_name), "wb") as f: for chunk in r.iter_content(chunk_size=8192): f.write(chunk) if port: @@ -77,19 +77,18 @@ def download(self, host, port, token, force_ssl=False, secure=False, name=None): if name: path = path + "?name={}".format(name) - with requests.get(path, verify=False, headers={'Authorization': f'{FEDN_AUTH_SCHEME} {token}'}) as r: + with requests.get(path, verify=False, headers={"Authorization": f"{FEDN_AUTH_SCHEME} {token}"}) as r: if 200 <= r.status_code < 204: - data = r.json() try: - self.checksum = data['checksum'] + self.checksum = data["checksum"] except Exception: logger.error("Could not extract checksum.") return True def validate(self, expected_checksum): - """ Validate the package against the checksum provided by the controller + """Validate the package against the checksum provided by the controller :param expected_checksum: checksum provided by the controller :return: True if checksums match, False otherwise @@ -107,37 +106,27 @@ def validate(self, expected_checksum): return False def unpack(self): - """ Unpack the compute package + """Unpack the compute package :return: True if unpacking was successful, False otherwise :rtype: bool """ if self.pkg_name: f = None - if self.pkg_name.endswith('tar.gz'): - f = tarfile.open(os.path.join( - self.pkg_path, self.pkg_name), 'r:gz') - if self.pkg_name.endswith('.tgz'): - f = tarfile.open(os.path.join( - self.pkg_path, self.pkg_name), 'r:gz') - if self.pkg_name.endswith('tar.bz2'): - f = tarfile.open(os.path.join( - self.pkg_path, self.pkg_name), 'r:bz2') + if self.pkg_name.endswith("tar.gz"): + f = tarfile.open(os.path.join(self.pkg_path, self.pkg_name), "r:gz") + if self.pkg_name.endswith(".tgz"): + f = tarfile.open(os.path.join(self.pkg_path, self.pkg_name), "r:gz") + if self.pkg_name.endswith("tar.bz2"): + f = tarfile.open(os.path.join(self.pkg_path, self.pkg_name), "r:bz2") else: - logger.error( - "Failed to unpack compute package, no pkg_name set." - "Has the reducer been configured with a compute package?" - ) + logger.error("Failed to unpack compute package, no pkg_name set." "Has the reducer been configured with a compute package?") return False try: if f: f.extractall(self.pkg_path) - logger.info( - "Successfully extracted compute package content in {}".format( - self.pkg_path - ) - ) + logger.info("Successfully extracted compute package content in {}".format(self.pkg_path)) # delete the tarball logger.info("Deleting temporary package tarball file.") os.remove(os.path.join(self.pkg_path, self.pkg_name)) @@ -157,7 +146,7 @@ def unpack(self): return False, "" def dispatcher(self, run_path): - """ Dispatch the compute package + """Dispatch the compute package :param run_path: path to dispatch the compute package :type run_path: str diff --git a/fedn/network/clients/state.py b/fedn/network/clients/state.py index 262f5862e..a349f846e 100644 --- a/fedn/network/clients/state.py +++ b/fedn/network/clients/state.py @@ -2,14 +2,15 @@ class ClientState(Enum): - """ Enum for representing the state of a client.""" + """Enum for representing the state of a client.""" + idle = 1 training = 2 validating = 3 def ClientStateToString(state): - """ Convert a ClientState to a string representation. + """Convert a ClientState to a string representation. :param state: the state to convert :type state: :class:`fedn.network.clients.state.ClientState` diff --git a/fedn/network/combiner/combiner.py b/fedn/network/combiner/combiner.py index b3b97fd6a..f19674b73 100644 --- a/fedn/network/combiner/combiner.py +++ b/fedn/network/combiner/combiner.py @@ -12,8 +12,7 @@ import fedn.network.grpc.fedn_pb2 as fedn import fedn.network.grpc.fedn_pb2_grpc as rpc -from fedn.common.log_config import (logger, set_log_level_from_string, - set_log_stream) +from fedn.common.log_config import logger, set_log_level_from_string, set_log_stream from fedn.network.combiner.connect import ConnectorCombiner, Status from fedn.network.combiner.modelservice import ModelService from fedn.network.combiner.roundhandler import RoundHandler @@ -21,11 +20,12 @@ from fedn.network.storage.s3.repository import Repository from fedn.network.storage.statestore.mongostatestore import MongoStateStore -VALID_NAME_REGEX = '^[a-zA-Z0-9_-]*$' +VALID_NAME_REGEX = "^[a-zA-Z0-9_-]*$" class Role(Enum): - """ Enum for combiner roles. """ + """Enum for combiner roles.""" + WORKER = 1 COMBINER = 2 REDUCER = 3 @@ -33,7 +33,7 @@ class Role(Enum): def role_to_proto_role(role): - """ Convert a Role to a proto Role. + """Convert a Role to a proto Role. :param role: the role to convert :type role: :class:`fedn.network.combiner.server.Role` @@ -51,17 +51,17 @@ def role_to_proto_role(role): class Combiner(rpc.CombinerServicer, rpc.ReducerServicer, rpc.ConnectorServicer, rpc.ControlServicer): - """ Combiner gRPC server. + """Combiner gRPC server. :param config: configuration for the combiner :type config: dict """ def __init__(self, config): - """ Initialize Combiner server.""" + """Initialize Combiner server.""" - set_log_level_from_string(config.get('verbosity', "INFO")) - set_log_stream(config.get('logfile', None)) + set_log_level_from_string(config.get("verbosity", "INFO")) + set_log_stream(config.get("logfile", None)) # Client queues self.clients = {} @@ -69,24 +69,26 @@ def __init__(self, config): self.modelservice = ModelService() # Validate combiner name - match = re.search(VALID_NAME_REGEX, config['name']) + match = re.search(VALID_NAME_REGEX, config["name"]) if not match: - raise ValueError('Unallowed character in combiner name. Allowed characters: a-z, A-Z, 0-9, _, -.') + raise ValueError("Unallowed character in combiner name. Allowed characters: a-z, A-Z, 0-9, _, -.") - self.id = config['name'] + self.id = config["name"] self.role = Role.COMBINER - self.max_clients = config['max_clients'] + self.max_clients = config["max_clients"] # Connector to announce combiner to discover service (reducer) - announce_client = ConnectorCombiner(host=config['discover_host'], - port=config['discover_port'], - myhost=config['host'], - fqdn=config['fqdn'], - myport=config['port'], - token=config['token'], - name=config['name'], - secure=config['secure'], - verify=config['verify']) + announce_client = ConnectorCombiner( + host=config["discover_host"], + port=config["discover_port"], + myhost=config["host"], + fqdn=config["fqdn"], + myport=config["port"], + token=config["token"], + name=config["name"], + secure=config["secure"], + verify=config["verify"], + ) while True: # Announce combiner to discover service @@ -107,27 +109,20 @@ def __init__(self, config): logger.info("Status.UnMatchedConfig") sys.exit("Exiting: Missing config") - cert = announce_config['certificate'] - key = announce_config['key'] + cert = announce_config["certificate"] + key = announce_config["key"] - if announce_config['certificate']: - cert = base64.b64decode(announce_config['certificate']) # .decode('utf-8') - key = base64.b64decode(announce_config['key']) # .decode('utf-8') + if announce_config["certificate"]: + cert = base64.b64decode(announce_config["certificate"]) # .decode('utf-8') + key = base64.b64decode(announce_config["key"]) # .decode('utf-8') # Set up gRPC server configuration - grpc_config = {'port': config['port'], - 'secure': config['secure'], - 'certificate': cert, - 'key': key} + grpc_config = {"port": config["port"], "secure": config["secure"], "certificate": cert, "key": key} # Set up model repository - self.repository = Repository( - announce_config['storage']['storage_config']) + self.repository = Repository(announce_config["storage"]["storage_config"]) - self.statestore = MongoStateStore( - announce_config['statestore']['network_id'], - announce_config['statestore']['mongo_config'] - ) + self.statestore = MongoStateStore(announce_config["statestore"]["network_id"], announce_config["statestore"]["mongo_config"]) # Create gRPC server self.server = Server(self, self.modelservice, grpc_config) @@ -144,7 +139,7 @@ def __init__(self, config): self.server.start() def __whoami(self, client, instance): - """ Set the client id and role in a proto message. + """Set the client id and role in a proto message. :param client: the client to set the id and role for :type client: :class:`fedn.network.grpc.fedn_pb2.Client` @@ -158,7 +153,7 @@ def __whoami(self, client, instance): return client def request_model_update(self, config, clients=[]): - """ Ask clients to update the current global model. + """Ask clients to update the current global model. :param config: the model configuration to send to clients :type config: dict @@ -168,12 +163,12 @@ def request_model_update(self, config, clients=[]): """ # The request to be added to the client queue request = fedn.TaskRequest() - request.model_id = config['model_id'] + request.model_id = config["model_id"] request.correlation_id = str(uuid.uuid4()) request.timestamp = str(datetime.now()) request.data = json.dumps(config) request.type = fedn.StatusType.MODEL_UPDATE - request.session_id = config['session_id'] + request.session_id = config["session_id"] request.sender.name = self.id request.sender.role = fedn.COMBINER @@ -187,14 +182,12 @@ def request_model_update(self, config, clients=[]): self._put_request_to_client_queue(request, fedn.Queue.TASK_QUEUE) if len(clients) < 20: - logger.info("Sent model update request for model {} to clients {}".format( - request.model_id, clients)) + logger.info("Sent model update request for model {} to clients {}".format(request.model_id, clients)) else: - logger.info("Sent model update request for model {} to {} clients".format( - request.model_id, len(clients))) + logger.info("Sent model update request for model {} to {} clients".format(request.model_id, len(clients))) def request_model_validation(self, model_id, config, clients=[]): - """ Ask clients to validate the current global model. + """Ask clients to validate the current global model. :param model_id: the model id to validate :type model_id: str @@ -225,14 +218,12 @@ def request_model_validation(self, model_id, config, clients=[]): self._put_request_to_client_queue(request, fedn.Queue.TASK_QUEUE) if len(clients) < 20: - logger.info("Sent model validation request for model {} to clients {}".format( - request.model_id, clients)) + logger.info("Sent model validation request for model {} to clients {}".format(request.model_id, clients)) else: - logger.info("Sent model validation request for model {} to {} clients".format( - request.model_id, len(clients))) + logger.info("Sent model validation request for model {} to {} clients".format(request.model_id, len(clients))) def get_active_trainers(self): - """ Get a list of active trainers. + """Get a list of active trainers. :return: the list of active trainers :rtype: list @@ -241,7 +232,7 @@ def get_active_trainers(self): return trainers def get_active_validators(self): - """ Get a list of active validators. + """Get a list of active validators. :return: the list of active validators :rtype: list @@ -250,7 +241,7 @@ def get_active_validators(self): return validators def nr_active_trainers(self): - """ Get the number of active trainers. + """Get the number of active trainers. :return: the number of active trainers :rtype: int @@ -260,7 +251,7 @@ def nr_active_trainers(self): #################################################################################################################### def __join_client(self, client): - """ Add a client to the list of active clients. + """Add a client to the list of active clients. :param client: the client to add :type client: :class:`fedn.network.grpc.fedn_pb2.Client` @@ -270,7 +261,7 @@ def __join_client(self, client): self.clients[client.name] = {"lastseen": datetime.now(), "status": "offline"} def _subscribe_client_to_queue(self, client, queue_name): - """ Subscribe a client to the queue. + """Subscribe a client to the queue. :param client: the client to subscribe :type client: :class:`fedn.network.grpc.fedn_pb2.Client` @@ -282,7 +273,7 @@ def _subscribe_client_to_queue(self, client, queue_name): self.clients[client.name][queue_name] = queue.Queue() def __get_queue(self, client, queue_name): - """ Get the queue for a client. + """Get the queue for a client. :param client: the client to get the queue for :type client: :class:`fedn.network.grpc.fedn_pb2.Client` @@ -299,7 +290,7 @@ def __get_queue(self, client, queue_name): raise def _list_subscribed_clients(self, queue_name): - """ List all clients subscribed to a queue. + """List all clients subscribed to a queue. :param queue_name: the name of the queue :type queue_name: str @@ -313,7 +304,7 @@ def _list_subscribed_clients(self, queue_name): return subscribed_clients def _list_active_clients(self, channel): - """ List all clients that have sent a status message in the last 10 seconds. + """List all clients that have sent a status message in the last 10 seconds. :param channel: the name of the channel :type channel: str @@ -350,14 +341,14 @@ def _list_active_clients(self, channel): return clients["active_clients"] def _deamon_thread_client_status(self, timeout=10): - """ Deamon thread that checks for inactive clients and updates statestore. """ + """Deamon thread that checks for inactive clients and updates statestore.""" while True: time.sleep(timeout) # TODO: Also update validation clients self._list_active_clients(fedn.Queue.TASK_QUEUE) def _put_request_to_client_queue(self, request, queue_name): - """ Get a client specific queue and add a request to it. + """Get a client specific queue and add a request to it. The client is identified by the request.receiver. :param request: the request to send @@ -369,14 +360,11 @@ def _put_request_to_client_queue(self, request, queue_name): q = self.__get_queue(request.receiver, queue_name) q.put(request) except Exception as e: - logger.error("Failed to put request to client queue {} for client {}: {}".format( - queue_name, - request.receiver.name, - str(e))) + logger.error("Failed to put request to client queue {} for client {}: {}".format(queue_name, request.receiver.name, str(e))) raise def _send_status(self, status): - """ Report a status to backend db. + """Report a status to backend db. :param status: the status to report :type status: :class:`fedn.network.grpc.fedn_pb2.Status` @@ -406,7 +394,7 @@ def _flush_model_update_queue(self): # Controller Service def Start(self, control: fedn.ControlRequest, context): - """ Start a round of federated learning" + """Start a round of federated learning" :param control: the control request :type control: :class:`fedn.network.grpc.fedn_pb2.ControlRequest` @@ -434,7 +422,7 @@ def Start(self, control: fedn.ControlRequest, context): return response def SetAggregator(self, control: fedn.ControlRequest, context): - """ Set the active aggregator. + """Set the active aggregator. :param control: the control request :type control: :class:`fedn.network.grpc.fedn_pb2.ControlRequest` @@ -451,14 +439,14 @@ def SetAggregator(self, control: fedn.ControlRequest, context): response = fedn.ControlResponse() if status: - response.message = 'Success' + response.message = "Success" else: - response.message = 'Failed' + response.message = "Failed" return response def FlushAggregationQueue(self, control: fedn.ControlRequest, context): - """ Flush the queue. + """Flush the queue. :param control: the control request :type control: :class:`fedn.network.grpc.fedn_pb2.ControlRequest` @@ -472,16 +460,16 @@ def FlushAggregationQueue(self, control: fedn.ControlRequest, context): response = fedn.ControlResponse() if status: - response.message = 'Success' + response.message = "Success" else: - response.message = 'Failed' + response.message = "Failed" return response ############################################################################## def Stop(self, control: fedn.ControlRequest, context): - """ TODO: Not yet implemented. + """TODO: Not yet implemented. :param control: the control request :type control: :class:`fedn.network.grpc.fedn_pb2.ControlRequest` @@ -497,7 +485,7 @@ def Stop(self, control: fedn.ControlRequest, context): ##################################################################################################################### def SendStatus(self, status: fedn.Status, context): - """ A client RPC endpoint that accepts status messages. + """A client RPC endpoint that accepts status messages. :param status: the status message :type status: :class:`fedn.network.grpc.fedn_pb2.Status` @@ -514,7 +502,7 @@ def SendStatus(self, status: fedn.Status, context): return response def ListActiveClients(self, request: fedn.ListClientsRequest, context): - """ RPC endpoint that returns a ClientList containing the names of all active clients. + """RPC endpoint that returns a ClientList containing the names of all active clients. An active client has sent a status message / responded to a heartbeat request in the last 10 seconds. @@ -538,7 +526,7 @@ def ListActiveClients(self, request: fedn.ListClientsRequest, context): return clients def AcceptingClients(self, request: fedn.ConnectionRequest, context): - """ RPC endpoint that returns a ConnectionResponse indicating whether the server + """RPC endpoint that returns a ConnectionResponse indicating whether the server is accepting clients or not. :param request: the request (unused) @@ -549,8 +537,7 @@ def AcceptingClients(self, request: fedn.ConnectionRequest, context): :rtype: :class:`fedn.network.grpc.fedn_pb2.ConnectionResponse` """ response = fedn.ConnectionResponse() - active_clients = self._list_active_clients( - fedn.Queue.TASK_QUEUE) + active_clients = self._list_active_clients(fedn.Queue.TASK_QUEUE) try: requested = int(self.max_clients) @@ -569,7 +556,7 @@ def AcceptingClients(self, request: fedn.ConnectionRequest, context): return response def SendHeartbeat(self, heartbeat: fedn.Heartbeat, context): - """ RPC that lets clients send a hearbeat, notifying the server that + """RPC that lets clients send a hearbeat, notifying the server that the client is available. :param heartbeat: the heartbeat @@ -594,7 +581,7 @@ def SendHeartbeat(self, heartbeat: fedn.Heartbeat, context): # Combiner Service def TaskStream(self, response, context): - """ A server stream RPC endpoint (Update model). Messages from client stream. + """A server stream RPC endpoint (Update model). Messages from client stream. :param response: the response :type response: :class:`fedn.network.grpc.fedn_pb2.ModelUpdateRequest` @@ -606,17 +593,15 @@ def TaskStream(self, response, context): metadata = context.invocation_metadata() if metadata: metadata = dict(metadata) - logger.info("grpc.Combiner.TaskStream: Client connected: {}\n".format(metadata['client'])) + logger.info("grpc.Combiner.TaskStream: Client connected: {}\n".format(metadata["client"])) - status = fedn.Status( - status="Client {} connecting to TaskStream.".format(client.name)) + status = fedn.Status(status="Client {} connecting to TaskStream.".format(client.name)) status.log_level = fedn.Status.INFO status.timestamp.GetCurrentTime() self.__whoami(status.sender, self) - self._subscribe_client_to_queue( - client, fedn.Queue.TASK_QUEUE) + self._subscribe_client_to_queue(client, fedn.Queue.TASK_QUEUE) q = self.__get_queue(client, fedn.Queue.TASK_QUEUE) self._send_status(status) @@ -637,7 +622,7 @@ def TaskStream(self, response, context): logger.error("Error in ModelUpdateRequestStream: {}".format(e)) def SendModelUpdate(self, request, context): - """ Send a model update response. + """Send a model update response. :param request: the request :type request: :class:`fedn.network.grpc.fedn_pb2.ModelUpdate` @@ -649,8 +634,7 @@ def SendModelUpdate(self, request, context): self.round_handler.aggregator.on_model_update(request) response = fedn.Response() - response.response = "RECEIVED ModelUpdate {} from client {}".format( - response, response.sender.name) + response.response = "RECEIVED ModelUpdate {} from client {}".format(response, response.sender.name) return response # TODO Fill later def register_model_validation(self, validation): @@ -663,7 +647,7 @@ def register_model_validation(self, validation): self.statestore.report_validation(validation) def SendModelValidation(self, request, context): - """ Send a model validation response. + """Send a model validation response. :param request: the request :type request: :class:`fedn.network.grpc.fedn_pb2.ModelValidation` @@ -677,17 +661,15 @@ def SendModelValidation(self, request, context): self.register_model_validation(request) response = fedn.Response() - response.response = "RECEIVED ModelValidation {} from client {}".format( - response, response.sender.name) + response.response = "RECEIVED ModelValidation {} from client {}".format(response, response.sender.name) return response #################################################################################################################### def run(self): - """ Start the server.""" + """Start the server.""" - logger.info("COMBINER: {} started, ready for gRPC requests.".format( - self.id)) + logger.info("COMBINER: {} started, ready for gRPC requests.".format(self.id)) try: while True: signal.pause() diff --git a/fedn/network/combiner/connect.py b/fedn/network/combiner/connect.py index 7dc388261..e144baa94 100644 --- a/fedn/network/combiner/connect.py +++ b/fedn/network/combiner/connect.py @@ -13,7 +13,8 @@ class Status(enum.Enum): - """ Enum for representing the status of a combiner announcement.""" + """Enum for representing the status of a combiner announcement.""" + Unassigned = 0 Assigned = 1 TryAgain = 2 @@ -22,7 +23,7 @@ class Status(enum.Enum): class ConnectorCombiner: - """ Connector for annnouncing combiner to the FEDn network. + """Connector for annnouncing combiner to the FEDn network. :param host: host of discovery service :type host: str @@ -45,7 +46,7 @@ class ConnectorCombiner: """ def __init__(self, host, port, myhost, fqdn, myport, token, name, secure=False, verify=False): - """ Initialize the ConnectorCombiner. + """Initialize the ConnectorCombiner. :param host: The host of the discovery service. :type host: str @@ -73,22 +74,20 @@ def __init__(self, host, port, myhost, fqdn, myport, token, name, secure=False, self.myhost = myhost self.myport = myport self.token = token - self.token_scheme = os.environ.get('FEDN_AUTH_SCHEME', 'Bearer') + self.token_scheme = os.environ.get("FEDN_AUTH_SCHEME", "Bearer") self.name = name self.secure = secure self.verify = verify if not self.token: - self.token = os.environ.get('FEDN_AUTH_TOKEN', None) + self.token = os.environ.get("FEDN_AUTH_TOKEN", None) # for https we assume a an ingress handles permanent redirect (308) self.prefix = "http://" if port: - self.connect_string = "{}{}:{}".format( - self.prefix, self.host, self.port) + self.connect_string = "{}{}:{}".format(self.prefix, self.host, self.port) else: - self.connect_string = "{}{}".format( - self.prefix, self.host) + self.connect_string = "{}{}".format(self.prefix, self.host) logger.info("Setting connection string to {}".format(self.connect_string)) @@ -99,24 +98,21 @@ def announce(self): :return: Tuple with announcement Status, FEDn network configuration if sucessful, else None. :rtype: :class:`fedn.network.combiner.connect.Status`, str """ - payload = { - "combiner_id": self.name, - "address": self.myhost, - "fqdn": self.fqdn, - "port": self.myport, - "secure_grpc": self.secure - } - url_prefix = os.environ.get('FEDN_CUSTOM_URL_PREFIX', '') + payload = {"combiner_id": self.name, "address": self.myhost, "fqdn": self.fqdn, "port": self.myport, "secure_grpc": self.secure} + url_prefix = os.environ.get("FEDN_CUSTOM_URL_PREFIX", "") try: - retval = requests.post(self.connect_string + url_prefix + '/add_combiner', json=payload, - verify=self.verify, - headers={'Authorization': f'{self.token_scheme} {self.token}'}) + retval = requests.post( + self.connect_string + url_prefix + "/add_combiner", + json=payload, + verify=self.verify, + headers={"Authorization": f"{self.token_scheme} {self.token}"}, + ) except Exception: return Status.Unassigned, {} if retval.status_code == 400: # Get error messange from response - reason = retval.json()['message'] + reason = retval.json()["message"] return Status.UnMatchedConfig, reason if retval.status_code == 401: @@ -124,8 +120,8 @@ def announce(self): return Status.UnAuthorized, reason if retval.status_code >= 200 and retval.status_code < 204: - if retval.json()['status'] == 'retry': - reason = retval.json()['message'] + if retval.json()["status"] == "retry": + reason = retval.json()["message"] return Status.TryAgain, reason return Status.Assigned, retval.json() diff --git a/fedn/network/combiner/interfaces.py b/fedn/network/combiner/interfaces.py index 69b987be0..f247c2bc1 100644 --- a/fedn/network/combiner/interfaces.py +++ b/fedn/network/combiner/interfaces.py @@ -15,7 +15,7 @@ class CombinerUnavailableError(Exception): class Channel: - """ Wrapper for a gRPC channel. + """Wrapper for a gRPC channel. :param address: The address for the gRPC server. :type address: str @@ -26,7 +26,7 @@ class Channel: """ def __init__(self, address, port, certificate=None): - """ Create a channel. + """Create a channel. If a valid certificate is given, a secure channel is created, else insecure. @@ -42,16 +42,13 @@ def __init__(self, address, port, certificate=None): self.certificate = certificate if self.certificate: - credentials = grpc.ssl_channel_credentials( - root_certificates=copy.deepcopy(certificate)) - self.channel = grpc.secure_channel('{}:{}'.format( - self.address, str(self.port)), credentials) + credentials = grpc.ssl_channel_credentials(root_certificates=copy.deepcopy(certificate)) + self.channel = grpc.secure_channel("{}:{}".format(self.address, str(self.port)), credentials) else: - self.channel = grpc.insecure_channel( - '{}:{}'.format(self.address, str(self.port))) + self.channel = grpc.insecure_channel("{}:{}".format(self.address, str(self.port))) def get_channel(self): - """ Get a channel. + """Get a channel. :return: An instance of a gRPC channel :rtype: :class:`grpc.Channel` @@ -60,7 +57,7 @@ def get_channel(self): class CombinerInterface: - """ Interface for the Combiner (aggregation server). + """Interface for the Combiner (aggregation server). Abstraction on top of the gRPC server servicer. :param parent: The parent combiner (controller) @@ -84,7 +81,7 @@ class CombinerInterface: """ def __init__(self, parent, name, address, fqdn, port, certificate=None, key=None, ip=None, config=None): - """ Initialize the combiner interface.""" + """Initialize the combiner interface.""" self.parent = parent self.name = name self.address = address @@ -95,15 +92,13 @@ def __init__(self, parent, name, address, fqdn, port, certificate=None, key=None self.ip = ip if not config: - self.config = { - 'max_clients': 8 - } + self.config = {"max_clients": 8} else: self.config = config @classmethod def from_json(combiner_config): - """ Initialize the combiner config from a json document. + """Initialize the combiner config from a json document. :parameter combiner_config: The combiner configuration. :type combiner_config: dict @@ -113,34 +108,34 @@ def from_json(combiner_config): return CombinerInterface(**combiner_config) def to_dict(self): - """ Export combiner configuration to a dictionary. + """Export combiner configuration to a dictionary. :return: A dictionary with the combiner configuration. :rtype: dict """ data = { - 'parent': self.parent, - 'name': self.name, - 'address': self.address, - 'fqdn': self.fqdn, - 'port': self.port, - 'ip': self.ip, - 'certificate': None, - 'key': None, - 'config': self.config + "parent": self.parent, + "name": self.name, + "address": self.address, + "fqdn": self.fqdn, + "port": self.port, + "ip": self.ip, + "certificate": None, + "key": None, + "config": self.config, } if self.certificate: cert_b64 = base64.b64encode(self.certificate) key_b64 = base64.b64encode(self.key) - data['certificate'] = str(cert_b64).split('\'')[1] - data['key'] = str(key_b64).split('\'')[1] + data["certificate"] = str(cert_b64).split("'")[1] + data["key"] = str(key_b64).split("'")[1] return data def to_json(self): - """ Export combiner configuration to json. + """Export combiner configuration to json. :return: A json document with the combiner configuration. :rtype: str @@ -148,34 +143,33 @@ def to_json(self): return json.dumps(self.to_dict()) def get_certificate(self): - """ Get combiner certificate. + """Get combiner certificate. :return: The combiner certificate. :rtype: str, None if no certificate is set. """ if self.certificate: cert_b64 = base64.b64encode(self.certificate) - return str(cert_b64).split('\'')[1] + return str(cert_b64).split("'")[1] else: return None def get_key(self): - """ Get combiner key. + """Get combiner key. :return: The combiner key. :rtype: str, None if no key is set. """ if self.key: key_b64 = base64.b64encode(self.key) - return str(key_b64).split('\'')[1] + return str(key_b64).split("'")[1] else: return None def flush_model_update_queue(self): - """ Reset the model update queue on the combiner. """ + """Reset the model update queue on the combiner.""" - channel = Channel(self.address, self.port, - self.certificate).get_channel() + channel = Channel(self.address, self.port, self.certificate).get_channel() control = rpc.ControlStub(channel) request = fedn.ControlRequest() @@ -189,14 +183,13 @@ def flush_model_update_queue(self): raise def set_aggregator(self, aggregator): - """ Set the active aggregator module. + """Set the active aggregator module. :param aggregator: The name of the aggregator module. :type config: str """ - channel = Channel(self.address, self.port, - self.certificate).get_channel() + channel = Channel(self.address, self.port, self.certificate).get_channel() control = rpc.ControlStub(channel) request = fedn.ControlRequest() @@ -213,15 +206,14 @@ def set_aggregator(self, aggregator): raise def submit(self, config): - """ Submit a compute plan to the combiner. + """Submit a compute plan to the combiner. :param config: The job configuration. :type config: dict :return: Server ControlResponse object. :rtype: :class:`fedn.network.grpc.fedn_pb2.ControlResponse` """ - channel = Channel(self.address, self.port, - self.certificate).get_channel() + channel = Channel(self.address, self.port, self.certificate).get_channel() control = rpc.ControlStub(channel) request = fedn.ControlRequest() request.command = fedn.Command.START @@ -241,7 +233,7 @@ def submit(self, config): return response def get_model(self, id, timeout=10): - """ Download a model from the combiner server. + """Download a model from the combiner server. :param id: The model id. :type id: str @@ -249,8 +241,7 @@ def get_model(self, id, timeout=10): :rtype: :class:`io.BytesIO`, None if the model is not available. """ - channel = Channel(self.address, self.port, - self.certificate).get_channel() + channel = Channel(self.address, self.port, self.certificate).get_channel() modelservice = rpc.ModelServiceStub(channel) data = BytesIO() @@ -276,13 +267,12 @@ def get_model(self, id, timeout=10): continue def allowing_clients(self): - """ Check if the combiner is allowing additional client connections. + """Check if the combiner is allowing additional client connections. :return: True if accepting, else False. :rtype: bool """ - channel = Channel(self.address, self.port, - self.certificate).get_channel() + channel = Channel(self.address, self.port, self.certificate).get_channel() connector = rpc.ConnectorStub(channel) request = fedn.ConnectionRequest() @@ -303,7 +293,7 @@ def allowing_clients(self): return False def list_active_clients(self, queue=1): - """ List active clients. + """List active clients. :param queue: The channel (queue) to use (optional). Default is 1 = MODEL_UPDATE_REQUESTS channel. see :class:`fedn.network.grpc.fedn_pb2.Channel` @@ -311,8 +301,7 @@ def list_active_clients(self, queue=1): :return: A list of active clients. :rtype: json """ - channel = Channel(self.address, self.port, - self.certificate).get_channel() + channel = Channel(self.address, self.port, self.certificate).get_channel() control = rpc.ConnectorStub(channel) request = fedn.ListClientsRequest() request.channel = queue diff --git a/fedn/network/combiner/modelservice.py b/fedn/network/combiner/modelservice.py index 59c63108b..0b50edbc7 100644 --- a/fedn/network/combiner/modelservice.py +++ b/fedn/network/combiner/modelservice.py @@ -21,11 +21,9 @@ def upload_request_generator(mdl, id): while True: b = mdl.read(CHUNK_SIZE) if b: - result = fedn.ModelRequest( - data=b, id=id, status=fedn.ModelStatus.IN_PROGRESS) + result = fedn.ModelRequest(data=b, id=id, status=fedn.ModelStatus.IN_PROGRESS) else: - result = fedn.ModelRequest( - id=id, data=None, status=fedn.ModelStatus.OK) + result = fedn.ModelRequest(id=id, data=None, status=fedn.ModelStatus.OK) yield result if not b: break @@ -47,14 +45,14 @@ def model_as_bytesIO(model): def get_tmp_path(): - """ Return a temporary output path compatible with save_model, load_model. """ + """Return a temporary output path compatible with save_model, load_model.""" fd, path = tempfile.mkstemp() os.close(fd) return path def load_model_from_BytesIO(model_bytesio, helper): - """ Load a model from a BytesIO object. + """Load a model from a BytesIO object. :param model_bytesio: A BytesIO object containing the model. :type model_bytesio: :class:`io.BytesIO` :param helper: The helper object for the model. @@ -63,7 +61,7 @@ def load_model_from_BytesIO(model_bytesio, helper): :rtype: return type of helper.load """ path = get_tmp_path() - with open(path, 'wb') as fh: + with open(path, "wb") as fh: fh.write(model_bytesio) fh.flush() model = helper.load(path) @@ -72,7 +70,7 @@ def load_model_from_BytesIO(model_bytesio, helper): def serialize_model_to_BytesIO(model, helper): - """ Serialize a model to a BytesIO object. + """Serialize a model to a BytesIO object. :param model: The model object. :type model: return type of helper.load @@ -85,7 +83,7 @@ def serialize_model_to_BytesIO(model, helper): a = BytesIO() a.seek(0, 0) - with open(outfile_name, 'rb') as f: + with open(outfile_name, "rb") as f: a.write(f.read()) a.seek(0) os.unlink(outfile_name) @@ -93,15 +91,13 @@ def serialize_model_to_BytesIO(model, helper): class ModelService(rpc.ModelServiceServicer): - """ Service for handling download and upload of models to the server. - - """ + """Service for handling download and upload of models to the server.""" def __init__(self): self.temp_model_storage = TempModelStorage() def exist(self, model_id): - """ Check if a model exists on the server. + """Check if a model exists on the server. :param model_id: The model id. :return: True if the model exists, else False. @@ -109,7 +105,7 @@ def exist(self, model_id): return self.temp_model_storage.exist(model_id) def get_model(self, id): - """ Download model with id 'id' from server. + """Download model with id 'id' from server. :param id: The model id. :type id: str @@ -131,7 +127,7 @@ def get_model(self, id): return None def set_model(self, model, id): - """ Upload model to server. + """Upload model to server. :param model: A model object (BytesIO) :type model: :class:`io.BytesIO` @@ -144,7 +140,7 @@ def set_model(self, model, id): # Model Service def Upload(self, request_iterator, context): - """ RPC endpoints for uploading a model. + """RPC endpoints for uploading a model. :param request_iterator: The model request iterator. :type request_iterator: :class:`fedn.network.grpc.fedn_pb2.ModelRequest` @@ -161,8 +157,7 @@ def Upload(self, request_iterator, context): self.temp_model_storage.set_model_metadata(request.id, fedn.ModelStatus.IN_PROGRESS) if request.status == fedn.ModelStatus.OK and not request.data: - result = fedn.ModelResponse(id=request.id, status=fedn.ModelStatus.OK, - message="Got model successfully.") + result = fedn.ModelResponse(id=request.id, status=fedn.ModelStatus.OK, message="Got model successfully.") # self.temp_model_storage_metadata.update({request.id: fedn.ModelStatus.OK}) self.temp_model_storage.set_model_metadata(request.id, fedn.ModelStatus.OK) self.temp_model_storage.get_ptr(request.id).flush() @@ -170,7 +165,7 @@ def Upload(self, request_iterator, context): return result def Download(self, request, context): - """ RPC endpoints for downloading a model. + """RPC endpoints for downloading a model. :param request: The model request object. :type request: :class:`fedn.network.grpc.fedn_pb2.ModelRequest` @@ -179,11 +174,11 @@ def Download(self, request, context): :return: A model response iterator. :rtype: :class:`fedn.network.grpc.fedn_pb2.ModelResponse` """ - logger.info(f'grpc.ModelService.Download: {request.sender.role}:{request.sender.name} requested model {request.id}') + logger.info(f"grpc.ModelService.Download: {request.sender.role}:{request.sender.name} requested model {request.id}") try: status = self.temp_model_storage.get_model_metadata(request.id) if status != fedn.ModelStatus.OK: - logger.error(f'model file is not ready: {request.id}, status: {status}') + logger.error(f"model file is not ready: {request.id}, status: {status}") yield fedn.ModelResponse(id=request.id, data=None, status=status) except Exception: logger.error("Error file does not exist: {}".format(request.id)) @@ -192,7 +187,7 @@ def Download(self, request, context): try: obj = self.temp_model_storage.get(request.id) if obj is None: - raise Exception(f'File not found: {request.id}') + raise Exception(f"File not found: {request.id}") with obj as f: while True: piece = f.read(CHUNK_SIZE) diff --git a/fedn/network/combiner/roundhandler.py b/fedn/network/combiner/roundhandler.py index 848142efc..2a8436e01 100644 --- a/fedn/network/combiner/roundhandler.py +++ b/fedn/network/combiner/roundhandler.py @@ -7,8 +7,7 @@ from fedn.common.log_config import logger from fedn.network.combiner.aggregators.aggregatorbase import get_aggregator -from fedn.network.combiner.modelservice import (load_model_from_BytesIO, - serialize_model_to_BytesIO) +from fedn.network.combiner.modelservice import load_model_from_BytesIO, serialize_model_to_BytesIO from fedn.utils.helpers.helpers import get_helper from fedn.utils.parameters import Parameters @@ -18,7 +17,7 @@ class ModelUpdateError(Exception): class RoundHandler: - """ Round handler. + """Round handler. The round handler processes requests from the global controller to produce model updates and perform model validations. @@ -34,7 +33,7 @@ class RoundHandler: """ def __init__(self, storage, server, modelservice): - """ Initialize the RoundHandler.""" + """Initialize the RoundHandler.""" self.round_configs = queue.Queue() self.storage = storage @@ -53,12 +52,12 @@ def push_round_config(self, round_config): :rtype: str """ try: - round_config['_job_id'] = str(uuid.uuid4()) + round_config["_job_id"] = str(uuid.uuid4()) self.round_configs.put(round_config) except Exception: logger.error("Failed to push round config.") raise - return round_config['_job_id'] + return round_config["_job_id"] def load_model_update(self, helper, model_id): """Load model update with id model_id into its memory representation. @@ -74,8 +73,7 @@ def load_model_update(self, helper, model_id): try: model = load_model_from_BytesIO(model_str.getbuffer(), helper) except IOError: - logger.warning( - "AGGREGATOR({}): Failed to load model!".format(self.name)) + logger.warning("AGGREGATOR({}): Failed to load model!".format(self.name)) else: raise ModelUpdateError("Failed to load model.") @@ -108,7 +106,7 @@ def load_model_update_str(self, model_id, retry=3): return model_str def waitforit(self, config, buffer_size=100, polling_interval=0.1): - """ Defines the policy for how long the server should wait before starting to aggregate models. + """Defines the policy for how long the server should wait before starting to aggregate models. The policy is as follows: 1. Wait a maximum of time_window time until the round times out. @@ -122,7 +120,7 @@ def waitforit(self, config, buffer_size=100, polling_interval=0.1): :type polling_interval: float, optional """ - time_window = float(config['round_timeout']) + time_window = float(config["round_timeout"]) tt = 0.0 while tt < time_window: @@ -143,22 +141,21 @@ def _training_round(self, config, clients): :rtype: model, dict """ - logger.info( - "ROUNDHANDLER: Initiating training round, participating clients: {}".format(clients)) + logger.info("ROUNDHANDLER: Initiating training round, participating clients: {}".format(clients)) meta = {} - meta['nr_expected_updates'] = len(clients) - meta['nr_required_updates'] = int(config['clients_required']) - meta['timeout'] = float(config['round_timeout']) + meta["nr_expected_updates"] = len(clients) + meta["nr_required_updates"] = int(config["clients_required"]) + meta["timeout"] = float(config["round_timeout"]) # Request model updates from all active clients. self.server.request_model_update(config, clients=clients) # If buffer_size is -1 (default), the round terminates when/if all clients have completed. - if int(config['buffer_size']) == -1: + if int(config["buffer_size"]) == -1: buffer_size = len(clients) else: - buffer_size = int(config['buffer_size']) + buffer_size = int(config["buffer_size"]) # Wait / block until the round termination policy has been met. self.waitforit(config, buffer_size=buffer_size) @@ -168,27 +165,25 @@ def _training_round(self, config, clients): data = None try: - helper = get_helper(config['helper_type']) - logger.info("Config delete_models_storage: {}".format(config['delete_models_storage'])) - if config['delete_models_storage'] == 'True': + helper = get_helper(config["helper_type"]) + logger.info("Config delete_models_storage: {}".format(config["delete_models_storage"])) + if config["delete_models_storage"] == "True": delete_models = True else: delete_models = False if "aggregator_kwargs" in config.keys(): - dict_parameters = ast.literal_eval(config['aggregator_kwargs']) + dict_parameters = ast.literal_eval(config["aggregator_kwargs"]) parameters = Parameters(dict_parameters) else: parameters = None - model, data = self.aggregator.combine_models(helper=helper, - delete_models=delete_models, - parameters=parameters) + model, data = self.aggregator.combine_models(helper=helper, delete_models=delete_models, parameters=parameters) except Exception as e: logger.warning("AGGREGATION FAILED AT COMBINER! {}".format(e)) - meta['time_combination'] = time.time() - tic - meta['aggregation_time'] = data + meta["time_combination"] = time.time() - tic + meta["aggregation_time"] = data return model, meta def _validation_round(self, config, clients, model_id): @@ -237,7 +232,7 @@ def stage_model(self, model_id, timeout_retry=3, retry=2): self.modelservice.set_model(model, model_id) def _assign_round_clients(self, n, type="trainers"): - """ Obtain a list of clients(trainers or validators) to ask for updates in this round. + """Obtain a list of clients(trainers or validators) to ask for updates in this round. :param n: Size of a random set taken from active trainers(clients), if n > "active trainers" all is used :type n: int @@ -276,30 +271,27 @@ def _check_nr_round_clients(self, config): """ active = self.server.nr_active_trainers() - if active >= int(config['clients_required']): - logger.info("Number of clients required ({0}) to start round met {1}.".format( - config['clients_required'], active)) + if active >= int(config["clients_required"]): + logger.info("Number of clients required ({0}) to start round met {1}.".format(config["clients_required"], active)) return True else: logger.info("Too few clients to start round.") return False def execute_validation_round(self, round_config): - """ Coordinate validation rounds as specified in config. + """Coordinate validation rounds as specified in config. :param round_config: The round config object. :type round_config: dict """ - model_id = round_config['model_id'] - logger.info( - "COMBINER orchestrating validation of model {}".format(model_id)) + model_id = round_config["model_id"] + logger.info("COMBINER orchestrating validation of model {}".format(model_id)) self.stage_model(model_id) - validators = self._assign_round_clients( - self.server.max_clients, type="validators") + validators = self._assign_round_clients(self.server.max_clients, type="validators") self._validation_round(round_config, validators, model_id) def execute_training_round(self, config): - """ Coordinates clients to execute training tasks. + """Coordinates clients to execute training tasks. :param config: The round config object. :type config: dict @@ -307,39 +299,37 @@ def execute_training_round(self, config): :rtype: dict """ - logger.info("Processing training round, job_id {}".format(config['_job_id'])) + logger.info("Processing training round, job_id {}".format(config["_job_id"])) data = {} - data['config'] = config - data['round_id'] = config['round_id'] + data["config"] = config + data["round_id"] = config["round_id"] # Download model to update and set in temp storage. - self.stage_model(config['model_id']) + self.stage_model(config["model_id"]) clients = self._assign_round_clients(self.server.max_clients) model, meta = self._training_round(config, clients) - data['data'] = meta + data["data"] = meta if model is None: - logger.warning( - "\t Failed to update global model in round {0}!".format(config['round_id'])) + logger.warning("\t Failed to update global model in round {0}!".format(config["round_id"])) if model is not None: - helper = get_helper(config['helper_type']) + helper = get_helper(config["helper_type"]) a = serialize_model_to_BytesIO(model, helper) model_id = self.storage.set_model(a.read(), is_file=False) a.close() - data['model_id'] = model_id + data["model_id"] = model_id - logger.info( - "TRAINING ROUND COMPLETED. Aggregated model id: {}, Job id: {}".format(model_id, config['_job_id'])) + logger.info("TRAINING ROUND COMPLETED. Aggregated model id: {}, Job id: {}".format(model_id, config["_job_id"])) # Delete temp model - self.modelservice.temp_model_storage.delete(config['model_id']) + self.modelservice.temp_model_storage.delete(config["model_id"]) return data def run(self, polling_interval=1.0): - """ Main control loop. Execute rounds based on round config on the queue. + """Main control loop. Execute rounds based on round config on the queue. :param polling_interval: The polling interval in seconds for checking if a new job/config is available. :type polling_interval: float @@ -354,23 +344,22 @@ def run(self, polling_interval=1.0): round_meta = {} if ready: - if round_config['task'] == 'training': + if round_config["task"] == "training": tic = time.time() round_meta = self.execute_training_round(round_config) - round_meta['time_exec_training'] = time.time() - \ - tic - round_meta['status'] = "Success" - round_meta['name'] = self.server.id + round_meta["time_exec_training"] = time.time() - tic + round_meta["status"] = "Success" + round_meta["name"] = self.server.id self.server.statestore.set_round_combiner_data(round_meta) - elif round_config['task'] == 'validation' or round_config['task'] == 'inference': + elif round_config["task"] == "validation" or round_config["task"] == "inference": self.execute_validation_round(round_config) else: logger.warning("config contains unkown task type.") else: round_meta = {} - round_meta['status'] = "Failed" - round_meta['reason'] = "Failed to meet client allocation requirements for this round config." - logger.warning("{0}".format(round_meta['reason'])) + round_meta["status"] = "Failed" + round_meta["reason"] = "Failed to meet client allocation requirements for this round config." + logger.warning("{0}".format(round_meta["reason"])) self.round_configs.task_done() except queue.Empty: diff --git a/fedn/network/config.py b/fedn/network/config.py index a9e8773f4..0c32949b8 100644 --- a/fedn/network/config.py +++ b/fedn/network/config.py @@ -6,16 +6,14 @@ class Config(ABC): class ReducerConfig(Config): - """ Configuration for the Reducer component. """ + """Configuration for the Reducer component.""" + compute_bundle_dir = None models_dir = None initial_model = None - storage_backend = { - 'type': 's3', 'settings': - {'bucket': 'models'} - } + storage_backend = {"type": "s3", "settings": {"bucket": "models"}} def __init__(self): pass diff --git a/fedn/network/controller/control.py b/fedn/network/controller/control.py index 2b64098cf..99a59469c 100644 --- a/fedn/network/controller/control.py +++ b/fedn/network/controller/control.py @@ -3,8 +3,7 @@ import time import uuid -from tenacity import (retry, retry_if_exception_type, stop_after_delay, - wait_random) +from tenacity import retry, retry_if_exception_type, stop_after_delay, wait_random from fedn.common.log_config import logger from fedn.network.combiner.interfaces import CombinerUnavailableError @@ -54,10 +53,10 @@ def __init__(self, message): class CombinersNotDoneException(Exception): - """ Exception class for when model is None """ + """Exception class for when model is None""" def __init__(self, message): - """ Constructor method. + """Constructor method. :param message: The exception message. :type message: str @@ -108,9 +107,9 @@ def session(self, config): last_round = int(self.get_latest_round_id()) for combiner in self.network.get_combiners(): - combiner.set_aggregator(config['aggregator']) + combiner.set_aggregator(config["aggregator"]) - self.set_session_status(config['session_id'], 'Started') + self.set_session_status(config["session_id"], "Started") # Execute the rounds in this session for round in range(1, int(config["rounds"] + 1)): # Increment the round number @@ -129,11 +128,11 @@ def session(self, config): config["model_id"] = self.statestore.get_latest_model() # TODO: Report completion of session - self.set_session_status(config['session_id'], 'Finished') + self.set_session_status(config["session_id"], "Finished") self._state = ReducerState.idle def round(self, session_config, round_id): - """ Execute one global round. + """Execute one global round. : param session_config: The session config. : type session_config: dict @@ -142,11 +141,11 @@ def round(self, session_config, round_id): """ - self.create_round({'round_id': round_id, 'status': "Pending"}) + self.create_round({"round_id": round_id, "status": "Pending"}) if len(self.network.get_combiners()) < 1: logger.warning("Round cannot start, no combiners connected!") - self.set_round_status(round_id, 'Failed') + self.set_round_status(round_id, "Failed") return None, self.statestore.get_round(round_id) # Assemble round config for this global round @@ -165,11 +164,10 @@ def round(self, session_config, round_id): round_start = self.evaluate_round_start_policy(participating_combiners) if round_start: - logger.info("round start policy met, {} participating combiners.".format( - len(participating_combiners))) + logger.info("round start policy met, {} participating combiners.".format(len(participating_combiners))) else: logger.warning("Round start policy not met, skipping round!") - self.set_round_status(round_id, 'Failed') + self.set_round_status(round_id, "Failed") return None, self.statestore.get_round(round_id) # Ask participating combiners to coordinate model updates @@ -181,18 +179,19 @@ def round(self, session_config, round_id): def do_if_round_times_out(result): logger.warning("Round timed out!") - @ retry(wait=wait_random(min=1.0, max=2.0), - stop=stop_after_delay(session_config['round_timeout']), - retry_error_callback=do_if_round_times_out, - retry=retry_if_exception_type(CombinersNotDoneException)) + @retry( + wait=wait_random(min=1.0, max=2.0), + stop=stop_after_delay(session_config["round_timeout"]), + retry_error_callback=do_if_round_times_out, + retry=retry_if_exception_type(CombinersNotDoneException), + ) def combiners_done(): - round = self.statestore.get_round(round_id) - if 'combiners' not in round: + if "combiners" not in round: logger.info("Waiting for combiners to update model...") raise CombinersNotDoneException("Combiners have not yet reported.") - if len(round['combiners']) < len(participating_combiners): + if len(round["combiners"]) < len(participating_combiners): logger.info("Waiting for combiners to update model...") raise CombinersNotDoneException("All combiners have not yet reported.") @@ -203,11 +202,10 @@ def combiners_done(): # Due to the distributed nature of the computation, there might be a # delay before combiners have reported the round data to the db, # so we need some robustness here. - @ retry(wait=wait_random(min=0.1, max=1.0), - retry=retry_if_exception_type(KeyError)) + @retry(wait=wait_random(min=0.1, max=1.0), retry=retry_if_exception_type(KeyError)) def check_combiners_done_reporting(): round = self.statestore.get_round(round_id) - combiners = round['combiners'] + combiners = round["combiners"] return combiners _ = check_combiners_done_reporting() @@ -216,7 +214,7 @@ def check_combiners_done_reporting(): round_valid = self.evaluate_round_validity_policy(round) if not round_valid: logger.error("Round failed. Invalid - evaluate_round_validity_policy: False") - self.set_round_status(round_id, 'Failed') + self.set_round_status(round_id, "Failed") return None, self.statestore.get_round(round_id) logger.info("Reducing combiner level models...") @@ -224,12 +222,12 @@ def check_combiners_done_reporting(): round_data = {} try: round = self.statestore.get_round(round_id) - model, data = self.reduce(round['combiners']) - round_data['reduce'] = data + model, data = self.reduce(round["combiners"]) + round_data["reduce"] = data logger.info("Done reducing models from combiners!") except Exception as e: logger.error("Failed to reduce models from combiners, reason: {}".format(e)) - self.set_round_status(round_id, 'Failed') + self.set_round_status(round_id, "Failed") return None, self.statestore.get_round(round_id) # Commit the new global model to the model trail @@ -237,20 +235,16 @@ def check_combiners_done_reporting(): logger.info("Committing global model to model trail...") tic = time.time() model_id = uuid.uuid4() - session_id = ( - session_config["session_id"] - if "session_id" in session_config - else None - ) + session_id = session_config["session_id"] if "session_id" in session_config else None self.commit(model_id, model, session_id) round_data["time_commit"] = time.time() - tic logger.info("Done committing global model to model trail.") else: logger.error("Failed to commit model to global model trail.") - self.set_round_status(round_id, 'Failed') + self.set_round_status(round_id, "Failed") return None, self.statestore.get_round(round_id) - self.set_round_status(round_id, 'Success') + self.set_round_status(round_id, "Success") # 4. Trigger participating combiner nodes to execute a validation round for the current model validate = session_config["validate"] @@ -261,8 +255,7 @@ def check_combiners_done_reporting(): combiner_config["task"] = "validation" combiner_config["helper_type"] = self.statestore.get_helper() - validating_combiners = self.get_participating_combiners( - combiner_config) + validating_combiners = self.get_participating_combiners(combiner_config) for combiner, combiner_config in validating_combiners: try: @@ -273,7 +266,7 @@ def check_combiners_done_reporting(): pass self.set_round_data(round_id, round_data) - self.set_round_status(round_id, 'Finished') + self.set_round_status(round_id, "Finished") return model_id, self.statestore.get_round(round_id) def reduce(self, combiners): @@ -292,15 +285,15 @@ def reduce(self, combiners): model = None for combiner in combiners: - name = combiner['name'] - model_id = combiner['model_id'] + name = combiner["name"] + model_id = combiner["model_id"] logger.info("Fetching model ({}) from model repository".format(model_id)) try: tic = time.time() data = self.model_repository.get_model(model_id) - meta['time_fetch_model'] += (time.time() - tic) + meta["time_fetch_model"] += time.time() - tic except Exception as e: logger.error("Failed to fetch model from model repository {}: {}".format(name, e)) data = None @@ -373,8 +366,7 @@ def inference_round(self, config): combiner_config["helper_type"] = self.statestore.get_framework() # Select combiners - validating_combiners = self.get_participating_combiners( - combiner_config) + validating_combiners = self.get_participating_combiners(combiner_config) # Test round start policy round_start = self.check_round_start_policy(validating_combiners) diff --git a/fedn/network/controller/controlbase.py b/fedn/network/controller/controlbase.py index e825e8e8b..d99bae40a 100644 --- a/fedn/network/controller/controlbase.py +++ b/fedn/network/controller/controlbase.py @@ -61,9 +61,7 @@ def __init__(self, statestore): raise MisconfiguredStorageBackend() if storage_config["storage_type"] == "S3": - self.model_repository = Repository( - storage_config["storage_config"] - ) + self.model_repository = Repository(storage_config["storage_config"]) else: logger.error("Unsupported storage backend, exiting.") raise UnsupportedStorageBackend() @@ -92,11 +90,7 @@ def get_helper(self): helper_type = self.statestore.get_helper() helper = fedn.utils.helpers.helpers.get_helper(helper_type) if not helper: - raise MisconfiguredHelper( - "Unsupported helper type {}, please configure compute_package.helper !".format( - helper_type - ) - ) + raise MisconfiguredHelper("Unsupported helper type {}, please configure compute_package.helper !".format(helper_type)) return helper def get_state(self): @@ -177,8 +171,8 @@ def get_compute_package(self, compute_package=""): else: return None - def create_session(self, config, status='Initialized'): - """ Initialize a new session in backend db. """ + def create_session(self, config, status="Initialized"): + """Initialize a new session in backend db.""" if "session_id" not in config.keys(): session_id = uuid.uuid4() @@ -191,7 +185,7 @@ def create_session(self, config, status='Initialized'): self.statestore.set_session_status(session_id, status) def set_session_status(self, session_id, status): - """ Set the round round stats. + """Set the round round stats. :param round_id: The round unique identifier :type round_id: str @@ -201,12 +195,12 @@ def set_session_status(self, session_id, status): self.statestore.set_session_status(session_id, status) def create_round(self, round_data): - """Initialize a new round in backend db. """ + """Initialize a new round in backend db.""" self.statestore.create_round(round_data) def set_round_data(self, round_id, round_data): - """ Set round data. + """Set round data. :param round_id: The round unique identifier :type round_id: str @@ -216,7 +210,7 @@ def set_round_data(self, round_id, round_data): self.statestore.set_round_data(round_id, round_data) def set_round_status(self, round_id, status): - """ Set the round round stats. + """Set the round round stats. :param round_id: The round unique identifier :type round_id: str @@ -226,7 +220,7 @@ def set_round_status(self, round_id, status): self.statestore.set_round_status(round_id, status) def set_round_config(self, round_id, round_config): - """ Upate round in backend db. + """Upate round in backend db. :param round_id: The round unique identifier :type round_id: str @@ -263,9 +257,7 @@ def commit(self, model_id, model=None, session_id=None): logger.info("Saving model file temporarily to disk...") outfile_name = helper.save(model) logger.info("CONTROL: Uploading model to Minio...") - model_id = self.model_repository.set_model( - outfile_name, is_file=True - ) + model_id = self.model_repository.set_model(outfile_name, is_file=True) logger.info("CONTROL: Deleting temporary model file...") os.unlink(outfile_name) @@ -292,16 +284,12 @@ def get_participating_combiners(self, combiner_round_config): self._handle_unavailable_combiner(combiner) continue - is_participating = self.evaluate_round_participation_policy( - combiner_round_config, nr_active_clients - ) + is_participating = self.evaluate_round_participation_policy(combiner_round_config, nr_active_clients) if is_participating: combiners.append((combiner, combiner_round_config)) return combiners - def evaluate_round_participation_policy( - self, compute_plan, nr_active_clients - ): + def evaluate_round_participation_policy(self, compute_plan, nr_active_clients): """Evaluate policy for combiner round-participation. A combiner participates if it is responsive and reports enough active clients to participate in the round. @@ -325,7 +313,7 @@ def evaluate_round_start_policy(self, combiners): return False def evaluate_round_validity_policy(self, round): - """ Check if the round is valid. + """Check if the round is valid. At the end of the round, before committing a model to the global model trail, we check if the round validity policy has been met. This can involve @@ -338,9 +326,9 @@ def evaluate_round_validity_policy(self, round): :rtype: bool """ model_ids = [] - for combiner in round['combiners']: + for combiner in round["combiners"]: try: - model_ids.append(combiner['model_id']) + model_ids.append(combiner["model_id"]) except KeyError: pass @@ -350,7 +338,7 @@ def evaluate_round_validity_policy(self, round): return True def state(self): - """ Get the current state of the controller. + """Get the current state of the controller. :return: The state :rype: str diff --git a/fedn/network/grpc/auth.py b/fedn/network/grpc/auth.py index d879cd812..e57926cd7 100644 --- a/fedn/network/grpc/auth.py +++ b/fedn/network/grpc/auth.py @@ -6,35 +6,33 @@ from fedn.network.api.auth import check_custom_claims ENDPOINT_ROLES_MAPPING = { - '/fedn.Combiner/TaskStream': ['client'], - '/fedn.Combiner/SendModelUpdate': ['client'], - '/fedn.Combiner/SendModelValidation': ['client'], - '/fedn.Connector/SendHeartbeat': ['client'], - '/fedn.Connector/SendStatus': ['client'], - '/fedn.ModelService/Download': ['client'], - '/fedn.ModelService/Upload': ['client'], - '/fedn.Control/Start': ['controller'], - '/fedn.Control/Stop': ['controller'], - '/fedn.Control/FlushAggregationQueue': ['controller'], - '/fedn.Control/SetAggregator': ['controller'], + "/fedn.Combiner/TaskStream": ["client"], + "/fedn.Combiner/SendModelUpdate": ["client"], + "/fedn.Combiner/SendModelValidation": ["client"], + "/fedn.Connector/SendHeartbeat": ["client"], + "/fedn.Connector/SendStatus": ["client"], + "/fedn.ModelService/Download": ["client"], + "/fedn.ModelService/Upload": ["client"], + "/fedn.Control/Start": ["controller"], + "/fedn.Control/Stop": ["controller"], + "/fedn.Control/FlushAggregationQueue": ["controller"], + "/fedn.Control/SetAggregator": ["controller"], } ENDPOINT_WHITELIST = [ - '/fedn.Connector/AcceptingClients', - '/fedn.Connector/ListActiveClients', - '/fedn.Control/Start', - '/fedn.Control/Stop', - '/fedn.Control/FlushAggregationQueue', - '/fedn.Control/SetAggregator', + "/fedn.Connector/AcceptingClients", + "/fedn.Connector/ListActiveClients", + "/fedn.Control/Start", + "/fedn.Control/Stop", + "/fedn.Control/FlushAggregationQueue", + "/fedn.Control/SetAggregator", ] -USER_AGENT_WHITELIST = [ - 'grpc_health_probe' -] +USER_AGENT_WHITELIST = ["grpc_health_probe"] def check_role_claims(payload, endpoint): - user_role = payload.get('role', '') + user_role = payload.get("role", "") # Perform endpoint-specific RBAC check allowed_roles = ENDPOINT_ROLES_MAPPING.get(endpoint) @@ -63,33 +61,33 @@ def intercept_service(self, continuation, handler_call_details): if handler_call_details.method in ENDPOINT_WHITELIST: return continuation(handler_call_details) # Pass if the request comes from whitelisted user agents - user_agent = metadata.get('user-agent').split(' ')[0] + user_agent = metadata.get("user-agent").split(" ")[0] if user_agent in USER_AGENT_WHITELIST: return continuation(handler_call_details) - token = metadata.get('authorization') + token = metadata.get("authorization") if token is None: - return _unary_unary_rpc_terminator(grpc.StatusCode.UNAUTHENTICATED, 'Token is missing') + return _unary_unary_rpc_terminator(grpc.StatusCode.UNAUTHENTICATED, "Token is missing") if not token.startswith(FEDN_AUTH_SCHEME): - return _unary_unary_rpc_terminator(grpc.StatusCode.UNAUTHENTICATED, f'Invalid token scheme, expected {FEDN_AUTH_SCHEME}') + return _unary_unary_rpc_terminator(grpc.StatusCode.UNAUTHENTICATED, f"Invalid token scheme, expected {FEDN_AUTH_SCHEME}") - token = token.split(' ')[1] + token = token.split(" ")[1] try: payload = jwt.decode(token, SECRET_KEY, algorithms=[FEDN_JWT_ALGORITHM]) if not check_role_claims(payload, handler_call_details.method): - return _unary_unary_rpc_terminator(grpc.StatusCode.PERMISSION_DENIED, 'Insufficient permissions') + return _unary_unary_rpc_terminator(grpc.StatusCode.PERMISSION_DENIED, "Insufficient permissions") if not check_custom_claims(payload): - return _unary_unary_rpc_terminator(grpc.StatusCode.PERMISSION_DENIED, 'Insufficient permissions') + return _unary_unary_rpc_terminator(grpc.StatusCode.PERMISSION_DENIED, "Insufficient permissions") return continuation(handler_call_details) except jwt.InvalidTokenError: - return _unary_unary_rpc_terminator(grpc.StatusCode.UNAUTHENTICATED, 'Invalid token') + return _unary_unary_rpc_terminator(grpc.StatusCode.UNAUTHENTICATED, "Invalid token") except jwt.ExpiredSignatureError: - return _unary_unary_rpc_terminator(grpc.StatusCode.UNAUTHENTICATED, 'Token expired') + return _unary_unary_rpc_terminator(grpc.StatusCode.UNAUTHENTICATED, "Token expired") except Exception as e: logger.error(str(e)) return _unary_unary_rpc_terminator(grpc.StatusCode.UNKNOWN, str(e)) diff --git a/fedn/network/loadbalancer/firstavailable.py b/fedn/network/loadbalancer/firstavailable.py index 9d44d3fbd..13dc766b2 100644 --- a/fedn/network/loadbalancer/firstavailable.py +++ b/fedn/network/loadbalancer/firstavailable.py @@ -2,7 +2,7 @@ class LeastPacked(LoadBalancerBase): - """ Load balancer that selects the first available combiner. + """Load balancer that selects the first available combiner. :param network: A handle to the network. :type network: class: `fedn.network.api.network.Network` @@ -12,7 +12,7 @@ def __init__(self, network): super().__init__(network) def find_combiner(self): - """ Find the first available combiner. """ + """Find the first available combiner.""" for combiner in self.network.get_combiners(): if combiner.allowing_clients(): diff --git a/fedn/network/loadbalancer/leastpacked.py b/fedn/network/loadbalancer/leastpacked.py index cac7bba54..a762701b0 100644 --- a/fedn/network/loadbalancer/leastpacked.py +++ b/fedn/network/loadbalancer/leastpacked.py @@ -3,7 +3,7 @@ class LeastPacked(LoadBalancerBase): - """ Load balancer that selects the combiner with the least number of attached training clients. + """Load balancer that selects the combiner with the least number of attached training clients. :param network: A handle to the network. :type network: class: `fedn.network.api.network.Network` @@ -14,7 +14,7 @@ def __init__(self, network): def find_combiner(self): """ - Find the combiner with the least number of attached clients. + Find the combiner with the least number of attached clients. """ min_clients = None diff --git a/fedn/network/loadbalancer/loadbalancerbase.py b/fedn/network/loadbalancer/loadbalancerbase.py index ff1edfa9b..ddbfa2231 100644 --- a/fedn/network/loadbalancer/loadbalancerbase.py +++ b/fedn/network/loadbalancer/loadbalancerbase.py @@ -2,7 +2,7 @@ class LoadBalancerBase(ABC): - """ Abstract base class for load balancers. + """Abstract base class for load balancers. :param network: A handle to the network. :type network: class: `fedn.network.api.network.Network` @@ -14,5 +14,5 @@ def __init__(self, network): @abstractmethod def find_combiner(self): - """ Find a combiner to connect to. """ + """Find a combiner to connect to.""" pass diff --git a/fedn/network/state.py b/fedn/network/state.py index 9d18bc924..fcbd391eb 100644 --- a/fedn/network/state.py +++ b/fedn/network/state.py @@ -2,7 +2,8 @@ class ReducerState(Enum): - """ Enum for representing the state of a reducer.""" + """Enum for representing the state of a reducer.""" + setup = 1 idle = 2 instructing = 3 @@ -10,7 +11,7 @@ class ReducerState(Enum): def ReducerStateToString(state): - """ Convert ReducerState to string. + """Convert ReducerState to string. :param state: The state. :type state: :class:`fedn.network.state.ReducerState` @@ -30,7 +31,7 @@ def ReducerStateToString(state): def StringToReducerState(state): - """ Convert string to ReducerState. + """Convert string to ReducerState. :param state: The state as string. :type state: str diff --git a/fedn/network/storage/models/memorymodelstorage.py b/fedn/network/storage/models/memorymodelstorage.py index e6ef9b07f..6a40a7ae0 100644 --- a/fedn/network/storage/models/memorymodelstorage.py +++ b/fedn/network/storage/models/memorymodelstorage.py @@ -8,25 +8,22 @@ class MemoryModelStorage(ModelStorage): - """ Class for in-memory storage of model artifacts. + """Class for in-memory storage of model artifacts. Models are stored as BytesIO objects in a dictionary. """ def __init__(self): - self.models = defaultdict(io.BytesIO) self.models_metadata = {} def exist(self, model_id): - if model_id in self.models.keys(): return True return False def get(self, model_id): - obj = self.models[model_id] obj.seek(0, 0) # Have to copy object to not mix up the file pointers when sending... fix in better way. @@ -42,9 +39,7 @@ def get_ptr(self, model_id): return self.models[model_id] def get_model_metadata(self, model_id): - return self.models_metadata[model_id] def set_model_metadata(self, model_id, model_metadata): - self.models_metadata.update({model_id: model_metadata}) diff --git a/fedn/network/storage/models/modelstorage.py b/fedn/network/storage/models/modelstorage.py index 4945ab37b..3062db36e 100644 --- a/fedn/network/storage/models/modelstorage.py +++ b/fedn/network/storage/models/modelstorage.py @@ -2,10 +2,9 @@ class ModelStorage(ABC): - @abstractmethod def exist(self, model_id): - """ Check if model exists in storage + """Check if model exists in storage :param model_id: The model id :type model_id: str @@ -16,7 +15,7 @@ def exist(self, model_id): @abstractmethod def get(self, model_id): - """ Get model from storage + """Get model from storage :param model_id: The model id :type model_id: str @@ -27,7 +26,7 @@ def get(self, model_id): @abstractmethod def get_model_metadata(self, model_id): - """ Get model metadata from storage + """Get model metadata from storage :param model_id: The model id :type model_id: str @@ -38,7 +37,7 @@ def get_model_metadata(self, model_id): @abstractmethod def set_model_metadata(self, model_id, model_metadata): - """ Set model metadata in storage + """Set model metadata in storage :param model_id: The model id :type model_id: str @@ -51,7 +50,7 @@ def set_model_metadata(self, model_id, model_metadata): @abstractmethod def delete(self, model_id): - """ Delete model from storage + """Delete model from storage :param model_id: The model id :type model_id: str @@ -62,7 +61,7 @@ def delete(self, model_id): @abstractmethod def delete_all(self): - """ Delete all models from storage + """Delete all models from storage :return: True if successful, False otherwise :rtype: bool diff --git a/fedn/network/storage/models/tempmodelstorage.py b/fedn/network/storage/models/tempmodelstorage.py index 492cee5ab..214fac4d7 100644 --- a/fedn/network/storage/models/tempmodelstorage.py +++ b/fedn/network/storage/models/tempmodelstorage.py @@ -9,12 +9,10 @@ class TempModelStorage(ModelStorage): - """ Class for managing local temporary models on file on combiners.""" + """Class for managing local temporary models on file on combiners.""" def __init__(self): - - self.default_dir = os.environ.get( - 'FEDN_MODEL_DIR', '/tmp/models') # set default to tmp + self.default_dir = os.environ.get("FEDN_MODEL_DIR", "/tmp/models") # set default to tmp if not os.path.exists(self.default_dir): os.makedirs(self.default_dir) @@ -22,13 +20,11 @@ def __init__(self): self.models_metadata = {} def exist(self, model_id): - if model_id in self.models.keys(): return True return False def get(self, model_id): - try: if self.models_metadata[model_id] != fedn.ModelStatus.OK: logger.warning("File not ready! Try again") @@ -38,7 +34,7 @@ def get(self, model_id): return None obj = BytesIO() - with open(os.path.join(self.default_dir, str(model_id)), 'rb') as f: + with open(os.path.join(self.default_dir, str(model_id)), "rb") as f: obj.write(f.read()) obj.seek(0, 0) @@ -51,15 +47,14 @@ def get_ptr(self, model_id): :return: """ try: - f = self.models[model_id]['file'] + f = self.models[model_id]["file"] except KeyError: f = open(os.path.join(self.default_dir, str(model_id)), "wb") - self.models[model_id] = {'file': f} - return self.models[model_id]['file'] + self.models[model_id] = {"file": f} + return self.models[model_id]["file"] def get_model_metadata(self, model_id): - try: status = self.models_metadata[model_id] except KeyError: @@ -67,12 +62,10 @@ def get_model_metadata(self, model_id): return status def set_model_metadata(self, model_id, model_metadata): - self.models_metadata.update({model_id: model_metadata}) # Delete model from disk def delete(self, model_id): - try: os.remove(os.path.join(self.default_dir, str(model_id))) logger.info("TEMPMODELSTORAGE: Deleted model with id: {}".format(model_id)) @@ -86,7 +79,6 @@ def delete(self, model_id): # Delete all models from disk def delete_all(self): - ids_pop = [] for model_id in self.models.keys(): try: diff --git a/fedn/network/storage/s3/base.py b/fedn/network/storage/s3/base.py index 8f22a2485..671b75d7d 100644 --- a/fedn/network/storage/s3/base.py +++ b/fedn/network/storage/s3/base.py @@ -6,7 +6,7 @@ class RepositoryBase(object): @abc.abstractmethod def set_artifact(self, instance_name, instance, bucket): - """ Set object with name object_name + """Set object with name object_name :param instance_name: The name of the object :tyep insance_name: str @@ -18,7 +18,7 @@ def set_artifact(self, instance_name, instance, bucket): @abc.abstractmethod def get_artifact(self, instance_name, bucket): - """ Retrive object with name instance_name. + """Retrive object with name instance_name. :param instance_name: The name of the object to retrieve :type instance_name: str @@ -29,7 +29,7 @@ def get_artifact(self, instance_name, bucket): @abc.abstractmethod def get_artifact_stream(self, instance_name, bucket): - """ Return a stream handler for object with name instance_name. + """Return a stream handler for object with name instance_name. :param instance_name: The name if the object :type instance_name: str diff --git a/fedn/network/storage/s3/miniorepository.py b/fedn/network/storage/s3/miniorepository.py index f6082b158..9c86b8997 100644 --- a/fedn/network/storage/s3/miniorepository.py +++ b/fedn/network/storage/s3/miniorepository.py @@ -9,12 +9,12 @@ class MINIORepository(RepositoryBase): - """ Class implementing Repository for MinIO. """ + """Class implementing Repository for MinIO.""" client = None def __init__(self, config): - """ Initialize object. + """Initialize object. :param config: Dictionary containing configuration for credentials and bucket names. :type config: dict @@ -23,34 +23,35 @@ def __init__(self, config): super().__init__() self.name = "MINIORepository" - if config['storage_secure_mode']: - manager = PoolManager( - num_pools=100, cert_reqs='CERT_NONE', assert_hostname=False) - self.client = Minio("{0}:{1}".format(config['storage_hostname'], config['storage_port']), - access_key=config['storage_access_key'], - secret_key=config['storage_secret_key'], - secure=config['storage_secure_mode'], http_client=manager) + if config["storage_secure_mode"]: + manager = PoolManager(num_pools=100, cert_reqs="CERT_NONE", assert_hostname=False) + self.client = Minio( + "{0}:{1}".format(config["storage_hostname"], config["storage_port"]), + access_key=config["storage_access_key"], + secret_key=config["storage_secret_key"], + secure=config["storage_secure_mode"], + http_client=manager, + ) else: - self.client = Minio("{0}:{1}".format(config['storage_hostname'], config['storage_port']), - access_key=config['storage_access_key'], - secret_key=config['storage_secret_key'], - secure=config['storage_secure_mode']) + self.client = Minio( + "{0}:{1}".format(config["storage_hostname"], config["storage_port"]), + access_key=config["storage_access_key"], + secret_key=config["storage_secret_key"], + secure=config["storage_secure_mode"], + ) def set_artifact(self, instance_name, instance, bucket, is_file=False): - if is_file: self.client.fput_object(bucket, instance_name, instance) else: try: - self.client.put_object( - bucket, instance_name, io.BytesIO(instance), len(instance)) + self.client.put_object(bucket, instance_name, io.BytesIO(instance), len(instance)) except Exception as e: raise Exception("Could not load data into bytes {}".format(e)) return True def get_artifact(self, instance_name, bucket): - try: data = self.client.get_object(bucket, instance_name) return data.read() @@ -61,7 +62,6 @@ def get_artifact(self, instance_name, bucket): data.release_conn() def get_artifact_stream(self, instance_name, bucket): - try: data = self.client.get_object(bucket, instance_name) return data @@ -69,7 +69,7 @@ def get_artifact_stream(self, instance_name, bucket): raise Exception("Could not fetch data from bucket, {}".format(e)) def list_artifacts(self, bucket): - """ List all objects in bucket. + """List all objects in bucket. :param bucket: Name of the bucket :type bucket: str @@ -81,12 +81,11 @@ def list_artifacts(self, bucket): for obj in objs: objects.append(obj.object_name) except Exception: - raise Exception( - "Could not list models in bucket {}".format(bucket)) + raise Exception("Could not list models in bucket {}".format(bucket)) return objects def delete_artifact(self, instance_name, bucket): - """ Delete object with name instance_name from buckets. + """Delete object with name instance_name from buckets. :param instance_name: The object name :param bucket: Buckets to delete from @@ -96,11 +95,11 @@ def delete_artifact(self, instance_name, bucket): try: self.client.remove_object(bucket, instance_name) except InvalidResponseError as err: - logger.error('Could not delete artifact: {0} err: {1}'.format(instance_name, err)) + logger.error("Could not delete artifact: {0} err: {1}".format(instance_name, err)) pass def create_bucket(self, bucket_name): - """ Create a new bucket. If bucket exists, do nothing. + """Create a new bucket. If bucket exists, do nothing. :param bucket_name: The name of the bucket :type bucket_name: str diff --git a/fedn/network/storage/s3/repository.py b/fedn/network/storage/s3/repository.py index d7d455341..18d36cdbb 100644 --- a/fedn/network/storage/s3/repository.py +++ b/fedn/network/storage/s3/repository.py @@ -5,12 +5,11 @@ class Repository: - """ Interface for storing model objects and compute packages in S3 compatible storage. """ + """Interface for storing model objects and compute packages in S3 compatible storage.""" def __init__(self, config): - - self.model_bucket = config['storage_bucket'] - self.context_bucket = config['context_bucket'] + self.model_bucket = config["storage_bucket"] + self.context_bucket = config["context_bucket"] # TODO: Make a plug-in solution self.client = MINIORepository(config) @@ -19,27 +18,25 @@ def __init__(self, config): self.client.create_bucket(self.model_bucket) def get_model(self, model_id): - """ Retrieve a model with id model_id. + """Retrieve a model with id model_id. :param model_id: Unique identifier for model to retrive. :return: The model object """ - logger.info("Client {} trying to get model with id: {}".format( - self.client.name, model_id)) + logger.info("Client {} trying to get model with id: {}".format(self.client.name, model_id)) return self.client.get_artifact(model_id, self.model_bucket) def get_model_stream(self, model_id): - """ Retrieve a stream handle to model with id model_id. + """Retrieve a stream handle to model with id model_id. :param model_id: :return: Handle to model object """ - logger.info("Client {} trying to get model with id: {}".format( - self.client.name, model_id)) + logger.info("Client {} trying to get model with id: {}".format(self.client.name, model_id)) return self.client.get_artifact_stream(model_id, self.model_bucket) def set_model(self, model, is_file=True): - """ Upload model object. + """Upload model object. :param model: The model object :type model: BytesIO or str file name. @@ -49,15 +46,14 @@ def set_model(self, model, is_file=True): model_id = uuid.uuid4() try: - self.client.set_artifact(str(model_id), model, - bucket=self.model_bucket, is_file=is_file) + self.client.set_artifact(str(model_id), model, bucket=self.model_bucket, is_file=is_file) except Exception: logger.error("Failed to upload model with ID {} to repository.".format(model_id)) raise return str(model_id) def delete_model(self, model_id): - """ Delete model. + """Delete model. :param model_id: The id of the model to delete :type model_id: str @@ -69,7 +65,7 @@ def delete_model(self, model_id): raise def set_compute_package(self, name, compute_package, is_file=True): - """ Upload compute package. + """Upload compute package. :param name: The name of the compute package. :type name: str @@ -79,14 +75,13 @@ def set_compute_package(self, name, compute_package, is_file=True): """ try: - self.client.set_artifact(str(name), compute_package, - bucket=self.context_bucket, is_file=is_file) + self.client.set_artifact(str(name), compute_package, bucket=self.context_bucket, is_file=is_file) except Exception: logger.error("Failed to write compute_package to repository.") raise def get_compute_package(self, compute_package): - """ Retrieve compute package from object store. + """Retrieve compute package from object store. :param compute_package: The name of the compute package. :type compute_pacakge: str @@ -100,7 +95,7 @@ def get_compute_package(self, compute_package): return data def delete_compute_package(self, compute_package): - """ Delete a compute package from storage. + """Delete a compute package from storage. :param compute_package: The name of the compute_package :type compute_package: str diff --git a/fedn/network/storage/statestore/mongostatestore.py b/fedn/network/storage/statestore/mongostatestore.py index c4fdf3fe7..a53a6e4d5 100644 --- a/fedn/network/storage/statestore/mongostatestore.py +++ b/fedn/network/storage/statestore/mongostatestore.py @@ -61,7 +61,7 @@ def __init__(self, network_id, config): self.init_index() def connect(self): - """ Establish client connection to MongoDB. + """Establish client connection to MongoDB. :param config: Dictionary containing connection strings and security credentials. :type config: dict @@ -125,11 +125,7 @@ def transition(self, state): True, ) else: - logger.info( - "Not updating state, already in {}".format( - ReducerStateToString(state) - ) - ) + logger.info("Not updating state, already in {}".format(ReducerStateToString(state))) def get_sessions(self, limit=None, skip=None, sort_key="_id", sort_order=pymongo.DESCENDING): """Get all sessions. @@ -151,13 +147,9 @@ def get_sessions(self, limit=None, skip=None, sort_key="_id", sort_order=pymongo limit = int(limit) skip = int(skip) - result = self.sessions.find().limit(limit).skip(skip).sort( - sort_key, sort_order - ) + result = self.sessions.find().limit(limit).skip(skip).sort(sort_key, sort_order) else: - result = self.sessions.find().sort( - sort_key, sort_order - ) + result = self.sessions.find().sort(sort_key, sort_order) count = self.sessions.count_documents({}) @@ -204,9 +196,7 @@ def set_latest_model(self, model_id, session_id=None): } ) - self.model.update_one( - {"key": "current_model"}, {"$set": {"model": model_id}}, True - ) + self.model.update_one({"key": "current_model"}, {"$set": {"model": model_id}}, True) self.model.update_one( {"key": "model_trail"}, { @@ -225,9 +215,7 @@ def get_initial_model(self): :rtype: str """ - result = self.model.find_one( - {"key": "model_trail"}, sort=[("committed_at", pymongo.ASCENDING)] - ) + result = self.model.find_one({"key": "model_trail"}, sort=[("committed_at", pymongo.ASCENDING)]) if result is None: return None @@ -266,16 +254,12 @@ def set_current_model(self, model_id: str): """ try: - committed_at = datetime.now() existing_model = self.model.find_one({"key": "models", "model": model_id}) if existing_model is not None: - - self.model.update_one( - {"key": "current_model"}, {"$set": {"model": model_id, "committed_at": committed_at, "session_id": None}}, True - ) + self.model.update_one({"key": "current_model"}, {"$set": {"model": model_id, "committed_at": committed_at, "session_id": None}}, True) return True except Exception as e: @@ -334,7 +318,6 @@ def set_active_compute_package(self, id: str): """ try: - find = {"id": id} projection = {"_id": False, "key": False} @@ -345,9 +328,7 @@ def set_active_compute_package(self, id: str): doc["key"] = "active" - self.control.package.replace_one( - {"key": "active"}, doc - ) + self.control.package.replace_one({"key": "active"}, doc) except Exception as e: logger.error("ERROR: {}".format(e)) @@ -376,9 +357,7 @@ def set_compute_package(self, file_name: str, storage_file_name: str, helper_typ self.control.package.update_one( {"key": "active"}, - { - "$set": obj - }, + {"$set": obj}, True, ) @@ -395,7 +374,6 @@ def get_compute_package(self): :rtype: ObjectID """ try: - find = {"key": "active"} projection = {"key": False, "_id": False} ret = self.control.package.find_one(find, projection) @@ -449,9 +427,7 @@ def set_helper(self, helper): :type helper: str :return: """ - self.control.package.update_one( - {"key": "active"}, {"$set": {"helper": helper}}, True - ) + self.control.package.update_one({"key": "active"}, {"$set": {"helper": helper}}, True) def get_helper(self): """Get the active helper package. @@ -466,9 +442,7 @@ def get_helper(self): # ret = self.control.config.find_one({'key': 'round_config'}) try: retcheck = ret["helper"] - if ( - retcheck == "" or retcheck == " " - ): # ugly check for empty string + if retcheck == "" or retcheck == " ": # ugly check for empty string return None return retcheck except (KeyError, IndexError): @@ -495,11 +469,7 @@ def list_models( """ result = None - find_option = ( - {"key": "models"} - if session_id is None - else {"key": "models", "session_id": session_id} - ) + find_option = {"key": "models"} if session_id is None else {"key": "models", "session_id": session_id} projection = {"_id": False, "key": False} @@ -507,17 +477,10 @@ def list_models( limit = int(limit) skip = int(skip) - result = ( - self.model.find(find_option, projection) - .limit(limit) - .skip(skip) - .sort(sort_key, sort_order) - ) + result = self.model.find(find_option, projection).limit(limit).skip(skip).sort(sort_key, sort_order) else: - result = self.model.find(find_option, projection).sort( - sort_key, sort_order - ) + result = self.model.find(find_option, projection).sort(sort_key, sort_order) count = self.model.count_documents(find_option) @@ -625,9 +588,7 @@ def get_events(self, **kwargs): projection = {"_id": False} if not kwargs: - result = self.control.status.find({}, projection).sort( - "timestamp", pymongo.DESCENDING - ) + result = self.control.status.find({}, projection).sort("timestamp", pymongo.DESCENDING) count = self.control.status.count_documents({}) else: limit = kwargs.pop("limit", None) @@ -636,16 +597,9 @@ def get_events(self, **kwargs): if limit is not None and skip is not None: limit = int(limit) skip = int(skip) - result = ( - self.control.status.find(kwargs, projection) - .sort("timestamp", pymongo.DESCENDING) - .limit(limit) - .skip(skip) - ) + result = self.control.status.find(kwargs, projection).sort("timestamp", pymongo.DESCENDING).limit(limit).skip(skip) else: - result = self.control.status.find(kwargs, projection).sort( - "timestamp", pymongo.DESCENDING - ) + result = self.control.status.find(kwargs, projection).sort("timestamp", pymongo.DESCENDING) count = self.control.status.count_documents(kwargs) @@ -661,9 +615,7 @@ def get_storage_backend(self): :rtype: ObjectID """ try: - ret = self.storage.find( - {"status": "enabled"}, projection={"_id": False} - ) + ret = self.storage.find({"status": "enabled"}, projection={"_id": False}) return ret[0] except (KeyError, IndexError): return None @@ -678,9 +630,7 @@ def set_storage_backend(self, config): config = copy.deepcopy(config) config["updated_at"] = str(datetime.now()) config["status"] = "enabled" - self.storage.update_one( - {"storage_type": config["storage_type"]}, {"$set": config}, True - ) + self.storage.update_one({"storage_type": config["storage_type"]}, {"$set": config}, True) def set_reducer(self, reducer_data): """Set the reducer in the statestore. @@ -690,9 +640,7 @@ def set_reducer(self, reducer_data): :return: """ reducer_data["updated_at"] = str(datetime.now()) - self.reducer.update_one( - {"name": reducer_data["name"]}, {"$set": reducer_data}, True - ) + self.reducer.update_one({"name": reducer_data["name"]}, {"$set": reducer_data}, True) def get_reducer(self): """Get reducer.config. @@ -767,9 +715,7 @@ def set_combiner(self, combiner_data): """ combiner_data["updated_at"] = str(datetime.now()) - self.combiners.update_one( - {"name": combiner_data["name"]}, {"$set": combiner_data}, True - ) + self.combiners.update_one({"name": combiner_data["name"]}, {"$set": combiner_data}, True) def delete_combiner(self, combiner): """Delete a combiner from statestore. @@ -793,9 +739,7 @@ def set_client(self, client_data): :return: """ client_data["updated_at"] = str(datetime.now()) - self.clients.update_one( - {"name": client_data["name"]}, {"$set": client_data}, True - ) + self.clients.update_one({"name": client_data["name"]}, {"$set": client_data}, True) def get_client(self, name): """Get client by name. @@ -866,15 +810,15 @@ def list_combiners_data(self, combiners, sort_key="count", sort_order=pymongo.DE result = None try: - - pipeline = [ - {"$match": {"combiner": {"$in": combiners}, "status": "online"}}, - {"$group": {"_id": "$combiner", "count": {"$sum": 1}}}, - {"$sort": {sort_key: sort_order, "_id": pymongo.ASCENDING}} - ] if combiners is not None else [ - {"$group": {"_id": "$combiner", "count": {"$sum": 1}}}, - {"$sort": {sort_key: sort_order, "_id": pymongo.ASCENDING}} - ] + pipeline = ( + [ + {"$match": {"combiner": {"$in": combiners}, "status": "online"}}, + {"$group": {"_id": "$combiner", "count": {"$sum": 1}}}, + {"$sort": {sort_key: sort_order, "_id": pymongo.ASCENDING}}, + ] + if combiners is not None + else [{"$group": {"_id": "$combiner", "count": {"$sum": 1}}}, {"$sort": {sort_key: sort_order, "_id": pymongo.ASCENDING}}] + ) result = self.clients.aggregate(pipeline) @@ -911,7 +855,7 @@ def drop_status(self): self.status.drop() def create_session(self, id=None): - """ Create a new session object. + """Create a new session object. :param id: The ID of the created session. :type id: uuid, str @@ -919,11 +863,11 @@ def create_session(self, id=None): """ if not id: id = uuid.uuid4() - data = {'session_id': str(id)} + data = {"session_id": str(id)} self.sessions.insert_one(data) def create_round(self, round_data): - """ Create a new round. + """Create a new round. :param round_data: Dictionary with round data. :type round_data: dict @@ -939,8 +883,7 @@ def set_session_config(self, id, config): :param config: Session configuration :type config: dict """ - self.sessions.update_one({'session_id': str(id)}, { - '$push': {'session_config': config}}, True) + self.sessions.update_one({"session_id": str(id)}, {"$push": {"session_config": config}}, True) def set_session_status(self, id, status): """Set session status. @@ -949,8 +892,7 @@ def set_session_status(self, id, status): :type round_id: str :param round_status: The status of the session. """ - self.sessions.update_one({'session_id': str(id)}, { - '$set': {'status': status}}, True) + self.sessions.update_one({"session_id": str(id)}, {"$set": {"status": status}}, True) def set_round_combiner_data(self, data): """Set combiner round controller data. @@ -958,8 +900,7 @@ def set_round_combiner_data(self, data): :param data: The combiner data :type data: dict """ - self.rounds.update_one({'round_id': str(data['round_id'])}, { - '$push': {'combiners': data}}, True) + self.rounds.update_one({"round_id": str(data["round_id"])}, {"$push": {"combiners": data}}, True) def set_round_config(self, round_id, round_config): """Set round configuration. @@ -969,8 +910,7 @@ def set_round_config(self, round_id, round_config): :param round_config: The round configuration :type round_config: dict """ - self.rounds.update_one({'round_id': round_id}, { - '$set': {'round_config': round_config}}, True) + self.rounds.update_one({"round_id": round_id}, {"$set": {"round_config": round_config}}, True) def set_round_status(self, round_id, round_status): """Set round status. @@ -979,8 +919,7 @@ def set_round_status(self, round_id, round_status): :type round_id: str :param round_status: The status of the round. """ - self.rounds.update_one({'round_id': round_id}, { - '$set': {'status': round_status}}, True) + self.rounds.update_one({"round_id": round_id}, {"$set": {"status": round_status}}, True) def set_round_data(self, round_id, round_data): """Update round metadata @@ -990,11 +929,10 @@ def set_round_data(self, round_id, round_data): :param round_data: The round metadata :type round_data: dict """ - self.rounds.update_one({'round_id': round_id}, { - '$set': {'round_data': round_data}}, True) + self.rounds.update_one({"round_id": round_id}, {"$set": {"round_data": round_data}}, True) def update_client_status(self, clients, status): - """ Update client status in statestore. + """Update client status in statestore. :param client_name: The client name :type client_name: str :param status: The client status diff --git a/fedn/network/storage/statestore/statestorebase.py b/fedn/network/storage/statestore/statestorebase.py index f41e3c025..7c6681682 100644 --- a/fedn/network/storage/statestore/statestorebase.py +++ b/fedn/network/storage/statestore/statestorebase.py @@ -2,22 +2,19 @@ class StateStoreBase(ABC): - """ - - """ + """ """ def __init__(self): pass @abstractmethod def state(self): - """ Return the current state of the statestore. - """ + """Return the current state of the statestore.""" pass @abstractmethod def transition(self, state): - """ Transition the statestore to a new state. + """Transition the statestore to a new state. :param state: The new state. :type state: str @@ -26,7 +23,7 @@ def transition(self, state): @abstractmethod def set_latest_model(self, model_id): - """ Set the latest model id in the statestore. + """Set the latest model id in the statestore. :param model_id: The model id. :type model_id: str @@ -35,7 +32,7 @@ def set_latest_model(self, model_id): @abstractmethod def get_latest_model(self): - """ Get the latest model id from the statestore. + """Get the latest model id from the statestore. :return: The model object. :rtype: ObjectId @@ -44,7 +41,7 @@ def get_latest_model(self): @abstractmethod def is_inited(self): - """ Check if the statestore is initialized. + """Check if the statestore is initialized. :return: True if initialized, else False. :rtype: bool diff --git a/fedn/utils/checksum.py b/fedn/utils/checksum.py index 3c7bbd3ec..8ca678597 100644 --- a/fedn/utils/checksum.py +++ b/fedn/utils/checksum.py @@ -2,7 +2,7 @@ def sha(fname): - """ Calculate the sha256 checksum of a file. Used for computing checksums of compute packages. + """Calculate the sha256 checksum of a file. Used for computing checksums of compute packages. :param fname: The file path. :type fname: str diff --git a/fedn/utils/dispatcher.py b/fedn/utils/dispatcher.py index 77f357249..5d00021b1 100644 --- a/fedn/utils/dispatcher.py +++ b/fedn/utils/dispatcher.py @@ -72,10 +72,7 @@ def _validate_virtualenv_is_available(): on how to install virtualenv. """ if not _is_virtualenv_available(): - raise Exception( - "Could not find the virtualenv binary. Run `pip install virtualenv` to install " - "virtualenv." - ) + raise Exception("Could not find the virtualenv binary. Run `pip install virtualenv` to install " "virtualenv.") def _get_virtualenv_extra_env_vars(env_root_dir=None): @@ -95,9 +92,7 @@ def _get_python_env(python_env_file): return _PythonEnv.from_yaml(python_env_file) -def _create_virtualenv( - python_bin_path, env_dir, python_env, extra_env=None, capture_output=False -): +def _create_virtualenv(python_bin_path, env_dir, python_env, extra_env=None, capture_output=False): # Created a command to activate the environment paths = ("bin", "activate") if _IS_UNIX else ("Scripts", "activate.bat") activate_cmd = env_dir.joinpath(*paths) @@ -110,8 +105,7 @@ def _create_virtualenv( with remove_on_error( env_dir, onerror=lambda e: logger.warning( - "Encountered an unexpected error: %s while creating a virtualenv environment in %s, " - "removing the environment directory...", + "Encountered an unexpected error: %s while creating a virtualenv environment in %s, " "removing the environment directory...", repr(e), env_dir, ), @@ -127,9 +121,7 @@ def _create_virtualenv( with tempfile.TemporaryDirectory() as tmpdir: tmp_req_file = f"requirements.{uuid.uuid4().hex}.txt" Path(tmpdir).joinpath(tmp_req_file).write_text("\n".join(deps)) - cmd = _join_commands( - activate_cmd, f"python -m pip install -r {tmp_req_file}" - ) + cmd = _join_commands(activate_cmd, f"python -m pip install -r {tmp_req_file}") _exec_cmd(cmd, capture_output=capture_output, cwd=tmpdir, extra_env=extra_env) return activate_cmd @@ -138,20 +130,17 @@ def _create_virtualenv( def _read_yaml_file(file_path): try: cfg = None - with open(file_path, 'rb') as config_file: - + with open(file_path, "rb") as config_file: cfg = yaml.safe_load(config_file.read()) except Exception as e: - logger.error( - f"Error trying to read yaml file: {file_path}" - ) + logger.error(f"Error trying to read yaml file: {file_path}") raise e return cfg class Dispatcher: - """ Dispatcher class for compute packages. + """Dispatcher class for compute packages. :param config: The configuration. :type config: dict @@ -160,7 +149,7 @@ class Dispatcher: """ def __init__(self, config, project_dir): - """ Initialize the dispatcher.""" + """Initialize the dispatcher.""" self.config = config self.project_dir = project_dir self.activate_cmd = "" @@ -174,10 +163,7 @@ def _get_or_create_python_env(self, capture_output=False, pip_requirements_overr else: python_env_path = os.path.join(self.project_dir, python_env) if not os.path.exists(python_env_path): - raise Exception( - "Compute package specified python_env file %s, but no such " - "file was found." % python_env_path - ) + raise Exception("Compute package specified python_env file %s, but no such " "file was found." % python_env_path) python_env = _get_python_env(python_env_path) extra_env = _get_virtualenv_extra_env_vars() @@ -201,10 +187,7 @@ def _get_or_create_python_env(self, capture_output=False, pip_requirements_overr ) # Install additional dependencies specified by `requirements_override` if pip_requirements_override: - logger.info( - "Installing additional dependencies specified by " - f"pip_requirements_override: {pip_requirements_override}" - ) + logger.info("Installing additional dependencies specified by " f"pip_requirements_override: {pip_requirements_override}") cmd = _join_commands( activate_cmd, f"python -m pip install --quiet -U {' '.join(pip_requirements_override)}", @@ -222,29 +205,23 @@ def _get_or_create_python_env(self, capture_output=False, pip_requirements_overr raise - def run_cmd(self, - cmd_type, - capture_output=False, - extra_env=None, - synchronous=True, - stream_output=False - ): - """ Run a command. + def run_cmd(self, cmd_type, capture_output=False, extra_env=None, synchronous=True, stream_output=False): + """Run a command. :param cmd_type: The command type. :type cmd_type: str :return: """ try: - cmdsandargs = cmd_type.split(' ') + cmdsandargs = cmd_type.split(" ") - entry_point = self.config['entry_points'][cmdsandargs[0]]['command'] + entry_point = self.config["entry_points"][cmdsandargs[0]]["command"] # remove the first element, that is not a file but a command args = cmdsandargs[1:] # Join entry point and arguments into a single command as a string - entry_point_args = ' '.join(args) + entry_point_args = " ".join(args) entry_point = f"{entry_point} {entry_point_args}" if self.activate_cmd: @@ -252,7 +229,7 @@ def run_cmd(self, else: cmd = _join_commands(entry_point) - logger.info('Running command: {}'.format(cmd)) + logger.info("Running command: {}".format(cmd)) _exec_cmd( cmd, cwd=self.project_dir, @@ -263,7 +240,7 @@ def run_cmd(self, stream_output=stream_output, ) - logger.info('Done executing {}'.format(cmd_type)) + logger.info("Done executing {}".format(cmd_type)) except IndexError: message = "No such argument or configuration to run." logger.error(message) diff --git a/fedn/utils/flowercompat/client_app_adapter.py b/fedn/utils/flowercompat/client_app_adapter.py index 610e5ca9f..15de3ff6b 100644 --- a/fedn/utils/flowercompat/client_app_adapter.py +++ b/fedn/utils/flowercompat/client_app_adapter.py @@ -1,16 +1,27 @@ from typing import Tuple from flwr.client import ClientApp -from flwr.common import (Context, EvaluateIns, FitIns, GetParametersIns, - Message, MessageType, MessageTypeLegacy, Metadata, - NDArrays, ndarrays_to_parameters, - parameters_to_ndarrays) -from flwr.common.recordset_compat import (evaluateins_to_recordset, - fitins_to_recordset, - getparametersins_to_recordset, - recordset_to_evaluateres, - recordset_to_fitres, - recordset_to_getparametersres) +from flwr.common import ( + Context, + EvaluateIns, + FitIns, + GetParametersIns, + Message, + MessageType, + MessageTypeLegacy, + Metadata, + NDArrays, + ndarrays_to_parameters, + parameters_to_ndarrays, +) +from flwr.common.recordset_compat import ( + evaluateins_to_recordset, + fitins_to_recordset, + getparametersins_to_recordset, + recordset_to_evaluateres, + recordset_to_fitres, + recordset_to_getparametersres, +) class FlwrClientAppAdapter: @@ -21,23 +32,21 @@ def __init__(self, app: ClientApp) -> None: def init_parameters(self, partition_id: int, config: dict = {}): # Construct a get_parameters message for the ClientApp - message, context = self._construct_message( - MessageTypeLegacy.GET_PARAMETERS, [], partition_id, config - ) + message, context = self._construct_message(MessageTypeLegacy.GET_PARAMETERS, [], partition_id, config) # Call client app with train message client_return_message = self.app(message, context) # return NDArrays of clients parameters parameters = self._parse_get_parameters_message(client_return_message) if len(parameters) == 0: - raise ValueError("The 'parameters' list is empty. Ensure your flower \ - client has implemented a get_parameters() function.") + raise ValueError( + "The 'parameters' list is empty. Ensure your flower \ + client has implemented a get_parameters() function." + ) return parameters def train(self, parameters: NDArrays, partition_id: int, config: dict = {}): # Construct a train message for the ClientApp with given parameters - message, context = self._construct_message( - MessageType.TRAIN, parameters, partition_id, config - ) + message, context = self._construct_message(MessageType.TRAIN, parameters, partition_id, config) # Call client app with train message client_return_message = self.app(message, context) # Parse return message @@ -46,9 +55,7 @@ def train(self, parameters: NDArrays, partition_id: int, config: dict = {}): def evaluate(self, parameters: NDArrays, partition_id: int, config: dict = {}): # Construct an evaluate message for the ClientApp with given parameters - message, context = self._construct_message( - MessageType.EVALUATE, parameters, partition_id, config - ) + message, context = self._construct_message(MessageType.EVALUATE, parameters, partition_id, config) # Call client app with evaluate message client_return_message = self.app(message, context) # Parse return message diff --git a/fedn/utils/helpers/helperbase.py b/fedn/utils/helpers/helperbase.py index a59ee49d8..3377d0336 100644 --- a/fedn/utils/helpers/helperbase.py +++ b/fedn/utils/helpers/helperbase.py @@ -4,16 +4,16 @@ class HelperBase(ABC): - """ Abstract class defining helpers. """ + """Abstract class defining helpers.""" def __init__(self): - """ Initialize helper. """ + """Initialize helper.""" self.name = self.__class__.__name__ @abstractmethod def increment_average(self, m1, m2, a, W): - """ Compute one increment of incremental weighted averaging. + """Compute one increment of incremental weighted averaging. :param m1: Current model weights in array-like format. :param m2: New model weights in array-like format. @@ -25,7 +25,7 @@ def increment_average(self, m1, m2, a, W): @abstractmethod def save(self, model, path): - """ Serialize weights to file. The serialized model must be a single binary object. + """Serialize weights to file. The serialized model must be a single binary object. :param model: Weights in array-like format. :param path: Path to file. @@ -35,7 +35,7 @@ def save(self, model, path): @abstractmethod def load(self, fh): - """ Load weights from file or filelike. + """Load weights from file or filelike. :param fh: file path, filehandle, filelike. :return: Weights in array-like format. @@ -43,10 +43,10 @@ def load(self, fh): pass def get_tmp_path(self): - """ Return a temporary output path compatible with save_model, load_model. + """Return a temporary output path compatible with save_model, load_model. :return: Path to file. """ - fd, path = tempfile.mkstemp(suffix='.npz') + fd, path = tempfile.mkstemp(suffix=".npz") os.close(fd) return path diff --git a/fedn/utils/helpers/helpers.py b/fedn/utils/helpers/helpers.py index 841af58a0..7fbf83a81 100644 --- a/fedn/utils/helpers/helpers.py +++ b/fedn/utils/helpers/helpers.py @@ -5,7 +5,7 @@ def get_helper(helper_module_name): - """ Return an instance of the helper class. + """Return an instance of the helper class. :param helper_module_name: The name of the helper plugin module. :type helper_module_name: str @@ -18,24 +18,24 @@ def get_helper(helper_module_name): def save_metadata(metadata, filename): - """ Save metadata to file. + """Save metadata to file. :param metadata: The metadata to save. :type metadata: dict :param filename: The name of the file to save to. :type filename: str """ - with open(filename+'-metadata', 'w') as outfile: + with open(filename + "-metadata", "w") as outfile: json.dump(metadata, outfile) def save_metrics(metrics, filename): - """ Save metrics to file. + """Save metrics to file. :param metrics: The metrics to save. :type metrics: dict :param filename: The name of the file to save to. :type filename: str """ - with open(filename, 'w') as outfile: + with open(filename, "w") as outfile: json.dump(metrics, outfile) diff --git a/fedn/utils/helpers/plugins/androidhelper.py b/fedn/utils/helpers/plugins/androidhelper.py index 6a9fc7f9d..119801b8c 100644 --- a/fedn/utils/helpers/plugins/androidhelper.py +++ b/fedn/utils/helpers/plugins/androidhelper.py @@ -18,9 +18,7 @@ def __init__(self): # function to calculate an incremental weighted average of the weights - def increment_average( - self, model, model_next, num_examples, total_examples - ): + def increment_average(self, model, model_next, num_examples, total_examples): """Incremental weighted average of model weights. :param model: Current model weights. @@ -40,9 +38,7 @@ def increment_average( return (1 - w) * model + w * model_next # function to calculate an incremental weighted average of the weights using numpy.add - def increment_average_add( - self, model, model_next, num_examples, total_examples - ): + def increment_average_add(self, model, model_next, num_examples, total_examples): """Incremental weighted average of model weights. :param model: Current model weights. @@ -59,9 +55,7 @@ def increment_average_add( # Incremental weighted average w = np.add( model, - num_examples - * (np.array(model_next) - np.array(model)) - / total_examples, + num_examples * (np.array(model_next) - np.array(model)) / total_examples, ) return w @@ -75,7 +69,7 @@ def save(self, weights, path=None): if not path: path = self.get_tmp_path() - byte_array = struct.pack("f"*len(weights), *weights) + byte_array = struct.pack("f" * len(weights), *weights) with open(path, "wb") as file: file.write(byte_array) @@ -94,7 +88,7 @@ def load(self, fh): else: byte_data = fh.read() - weights = np.array(struct.unpack(f'{len(byte_data) // 4}f', byte_data)) + weights = np.array(struct.unpack(f"{len(byte_data) // 4}f", byte_data)) return weights diff --git a/fedn/utils/helpers/plugins/numpyhelper.py b/fedn/utils/helpers/plugins/numpyhelper.py index 9751e3903..ce6c29420 100644 --- a/fedn/utils/helpers/plugins/numpyhelper.py +++ b/fedn/utils/helpers/plugins/numpyhelper.py @@ -1,19 +1,18 @@ - import numpy as np from fedn.utils.helpers.helperbase import HelperBase class Helper(HelperBase): - """ FEDn helper class for models weights/parameters that can be transformed to numpy ndarrays. """ + """FEDn helper class for models weights/parameters that can be transformed to numpy ndarrays.""" def __init__(self): - """ Initialize helper. """ + """Initialize helper.""" super().__init__() self.name = "numpyhelper" def increment_average(self, m1, m2, n, N): - """ Update a weighted incremental average of model weights. + """Update a weighted incremental average of model weights. :param m1: Current parameters. :type model: list of numpy ndarray @@ -27,10 +26,10 @@ def increment_average(self, m1, m2, n, N): :rtype: list of numpy ndarray """ - return [np.add(x, n*(y-x)/N) for x, y in zip(m1, m2)] + return [np.add(x, n * (y - x) / N) for x, y in zip(m1, m2)] def add(self, m1, m2, a=1.0, b=1.0): - """ m1*a + m2*b + """m1*a + m2*b :param model: Current model weights. :type model: list of ndarrays @@ -40,10 +39,10 @@ def add(self, m1, m2, a=1.0, b=1.0): :rtype: list of ndarrays """ - return [x*a+y*b for x, y in zip(m1, m2)] + return [x * a + y * b for x, y in zip(m1, m2)] def subtract(self, m1, m2, a=1.0, b=1.0): - """ m1*a - m2*b. + """m1*a - m2*b. :param m1: Current model weights. :type m1: list of ndarrays @@ -55,7 +54,7 @@ def subtract(self, m1, m2, a=1.0, b=1.0): return self.add(m1, m2, a, -b) def divide(self, m1, m2): - """ Subtract weights. + """Subtract weights. :param m1: Current model weights. :type m1: list of ndarrays @@ -68,7 +67,7 @@ def divide(self, m1, m2): return [np.divide(x, y) for x, y in zip(m1, m2)] def multiply(self, m1, m2): - """ Multiply m1 by m2. + """Multiply m1 by m2. :param m1: Current model weights. :type m1: list of ndarrays @@ -81,7 +80,7 @@ def multiply(self, m1, m2): return [np.multiply(x, y) for (x, y) in zip(m1, m2)] def sqrt(self, m1): - """ Sqrt of m1, element-wise. + """Sqrt of m1, element-wise. :param m1: Current model weights. :type model: list of ndarrays @@ -94,7 +93,7 @@ def sqrt(self, m1): return [np.sqrt(x) for x in m1] def power(self, m1, a): - """ m1 raised to the power of m2. + """m1 raised to the power of m2. :param m1: Current model weights. :type m1: list of ndarrays @@ -107,7 +106,7 @@ def power(self, m1, a): return [np.power(x, a) for x in m1] def norm(self, m): - """ Return the norm (L1) of model weights. + """Return the norm (L1) of model weights. :param m: Current model weights. :type m: list of ndarrays @@ -120,7 +119,7 @@ def norm(self, m): return n def sign(self, m): - """ Sign of m. + """Sign of m. :param m: Model parameters. :type m: list of ndarrays @@ -131,7 +130,7 @@ def sign(self, m): return [np.sign(x) for x in m] def ones(self, m1, a): - """ Return a list of numpy arrays of the same shape as m1, filled with ones. + """Return a list of numpy arrays of the same shape as m1, filled with ones. :param m1: Current model weights. :type m1: list of ndarrays @@ -143,11 +142,11 @@ def ones(self, m1, a): res = [] for x in m1: - res.append(np.ones(np.shape(x))*a) + res.append(np.ones(np.shape(x)) * a) return res def save(self, weights, path=None): - """ Serialize weights to file. The serialized model must be a single binary object. + """Serialize weights to file. The serialized model must be a single binary object. :param weights: List of weights in numpy format. :param path: Path to file. @@ -165,7 +164,7 @@ def save(self, weights, path=None): return path def load(self, fh): - """ Load weights from file or filelike. + """Load weights from file or filelike. :param fh: file path, filehandle, filelike. :return: List of weights in numpy format. diff --git a/fedn/utils/plots.py b/fedn/utils/plots.py index 8ec81aca2..7901e2374 100644 --- a/fedn/utils/plots.py +++ b/fedn/utils/plots.py @@ -11,17 +11,14 @@ class Plot: - """ - - """ + """ """ def __init__(self, statestore): try: statestore_config = statestore.get_config() - statestore = MongoStateStore( - statestore_config['network_id'], statestore_config['mongo_config']) + statestore = MongoStateStore(statestore_config["network_id"], statestore_config["mongo_config"]) self.mdb = statestore.connect() - self.status = self.mdb['control.status'] + self.status = self.mdb["control.status"] self.round_time = self.mdb["control.round_time"] self.combiner_round_time = self.mdb["control.combiner_round_time"] self.psutil_usage = self.mdb["control.psutil_monitoring"] @@ -34,10 +31,10 @@ def __init__(self, statestore): # plot metrics from DB def _scalar_metrics(self, metrics): - """ Extract all scalar valued metrics from a MODEL_VALIDATON. """ + """Extract all scalar valued metrics from a MODEL_VALIDATON.""" - data = json.loads(metrics['data']) - data = json.loads(data['data']) + data = json.loads(metrics["data"]) + data = json.loads(data["data"]) valid_metrics = [] for metric, val in data.items(): @@ -55,18 +52,17 @@ def create_table_plot(self): :return: """ - metrics = self.status.find_one({'type': 'MODEL_VALIDATION'}) + metrics = self.status.find_one({"type": "MODEL_VALIDATION"}) if metrics is None: fig = go.Figure(data=[]) - fig.update_layout( - title_text='No data currently available for table mean metrics') + fig.update_layout(title_text="No data currently available for table mean metrics") table = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) return False valid_metrics = self._scalar_metrics(metrics) if valid_metrics == []: fig = go.Figure(data=[]) - fig.update_layout(title_text='No scalar metrics found') + fig.update_layout(title_text="No scalar metrics found") table = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) return False @@ -74,14 +70,12 @@ def create_table_plot(self): models = [] for metric in valid_metrics: validations = {} - for post in self.status.find({'type': 'MODEL_VALIDATION'}): - e = json.loads(post['data']) + for post in self.status.find({"type": "MODEL_VALIDATION"}): + e = json.loads(post["data"]) try: - validations[e['modelId']].append( - float(json.loads(e['data'])[metric])) + validations[e["modelId"]].append(float(json.loads(e["data"])[metric])) except KeyError: - validations[e['modelId']] = [ - float(json.loads(e['data'])[metric])] + validations[e["modelId"]] = [float(json.loads(e["data"])[metric])] vals = [] models = [] @@ -98,19 +92,21 @@ def create_table_plot(self): vals.reverse() values.append(vals) - fig = go.Figure(data=[go.Table( - header=dict(values=['Model ID'] + header_vals, - line_color='darkslategray', - fill_color='lightskyblue', - align='left'), - - cells=dict(values=values, # 2nd column - line_color='darkslategray', - fill_color='lightcyan', - align='left')) - ]) + fig = go.Figure( + data=[ + go.Table( + header=dict(values=["Model ID"] + header_vals, line_color="darkslategray", fill_color="lightskyblue", align="left"), + cells=dict( + values=values, # 2nd column + line_color="darkslategray", + fill_color="lightcyan", + align="left", + ), + ) + ] + ) - fig.update_layout(title_text='Summary: mean metrics') + fig.update_layout(title_text="Summary: mean metrics") table = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) return table @@ -123,63 +119,67 @@ def create_timeline_plot(self): x = [] y = [] base = [] - for p in self.status.find({'type': 'MODEL_UPDATE_REQUEST'}): - e = json.loads(p['data']) - cid = e['correlationId'] - for cc in self.status.find({'sender': p['sender'], 'type': 'MODEL_UPDATE'}): - da = json.loads(cc['data']) - if da['correlationId'] == cid: + for p in self.status.find({"type": "MODEL_UPDATE_REQUEST"}): + e = json.loads(p["data"]) + cid = e["correlationId"] + for cc in self.status.find({"sender": p["sender"], "type": "MODEL_UPDATE"}): + da = json.loads(cc["data"]) + if da["correlationId"] == cid: cp = cc - cd = json.loads(cp['data']) - tr = datetime.strptime(e['timestamp'], '%Y-%m-%d %H:%M:%S.%f') - tu = datetime.strptime(cd['timestamp'], '%Y-%m-%d %H:%M:%S.%f') + cd = json.loads(cp["data"]) + tr = datetime.strptime(e["timestamp"], "%Y-%m-%d %H:%M:%S.%f") + tu = datetime.strptime(cd["timestamp"], "%Y-%m-%d %H:%M:%S.%f") ts = tu - tr base.append(tr.timestamp()) x.append(ts.total_seconds() / 60.0) - y.append(p['sender']['name']) - - trace_data.append(go.Bar( - x=y, - y=x, - marker=dict(color='royalblue'), - name="Training", - )) + y.append(p["sender"]["name"]) + + trace_data.append( + go.Bar( + x=y, + y=x, + marker=dict(color="royalblue"), + name="Training", + ) + ) x = [] y = [] base = [] - for p in self.status.find({'type': 'MODEL_VALIDATION_REQUEST'}): - e = json.loads(p['data']) - cid = e['correlationId'] - for cc in self.status.find({'sender': p['sender'], 'type': 'MODEL_VALIDATION'}): - da = json.loads(cc['data']) - if da['correlationId'] == cid: + for p in self.status.find({"type": "MODEL_VALIDATION_REQUEST"}): + e = json.loads(p["data"]) + cid = e["correlationId"] + for cc in self.status.find({"sender": p["sender"], "type": "MODEL_VALIDATION"}): + da = json.loads(cc["data"]) + if da["correlationId"] == cid: cp = cc - cd = json.loads(cp['data']) - tr = datetime.strptime(e['timestamp'], '%Y-%m-%d %H:%M:%S.%f') - tu = datetime.strptime(cd['timestamp'], '%Y-%m-%d %H:%M:%S.%f') + cd = json.loads(cp["data"]) + tr = datetime.strptime(e["timestamp"], "%Y-%m-%d %H:%M:%S.%f") + tu = datetime.strptime(cd["timestamp"], "%Y-%m-%d %H:%M:%S.%f") ts = tu - tr base.append(tr.timestamp()) x.append(ts.total_seconds() / 60.0) - y.append(p['sender']['name']) - - trace_data.append(go.Bar( - x=y, - y=x, - marker=dict(color='lightskyblue'), - name="Validation", - )) + y.append(p["sender"]["name"]) + + trace_data.append( + go.Bar( + x=y, + y=x, + marker=dict(color="lightskyblue"), + name="Validation", + ) + ) layout = go.Layout( - barmode='stack', + barmode="stack", showlegend=True, ) fig = go.Figure(data=trace_data, layout=layout) - fig.update_xaxes(title_text='Alliance/client') - fig.update_yaxes(title_text='Time (Min)') - fig.update_layout(title_text='Alliance timeline') + fig.update_xaxes(title_text="Alliance/client") + fig.update_yaxes(title_text="Time (Min)") + fig.update_layout(title_text="Alliance timeline") timeline = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) return timeline @@ -189,16 +189,15 @@ def create_client_training_distribution(self): :return: """ training = [] - for p in self.status.find({'type': 'MODEL_UPDATE'}): - e = json.loads(p['data']) - meta = json.loads(e['meta']) - training.append(meta['exec_training']) + for p in self.status.find({"type": "MODEL_UPDATE"}): + e = json.loads(p["data"]) + meta = json.loads(e["meta"]) + training.append(meta["exec_training"]) if not training: return False fig = go.Figure(data=go.Histogram(x=training)) - fig.update_layout( - title_text='Client model training time, mean: {}'.format(numpy.mean(training))) + fig.update_layout(title_text="Client model training time, mean: {}".format(numpy.mean(training))) histogram = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) return histogram @@ -208,19 +207,18 @@ def create_client_histogram_plot(self): :return: """ training = [] - for p in self.status.find({'type': 'MODEL_UPDATE'}): - e = json.loads(p['data']) - meta = json.loads(e['meta']) - training.append(meta['exec_training']) + for p in self.status.find({"type": "MODEL_UPDATE"}): + e = json.loads(p["data"]) + meta = json.loads(e["meta"]) + training.append(meta["exec_training"]) fig = go.Figure() fig.update_layout( template="simple_white", xaxis=dict(title_text="Time (s)"), - yaxis=dict(title_text='Number of updates'), - title="Mean client training time: {}".format( - numpy.mean(training)), + yaxis=dict(title_text="Number of updates"), + title="Mean client training time: {}".format(numpy.mean(training)), # showlegend=True ) if not training: @@ -240,21 +238,16 @@ def create_client_plot(self): upload = [] download = [] training = [] - for p in self.status.find({'type': 'MODEL_UPDATE'}): - e = json.loads(p['data']) - meta = json.loads(e['meta']) - upload.append(meta['upload_model']) - download.append(meta['fetch_model']) - training.append(meta['exec_training']) - processing.append(meta['processing_time']) + for p in self.status.find({"type": "MODEL_UPDATE"}): + e = json.loads(p["data"]) + meta = json.loads(e["meta"]) + upload.append(meta["upload_model"]) + download.append(meta["fetch_model"]) + training.append(meta["exec_training"]) + processing.append(meta["processing_time"]) fig = go.Figure() - fig.update_layout( - template="simple_white", - title="Mean client processing time: {}".format( - numpy.mean(processing)), - showlegend=True - ) + fig.update_layout(template="simple_white", title="Mean client processing time: {}".format(numpy.mean(processing)), showlegend=True) if not processing: return False data = [numpy.mean(training), numpy.mean(upload), numpy.mean(download)] @@ -273,32 +266,25 @@ def create_combiner_plot(self): aggregation = [] model_load = [] combination = [] - for round in self.mdb['control.round'].find(): + for round in self.mdb["control.round"].find(): try: - for combiner in round['combiners']: + for combiner in round["combiners"]: data = combiner - stats = data['local_round']['1'] - ml = stats['aggregation_time']['time_model_load'] - ag = stats['aggregation_time']['time_model_aggregation'] - combination.append(stats['time_combination']) - waiting.append(stats['time_combination'] - ml - ag) + stats = data["local_round"]["1"] + ml = stats["aggregation_time"]["time_model_load"] + ag = stats["aggregation_time"]["time_model_aggregation"] + combination.append(stats["time_combination"]) + waiting.append(stats["time_combination"] - ml - ag) model_load.append(ml) aggregation.append(ag) except Exception: pass - labels = ['Waiting for client updates', - 'Aggregation', 'Loading model updates from disk'] - val = [numpy.mean(waiting), numpy.mean( - aggregation), numpy.mean(model_load)] + labels = ["Waiting for client updates", "Aggregation", "Loading model updates from disk"] + val = [numpy.mean(waiting), numpy.mean(aggregation), numpy.mean(model_load)] fig = go.Figure() - fig.update_layout( - template="simple_white", - title="Mean combiner round time: {}".format( - numpy.mean(combination)), - showlegend=True - ) + fig.update_layout(template="simple_white", title="Mean combiner round time: {}".format(numpy.mean(combination)), showlegend=True) if not combination: return False fig.add_trace(go.Pie(labels=labels, values=val)) @@ -310,7 +296,7 @@ def fetch_valid_metrics(self): :return: """ - metrics = self.status.find_one({'type': 'MODEL_VALIDATION'}) + metrics = self.status.find_one({"type": "MODEL_VALIDATION"}) valid_metrics = self._scalar_metrics(metrics) return valid_metrics @@ -320,34 +306,31 @@ def create_box_plot(self, metric): :param metric: :return: """ - metrics = self.status.find_one({'type': 'MODEL_VALIDATION'}) + metrics = self.status.find_one({"type": "MODEL_VALIDATION"}) if metrics is None: fig = go.Figure(data=[]) - fig.update_layout(title_text='No data currently available for metric distribution over ' - 'participants') + fig.update_layout(title_text="No data currently available for metric distribution over " "participants") box = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) return box valid_metrics = self._scalar_metrics(metrics) if valid_metrics == []: fig = go.Figure(data=[]) - fig.update_layout(title_text='No scalar metrics found') + fig.update_layout(title_text="No scalar metrics found") box = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) return box validations = {} - for post in self.status.find({'type': 'MODEL_VALIDATION'}): - e = json.loads(post['data']) + for post in self.status.find({"type": "MODEL_VALIDATION"}): + e = json.loads(post["data"]) try: - validations[e['modelId']].append( - float(json.loads(e['data'])[metric])) + validations[e["modelId"]].append(float(json.loads(e["data"])[metric])) except KeyError: - validations[e['modelId']] = [ - float(json.loads(e['data'])[metric])] + validations[e["modelId"]] = [float(json.loads(e["data"])[metric])] # Make sure validations are plotted in chronological order - model_trail = self.mdb.control.model.find_one({'key': 'model_trail'}) - model_trail_ids = model_trail['model'] + model_trail = self.mdb.control.model.find_one({"key": "model_trail"}) + model_trail_ids = model_trail["model"] validations_sorted = [] for model_id in model_trail_ids: try: @@ -364,24 +347,16 @@ def create_box_plot(self, metric): # x.append(j) y.append(numpy.mean([float(i) for i in acc])) if len(acc) >= 2: - box.add_trace(go.Box(y=acc, name=str(j), marker_color="royalblue", showlegend=False, - boxpoints=False)) + box.add_trace(go.Box(y=acc, name=str(j), marker_color="royalblue", showlegend=False, boxpoints=False)) else: - box.add_trace(go.Scatter( - x=[str(j)], y=[y[j]], showlegend=False)) + box.add_trace(go.Scatter(x=[str(j)], y=[y[j]], showlegend=False)) rounds = list(range(len(y))) - box.add_trace(go.Scatter( - x=rounds, - y=y, - name='Mean' - )) - - box.update_xaxes(title_text='Rounds') - box.update_yaxes( - tickvals=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]) - box.update_layout(title_text='Metric distribution over clients: {}'.format(metric), - margin=dict(l=20, r=20, t=45, b=20)) + box.add_trace(go.Scatter(x=rounds, y=y, name="Mean")) + + box.update_xaxes(title_text="Rounds") + box.update_yaxes(tickvals=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]) + box.update_layout(title_text="Metric distribution over clients: {}".format(metric), margin=dict(l=20, r=20, t=45, b=20)) box = json.dumps(box, cls=plotly.utils.PlotlyJSONEncoder) return box @@ -391,38 +366,27 @@ def create_round_plot(self): :return: """ trace_data = [] - metrics = self.round_time.find_one({'key': 'round_time'}) + metrics = self.round_time.find_one({"key": "round_time"}) if metrics is None: fig = go.Figure(data=[]) - fig.update_layout( - title_text='No data currently available for round time') + fig.update_layout(title_text="No data currently available for round time") return False - for post in self.round_time.find({'key': 'round_time'}): - rounds = post['round'] - traces_data = post['round_time'] + for post in self.round_time.find({"key": "round_time"}): + rounds = post["round"] + traces_data = post["round_time"] - trace_data.append(go.Scatter( - x=rounds, - y=traces_data, - mode='lines+markers', - name='Reducer' - )) + trace_data.append(go.Scatter(x=rounds, y=traces_data, mode="lines+markers", name="Reducer")) - for rec in self.combiner_round_time.find({'key': 'combiner_round_time'}): - c_traces_data = rec['round_time'] + for rec in self.combiner_round_time.find({"key": "combiner_round_time"}): + c_traces_data = rec["round_time"] - trace_data.append(go.Scatter( - x=rounds, - y=c_traces_data, - mode='lines+markers', - name='Combiner' - )) + trace_data.append(go.Scatter(x=rounds, y=c_traces_data, mode="lines+markers", name="Combiner")) fig = go.Figure(data=trace_data) - fig.update_xaxes(title_text='Round') - fig.update_yaxes(title_text='Time (s)') - fig.update_layout(title_text='Round time') + fig.update_xaxes(title_text="Round") + fig.update_yaxes(title_text="Time (s)") + fig.update_layout(title_text="Round time") round_t = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) return round_t @@ -431,46 +395,38 @@ def create_cpu_plot(self): :return: """ - metrics = self.psutil_usage.find_one({'key': 'cpu_mem_usage'}) + metrics = self.psutil_usage.find_one({"key": "cpu_mem_usage"}) if metrics is None: fig = go.Figure(data=[]) - fig.update_layout( - title_text='No data currently available for MEM and CPU usage') + fig.update_layout(title_text="No data currently available for MEM and CPU usage") cpu = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) return False - for post in self.psutil_usage.find({'key': 'cpu_mem_usage'}): - cpu = post['cpu'] - mem = post['mem'] - ps_time = post['time'] - round = post['round'] + for post in self.psutil_usage.find({"key": "cpu_mem_usage"}): + cpu = post["cpu"] + mem = post["mem"] + ps_time = post["time"] + round = post["round"] # Create figure with secondary y-axis fig = make_subplots(specs=[[{"secondary_y": True}]]) - fig.add_trace(go.Scatter( - x=ps_time, - y=cpu, - mode='lines+markers', - name='CPU (%)' - )) - - fig.add_trace(go.Scatter( - x=ps_time, - y=mem, - mode='lines+markers', - name='MEM (%)' - )) - - fig.add_trace(go.Scatter( - x=ps_time, - y=round, - mode='lines+markers', - name='Round', - ), secondary_y=True) - - fig.update_xaxes(title_text='Date Time') - fig.update_yaxes(title_text='Percentage (%)') + fig.add_trace(go.Scatter(x=ps_time, y=cpu, mode="lines+markers", name="CPU (%)")) + + fig.add_trace(go.Scatter(x=ps_time, y=mem, mode="lines+markers", name="MEM (%)")) + + fig.add_trace( + go.Scatter( + x=ps_time, + y=round, + mode="lines+markers", + name="Round", + ), + secondary_y=True, + ) + + fig.update_xaxes(title_text="Date Time") + fig.update_yaxes(title_text="Percentage (%)") fig.update_yaxes(title_text="Round", secondary_y=True) - fig.update_layout(title_text='CPU loads and memory usage') + fig.update_layout(title_text="CPU loads and memory usage") cpu = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) return cpu