diff --git a/.ci/tests/examples/run.sh b/.ci/tests/examples/run.sh index e21ed2b2c..aa498de5f 100755 --- a/.ci/tests/examples/run.sh +++ b/.ci/tests/examples/run.sh @@ -11,7 +11,7 @@ helper="$2" python -m venv ".$example" source ".$example/bin/activate" -".$example/bin/pip" install ./fedn/ fire +".$example/bin/pip" install . fire >&2 echo "Start FEDn" pushd "examples/$example" diff --git a/.devcontainer/bin/init_venv.sh b/.devcontainer/bin/init_venv.sh index c76670e77..89546305d 100755 --- a/.devcontainer/bin/init_venv.sh +++ b/.devcontainer/bin/init_venv.sh @@ -10,8 +10,5 @@ python -m venv .venv sphinx==4.4.0 \ sphinx_press_theme==0.8.0 \ sphinx-autobuild==2021.3.14 \ - autopep8==1.5.7 \ - isort==5.10.1 \ - flake8==4.0.1 \ sphinx_rtd_theme==0.5.2 -.venv/bin/pip install -e fedn \ No newline at end of file +.venv/bin/pip install -e . \ No newline at end of file diff --git a/.devcontainer/devcontainer.json.tpl b/.devcontainer/devcontainer.json.tpl index cdf276df5..29348f392 100644 --- a/.devcontainer/devcontainer.json.tpl +++ b/.devcontainer/devcontainer.json.tpl @@ -5,23 +5,26 @@ "remoteUser": "default", // "workspaceFolder": "/fedn", // "workspaceMount": "source=/path/to/fedn,target=/fedn,type=bind,consistency=default", - "extensions": [ - "ms-azuretools.vscode-docker", - "ms-python.python", - "exiasr.hadolint", - "yzhang.markdown-all-in-one", - "ms-python.isort" - ], + "customizations": { + "vscode": { + "extensions": [ + "ms-azuretools.vscode-docker", + "ms-python.python", + "exiasr.hadolint", + "yzhang.markdown-all-in-one", + "charliermarsh.ruff" + ] + } + }, "mounts": [ - "source=/var/run/docker.sock,target=/var/run/docker.sock,type=bind,consistency=default", + "source=/var/run/docker.sock,target=/var/run/docker.sock,type=bind,consistency=default" ], "runArgs": [ "--net=host" ], "build": { "args": { - "BASE_IMG": "python:3.9" + "BASE_IMG": "python:3.11" } } -} - +} \ No newline at end of file diff --git a/.dockerignore b/.dockerignore index 0df29e059..24f906be5 100644 --- a/.dockerignore +++ b/.dockerignore @@ -9,4 +9,5 @@ docs/* **/.mnist-pytorch **/*.npz **/data -**/*.tgz \ No newline at end of file +**/*.tgz +dist \ No newline at end of file diff --git a/.github/workflows/code-checks.yaml b/.github/workflows/code-checks.yaml index 939be75e7..83d01d9ed 100644 --- a/.github/workflows/code-checks.yaml +++ b/.github/workflows/code-checks.yaml @@ -11,30 +11,9 @@ jobs: - name: init venv run: .devcontainer/bin/init_venv.sh - - - name: check Python imports - run: > - .venv/bin/isort . --check --diff - --skip .venv - --skip .mnist-keras - --skip .mnist-pytorch - --skip fedn_pb2.py - --skip fedn_pb2_grpc.py - - - name: check Python formatting - run: > - .venv/bin/autopep8 --recursive --diff - --exclude .venv - --exclude .mnist-keras - --exclude .mnist-pytorch - --exclude fedn_pb2.py - --exclude fedn_pb2_grpc.py - . - - name: run Python linter - run: > - .venv/bin/flake8 . - --exclude ".venv,.mnist-keras,.mnist-pytorch,fedn_pb2.py,fedn_pb2_grpc.py" + - name: Ruff Linting + uses: chartboost/ruff-action@v1 - name: check for floating imports run: > diff --git a/.github/workflows/push-to-pypi.yaml b/.github/workflows/push-to-pypi.yaml index 8cfa91cee..1b59835ad 100644 --- a/.github/workflows/push-to-pypi.yaml +++ b/.github/workflows/push-to-pypi.yaml @@ -16,15 +16,15 @@ jobs: - name: Install pypa/build run: python -m pip install build - working-directory: ./fedn + working-directory: ./ - name: Build package run: python -m build - working-directory: ./fedn + working-directory: ./ - name: Publish to Test PyPI uses: pypa/gh-action-pypi-publish@v1.8.14 with: user: __token__ password: ${{ secrets.PYPI_TOKEN }} - packages_dir: fedn/dist + packages_dir: ./dist diff --git a/.gitignore b/.gitignore index 0f4259be8..b595e49c9 100644 --- a/.gitignore +++ b/.gitignore @@ -21,8 +21,8 @@ __pycache__/ # Distribution / packaging .Python build/ +package/ develop-eggs/ -#dist/ downloads/ eggs/ .eggs/ @@ -38,6 +38,7 @@ share/python-wheels/ .installed.cfg *.egg MANIFEST +dist # PyInstaller # Usually these files are written by a python script from a template diff --git a/Dockerfile b/Dockerfile index 5a7259859..49c91a3be 100644 --- a/Dockerfile +++ b/Dockerfile @@ -8,7 +8,7 @@ ARG GRPC_HEALTH_PROBE_VERSION="" ARG REQUIREMENTS="" # Add FEDn and default configs -COPY fedn /app/fedn +COPY . /app COPY config/settings-client.yaml.template /app/config/settings-client.yaml COPY config/settings-combiner.yaml.template /app/config/settings-combiner.yaml COPY config/settings-reducer.yaml.template /app/config/settings-reducer.yaml @@ -27,6 +27,8 @@ RUN if [ ! -z "$GRPC_HEALTH_PROBE_VERSION" ]; then \ echo "No grpc_health_probe version specified, skipping installation"; \ fi +# Setup working directory +WORKDIR /app # Create FEDn app directory SHELL ["/bin/bash", "-c"] @@ -39,7 +41,8 @@ RUN mkdir -p /app \ # Install FEDn and requirements && python -m venv /venv \ && /venv/bin/pip install --upgrade pip \ - && /venv/bin/pip install --no-cache-dir -e /app/fedn \ + && /venv/bin/pip install --no-cache-dir setuptools>=65 \ + && /venv/bin/pip install --no-cache-dir -e . \ && if [[ ! -z "$REQUIREMENTS" ]]; then \ /venv/bin/pip install --no-cache-dir -r /app/config/requirements.txt; \ fi \ @@ -47,6 +50,4 @@ RUN mkdir -p /app \ # Clean up && rm -r /app/config/requirements.txt -# Setup working directory -WORKDIR /app ENTRYPOINT [ "/venv/bin/fedn" ] \ No newline at end of file diff --git a/README.rst b/README.rst index a5e73cdd5..2dab7813a 100644 --- a/README.rst +++ b/README.rst @@ -30,61 +30,66 @@ We develop the FEDn framework following these core design principles: Features ========= -Federated machine learning: +Core FL framework (this repository): -- Support for any ML framework (e.g. PyTorch, Tensforflow/Keras and Scikit-learn) +- Tiered federated learning architecture enabling massive scalability and resilience. +- Support for any ML framework (examples for PyTorch, Tensforflow/Keras and Scikit-learn) - Extendable via a plug-in architecture (aggregators, load balancers, object storage backends, databases etc.) - Built-in federated algorithms (FedAvg, FedAdam, FedYogi, FedAdaGrad, etc.) -- CLI and Python API client for running FEDn networks and coordinating experiments. +- CLI and Python API. - Implement clients in any language (Python, C++, Kotlin etc.) - No open ports needed client-side. +- Flexible deployment of server-side components using Docker / docker compose. FEDn Studio - From development to FL in production: -- Leverage Scaleout's free managed service for development and testing in real-world scenarios (SaaS). -- Token-based authentication (JWT) and role-based access control (RBAC) for FL clients. -- REST API and UI. -- Data science dashboard for orchestrating experiments and visualizing results. -- Admin dashboard for managing the FEDn network and users/clients. -- View extensive logging and tracing information. -- Collaborate with other data-scientists on the project specification in a shared workspace. -- Cloud or on-premise deployment (cloud-native design, deploy to any Kubernetes cluster) +- Secure deployment of server-side / control-plane on Kubernetes. +- UI with dashboards for orchestrating experiments and visualizing results +- Team features - collaborate with other users in shared project workspaces. +- Features for the trusted-third party: Manage access to the FL network, FL clients and training progress. +- REST API for handling experiments/jobs. +- View and export logging and tracing information. +- Public cloud, dedicated cloud and on-premise deployment options. Getting started ============================ -The best way to get started is to take the quickstart tutorial: +Get started with FEDn in two steps: -- `Quickstart `__ +1. Sign up for a `Free FEDn Studio account `__ +2. Take the `Quickstart tutorial `__ + +FEDn Studio (SaaS) is free for academic use and personal development / small-scale testing and exploration. For users and teams requiring +additional project resources, dedicated support or other hosting options, `explore our plans `__. Documentation ============= -More details about the architecture, deployment, and how to develop your own application and framework extensions (such as custom aggregators) are found in the documentation: +More details about the architecture, deployment, and how to develop your own application and framework extensions are found in the documentation: - `Documentation `__ -Running your project in FEDn Studio (SaaS or on-premise) -======================================================== - -The FEDn Studio SaaS is free for development, testing and research (one project per user, backend compute resources sized for dev/test): +FEDn Studio Deployment options +============================== -- `Register for a free account in FEDn Studio `__ -- `Take the tutorial to deploy your project on FEDn Studio `__ +Several hosting options are available to suit different project settings. -Scaleout can also support users to scale up experiments and demonstrators on Studio, by granting custom resource quotas. Additonally, charts are available for self-managed deployment on-premise or in your cloud VPC (all major cloud providers). Contact the Scaleout team for more information. +- `Public cloud (multi-tenant) `__: Managed multi-tenant deployment in public cloud. +- Dedicated cloud (single-tenant): Managed, dedicated deployment in a cloud region of your choice (AWS, GCP, Azure, managed Kubernetes) +- Self-managed: Set up a self-managed deployment in your VPC or on-premise Kubernets cluster using Helm Chart and container images provided by Scaleout. +Contact the Scaleout team for information. Support ================= -Community support in available in our `Discord +Community support is available in our `Discord server `__. -Options are available for `Enterprise support `__. +Options are available for `Dedicated/custom support `__. Making contributions ==================== @@ -114,4 +119,4 @@ License FEDn is licensed under Apache-2.0 (see `LICENSE `__ file for full information). -Use of FEDn Studio (SaaS) is subject to the `Terms of Use `__. +Use of FEDn Studio is subject to the `Terms of Use `__. diff --git a/config/combiner-settings.override.yaml b/config/combiner-settings.override.yaml index fc3f86e12..bca431b7a 100644 --- a/config/combiner-settings.override.yaml +++ b/config/combiner-settings.override.yaml @@ -5,5 +5,5 @@ version: '3.3' services: combiner: volumes: - - ${HOST_REPO_DIR:-.}/fedn:/app/fedn + - ${HOST_REPO_DIR:-.}:/app - ${HOST_REPO_DIR:-.}/config/settings-combiner.yaml:/app/config/settings-combiner.yaml diff --git a/config/reducer-settings.override.yaml b/config/reducer-settings.override.yaml index af5ee5126..18e499f73 100644 --- a/config/reducer-settings.override.yaml +++ b/config/reducer-settings.override.yaml @@ -5,5 +5,5 @@ version: '3.3' services: reducer: volumes: - - ${HOST_REPO_DIR:-.}/fedn:/app/fedn + - ${HOST_REPO_DIR:-.}:/app - ${HOST_REPO_DIR:-.}/config/settings-reducer.yaml:/app/config/settings-reducer.yaml diff --git a/docker-compose.yaml b/docker-compose.yaml index 3f2f3c2be..b6563ace6 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -78,7 +78,7 @@ services: - mongo entrypoint: [ "sh", "-c" ] command: - - "/venv/bin/pip install --no-cache-dir -e /app/fedn && /venv/bin/python fedn/fedn/network/api/server.py" + - "/venv/bin/pip install --no-cache-dir -e . && /venv/bin/python fedn/network/api/server.py" ports: - 8092:8092 @@ -97,7 +97,7 @@ services: - ${HOST_REPO_DIR:-.}/fedn:/app/fedn entrypoint: [ "sh", "-c" ] command: - - "/venv/bin/pip install --no-cache-dir -e /app/fedn && /venv/bin/fedn run combiner --init config/settings-combiner.yaml" + - "/venv/bin/pip install --no-cache-dir -e . && /venv/bin/fedn combiner start --init config/settings-combiner.yaml" ports: - 12080:12080 healthcheck: @@ -127,7 +127,7 @@ services: - ${HOST_REPO_DIR:-.}/fedn:/app/fedn entrypoint: [ "sh", "-c" ] command: - - "/venv/bin/pip install --no-cache-dir -e /app/fedn && /venv/bin/fedn run client --init config/settings-client.yaml" + - "/venv/bin/pip install --no-cache-dir -e . && /venv/bin/fedn client start --init config/settings-client.yaml" deploy: replicas: 0 depends_on: diff --git a/docs/aggregators.rst b/docs/aggregators.rst index 756b858c7..7075b1287 100644 --- a/docs/aggregators.rst +++ b/docs/aggregators.rst @@ -3,7 +3,9 @@ Aggregators =========== -Aggregators handle combinations of model updates received by the combiner into a combiner-level global model. +Overview +--------- +Aggregators are responsible for combining client model updates into a combiner-level global model. During a training session, the combiners will instantiate an Aggregator and use it to process the incoming model updates from clients. .. image:: img/aggregators.png @@ -11,35 +13,104 @@ During a training session, the combiners will instantiate an Aggregator and use :width: 100% :align: center -The above figure illustrates the overall flow. When a client completes a model update, the model parameters are streamed to the combiner, and a model update message is sent. The model parameters are written to file on disk, and the model update message is passed to a callback function, on_model_update. The callback function validates the model update, and if successful, puts the update message on an aggregation queue. The model parameters are written to disk at a configurable storage location at the combiner. This is done to avoid exhausting RAM memory at the combiner. As multiple clients send updates, the aggregation queue builds up, and when a certain criteria is met, another method, combine_models, starts processing the queue, aggregating models according to the specifics of the scheme (FedAvg, FedAdam, etc). +The figure above illustrates the overall workflow. When a client completes a model update, the model parameters are streamed to the combiner, +and a model update message is sent. The parameters are saved to a file on disk, and the update message is passed to a callback function, ``on_model_update``. +This function validates the model update and, if successful, places the update message in an aggregation queue. +The model parameters are saved to disk at a configurable storage location within the combiner to prevent exhausting RAM. +As multiple clients submit updates, the aggregation queue accumulates. Once specific criteria are met, another method, ``combine_models``, +begins processing the queue, aggregating models according to the specifics of the scheme (e.g., FedAvg, FedAdam). -The user can configure several parameters that guide general behavior of the aggregation flow: + +Using built-in Aggregators +-------------------------- + +FEDn supports the following aggregation algorithms: + +- FedAvg (default) +- FedAdam (FedOpt) +- FedYogi (FedOpt) +- FedAdaGrad (FedOpt) + +The implementation of the methods from the FedOpt family follows the implemenation in this paper: https://arxiv.org/pdf/2003.00295.pdf + +Training sessions can be configured to use a given aggregator. For example, to use FedAdam: + +.. code:: python + + session_config = { + "helper": "numpyhelper", + "id": "experiment_fedadam", + "aggregator": "fedopt", + "aggregator_kwargs": { + "serveropt": "adam", + "learning_rate": 1e-2, + "beta1": 0.9, + "beta2": 0.99, + "tau": 1e-4 + }, + "rounds": 10 + } + + result_fedadam = client.start_session(**session_config) + +.. note:: + + The FedOpt family of methods use server-side momentum. FEDn resets the aggregator for each new session. + This means that the history will will also be reset, i.e. the momentum terms will be forgotten. + When using FedAdam, FedYogi and FedAdaGrad, the user needs to strike a + balance between the number of rounds in the session from a convergence and utility perspective. + +.. note:: + + The parameter ``aggregator_kwargs`` are hyperparameters for the FedOpt family aggregators. The + data types for these parameters (str, float) are validated by the aggregator and an error + will be issued if passing parameter values of incompatible type. All hyperparameters are + given above for completeness. It is primarily the ``learning_rate`` that will require tuning. + +Several additional parameters that guide general behavior of the aggregation flow can be configured: - Round timeout: The maximal time the combiner waits before processing the update queue. - Buffer size: The maximal allowed length of the queue before processing it. - Whether to retain or delete model update files after they have been processed (default is to delete them) +Extending FEDn with new Aggregators +----------------------------------- A developer can extend FEDn with his/her own Aggregator(s) by implementing the interface specified in -:py:mod:`fedn.network.combiner.aggregators.aggregatorbase.AggregatorBase`. The developer implements two following methods: +:py:mod:`fedn.network.combiner.aggregators.aggregatorbase.AggregatorBase`. This involes implementing the two methods: + +- ``on_model_update`` (perform model update validation before update is placed on queue, optional) +- ``combine_models`` (process the queue and aggregate updates) + +**on_model_update** + +The ``on_model_update`` callback recieves the model update messages from clients (including all metadata) and can be used to perform validation and +potential transformation of the model update before it is placed on the aggregation queue (see image above). +The base class implements a default callback that checks that all metadata assumed by the aggregation algorithms FedAvg and FedOpt is available. The callback could also be used to implement custom pre-processing and additional checks including strategies +to filter out updates that are suspected to be corrupted or malicious. -- ``on_model_update`` (optional) -- ``combine_models`` +**combine_models** -on_model_update ----------------- +When a certain criteria is met, e.g. if all clients have sent updates, or the round has times out, the ``combine_model_update`` method +processes the model update queue, producing an aggregated model. This is the main extension point where the +numerical details of the aggregation scheme is implemented. The best way to understand how to implement this method is to study the built-in aggregation algorithms: -The on_model_update has access to the complete model update including the metadata passed on by the clients (as specified in the training entrypoint, see compute package). The base class implements a default callback that checks that all metadata assumed by the aggregation algorithms FedAvg and FedAdam is present in the metadata. However, the callback could also be used to implement custom preprocessing and additional checks including strategies to filter out updates that are suspected to be corrupted or malicious. +- :py:mod:`fedn.network.combiner.aggregators.fedavg` (weighted average of parameters) +- :py:mod:`fedn.network.combiner.aggregators.fedopt` (compute pseudo-gradients and apply a server-side optmizer) -combine_models --------------- +To add an aggregator plugin ``myaggregator``, the developer implements the interface and places a file called ‘myaggregator.py’ in the folder ‘fedn.network.combiner.aggregators’. +This extension can then simply be called as such: -This method is responsible for processing the model update queue and in doing so produce an aggregated model. This is the main extension point where the numerical detail of the aggregation scheme is implemented. The best way to understand how to implement this methods is to study the already implemented algorithms: +.. code:: python -- :py:mod:`fedn.network.combiner.aggregators.fedavg` -- :py:mod:`fedn.network.combiner.aggregators.fedopt` + session_config = { + "helper": "numpyhelper", + "id": "experiment_myaggregator", + "aggregator": "myaggregator", + "rounds": 10 + } -To add an aggregator plugin “myaggregator”, the developer implements the interface and places a file called ‘myaggregator.py’ in the folder ‘fedn.network.combiner.aggregators’. + result_myaggregator = client.start_session(**session_config) diff --git a/docs/conf.py b/docs/conf.py index acd52819f..913c35d9c 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -12,7 +12,7 @@ author = 'Scaleout Systems AB' # The full version, including alpha/beta/rc tags -release = '0.9.1' +release = '0.9.2' # Add any Sphinx extension module names here, as strings extensions = [ diff --git a/docs/faq.rst b/docs/faq.rst index f817a0dee..223aa2e49 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -19,6 +19,17 @@ However, during development of a new model it will be necessary to reinitialize. 2. Restart the clients. +Q: Can I skip fetching the remote package and instead use a local folder when developing the compute package +------------------------------------------------------------------------------------------------------------ + +Yes, to facilitate interactive development of the compute package you can start a client that uses a local folder 'client' in your current working directory by: + +.. code-block:: bash + + fedn client start --remote=False -in client.yaml + + +Note that in production federations this options should in most cases be disallowed. Q: How can other aggregation algorithms can be defined? ------------------------------------------------------- @@ -45,7 +56,7 @@ Yes! You can toggle which message streams a client subscibes to when starting th .. code-block:: bash - fedn run client --trainer=False -in client.yaml + fedn client start --trainer=False -in client.yaml Q: How do you approach the question of output privacy? diff --git a/docs/quickstart.rst b/docs/quickstart.rst index aa273ba98..fa5b83eaa 100644 --- a/docs/quickstart.rst +++ b/docs/quickstart.rst @@ -34,8 +34,8 @@ Clone the FEDn repository and install the package: .. code-block:: bash git clone https://github.com/scaleoutsystems/fedn.git - cd fedn/fedn - pip install -e . + cd fedn + pip install . It is recommended to use a virtual environment when installing FEDn. @@ -114,15 +114,15 @@ For example, to split the data in 10 parts and start a client using the 8th part export FEDN_PACKAGE_EXTRACT_DIR=package export FEDN_NUM_DATA_SPLITS=10 - export FEDN_DATA_PATH=package/data/clients/8/mnist.pt - fedn run client -in client.yaml --secure=True --force-ssl + export FEDN_DATA_PATH=./data/clients/8/mnist.pt + fedn client start -in client.yaml --secure=True --force-ssl .. code-tab:: bash :caption: Windows (Powershell) $env:FEDN_PACKAGE_EXTRACT_DIR="package" $env:FEDN_NUM_DATA_SPLITS=10 - $env:FEDN_DATA_PATH="package/data/clients/8/mnist.pt" + $env:FEDN_DATA_PATH="./data/clients/8/mnist.pt" fedn run client -in client.yaml --secure=True --force-ssl diff --git a/examples/async-clients/client/entrypoint.py b/examples/async-clients/client/entrypoint.py index 4ddddd956..220b5299b 100644 --- a/examples/async-clients/client/entrypoint.py +++ b/examples/async-clients/client/entrypoint.py @@ -8,7 +8,7 @@ from fedn.utils.helpers.helpers import get_helper, save_metadata, save_metrics -HELPER_MODULE = 'numpyhelper' +HELPER_MODULE = "numpyhelper" ARRAY_SIZE = 10000 @@ -22,7 +22,7 @@ def compile_model(max_iter=1): def save_parameters(model, out_path): - """ Save model to disk. + """Save model to disk. :param model: The model to save. :type model: torch.nn.Module @@ -36,7 +36,7 @@ def save_parameters(model, out_path): def load_parameters(model_path): - """ Load model from disk. + """Load model from disk. param model_path: The path to load from. :type model_path: str @@ -49,8 +49,8 @@ def load_parameters(model_path): return parameters -def init_seed(out_path='seed.npz'): - """ Initialize seed model. +def init_seed(out_path="seed.npz"): + """Initialize seed model. :param out_path: The path to save the seed model to. :type out_path: str @@ -61,7 +61,7 @@ def init_seed(out_path='seed.npz'): def make_data(n_min=50, n_max=100): - """ Generate / simulate a random number n data points. + """Generate / simulate a random number n data points. n will fall in the interval (n_min, n_max) @@ -78,14 +78,12 @@ def make_data(n_min=50, n_max=100): def train(in_model_path, out_model_path): - """ Train model. - - """ + """Train model.""" # Load model parameters = load_parameters(in_model_path) model = compile_model() - n = len(parameters)//2 + n = len(parameters) // 2 model.coefs_ = parameters[:n] model.intercepts_ = parameters[n:] @@ -97,7 +95,7 @@ def train(in_model_path, out_model_path): # Metadata needed for aggregation server side metadata = { - 'num_examples': len(X_train), + "num_examples": len(X_train), } # Save JSON metadata file @@ -108,7 +106,7 @@ def train(in_model_path, out_model_path): def validate(in_model_path, out_json_path): - """ Validate model. + """Validate model. :param in_model_path: The path to the input model. :type in_model_path: str @@ -119,7 +117,7 @@ def validate(in_model_path, out_json_path): """ parameters = load_parameters(in_model_path) model = compile_model() - n = len(parameters)//2 + n = len(parameters) // 2 model.coefs_ = parameters[:n] model.intercepts_ = parameters[n:] @@ -134,9 +132,5 @@ def validate(in_model_path, out_json_path): save_metrics(report, out_json_path) -if __name__ == '__main__': - fire.Fire({ - 'init_seed': init_seed, - 'train': train, - 'validate': validate - }) +if __name__ == "__main__": + fire.Fire({"init_seed": init_seed, "train": train, "validate": validate}) diff --git a/examples/async-clients/init_fedn.py b/examples/async-clients/init_fedn.py index 2aa298602..8677c472d 100644 --- a/examples/async-clients/init_fedn.py +++ b/examples/async-clients/init_fedn.py @@ -1,8 +1,8 @@ from fedn import APIClient -DISCOVER_HOST = '127.0.0.1' +DISCOVER_HOST = "127.0.0.1" DISCOVER_PORT = 8092 client = APIClient(DISCOVER_HOST, DISCOVER_PORT) -client.set_active_package('package.tgz', 'numpyhelper') -client.set_active_model('seed.npz') +client.set_active_package("package.tgz", "numpyhelper") +client.set_active_model("seed.npz") diff --git a/examples/async-clients/run_clients.py b/examples/async-clients/run_clients.py index 5b56c53c9..82da30ad9 100644 --- a/examples/async-clients/run_clients.py +++ b/examples/async-clients/run_clients.py @@ -26,24 +26,39 @@ # Use with a local deployment settings = { - 'DISCOVER_HOST': '127.0.0.1', - 'DISCOVER_PORT': 8092, - 'TOKEN': None, - 'N_CLIENTS': 10, - 'N_CYCLES': 100, - 'CLIENTS_MAX_DELAY': 10, - 'CLIENTS_ONLINE_FOR_SECONDS': 120 + "DISCOVER_HOST": "127.0.0.1", + "DISCOVER_PORT": 8092, + "TOKEN": None, + "N_CLIENTS": 10, + "N_CYCLES": 100, + "CLIENTS_MAX_DELAY": 10, + "CLIENTS_ONLINE_FOR_SECONDS": 120, } -client_config = {'discover_host': settings['DISCOVER_HOST'], 'discover_port': settings['DISCOVER_PORT'], 'token': settings['TOKEN'], 'name': 'testclient', - 'client_id': 1, 'remote_compute_context': True, 'force_ssl': False, 'dry_run': False, 'secure': False, - 'preshared_cert': False, 'verify': False, 'preferred_combiner': False, - 'validator': True, 'trainer': True, 'init': None, 'logfile': 'test.log', 'heartbeat_interval': 2, - 'reconnect_after_missed_heartbeat': 30} +client_config = { + "discover_host": settings["DISCOVER_HOST"], + "discover_port": settings["DISCOVER_PORT"], + "token": settings["TOKEN"], + "name": "testclient", + "client_id": 1, + "remote_compute_context": True, + "force_ssl": False, + "dry_run": False, + "secure": False, + "preshared_cert": False, + "verify": False, + "preferred_combiner": False, + "validator": True, + "trainer": True, + "init": None, + "logfile": "test.log", + "heartbeat_interval": 2, + "reconnect_after_missed_heartbeat": 30, +} -def run_client(online_for=120, name='client'): - """ Simulates a client that starts and stops +def run_client(online_for=120, name="client"): + """Simulates a client that starts and stops at random intervals. The client will start after a radom time 'mean_delay', @@ -55,23 +70,28 @@ def run_client(online_for=120, name='client'): """ conf = copy.deepcopy(client_config) - conf['name'] = name + conf["name"] = name - for i in range(settings['N_CYCLES']): + for i in range(settings["N_CYCLES"]): # Sample a delay until the client starts - t_start = np.random.randint(0, settings['CLIENTS_MAX_DELAY']) + t_start = np.random.randint(0, settings["CLIENTS_MAX_DELAY"]) time.sleep(t_start) fl_client = Client(conf) time.sleep(online_for) fl_client.disconnect() -if __name__ == '__main__': - +if __name__ == "__main__": # We start N_CLIENTS independent client processes processes = [] - for i in range(settings['N_CLIENTS']): - p = Process(target=run_client, args=(settings['CLIENTS_ONLINE_FOR_SECONDS'], 'client{}'.format(i),)) + for i in range(settings["N_CLIENTS"]): + p = Process( + target=run_client, + args=( + settings["CLIENTS_ONLINE_FOR_SECONDS"], + "client{}".format(i), + ), + ) processes.append(p) p.start() diff --git a/examples/async-clients/run_experiment.py b/examples/async-clients/run_experiment.py index d8d12dca2..f6d5e1f6d 100644 --- a/examples/async-clients/run_experiment.py +++ b/examples/async-clients/run_experiment.py @@ -3,16 +3,14 @@ from fedn import APIClient -DISCOVER_HOST = '127.0.0.1' +DISCOVER_HOST = "127.0.0.1" DISCOVER_PORT = 8092 client = APIClient(DISCOVER_HOST, DISCOVER_PORT) -if __name__ == '__main__': - +if __name__ == "__main__": # Run six sessions, each with 100 rounds. num_sessions = 6 for s in range(num_sessions): - session_config = { "helper": "numpyhelper", "id": str(uuid.uuid4()), @@ -23,12 +21,12 @@ } session = client.start_session(**session_config) - if session['success'] is False: - print(session['message']) + if session["success"] is False: + print(session["message"]) exit(0) print("Started session: {}".format(session)) # Wait for session to finish - while not client.session_is_finished(session_config['id']): + while not client.session_is_finished(session_config["id"]): time.sleep(2) diff --git a/examples/flower-client/README.rst b/examples/flower-client/README.rst index 9cfa617dc..fff8e20b3 100644 --- a/examples/flower-client/README.rst +++ b/examples/flower-client/README.rst @@ -62,7 +62,7 @@ On your local machine / client, start the FEDn client: .. code-block:: - fedn run client -in client.yaml --secure=True --force-ssl + fedn client start -in client.yaml --secure=True --force-ssl Or, if you prefer to use Docker (this might take a long time): diff --git a/examples/flower-client/client/entrypoint.py b/examples/flower-client/client/entrypoint.py index e80e7d4b5..1a9a8b8cf 100755 --- a/examples/flower-client/client/entrypoint.py +++ b/examples/flower-client/client/entrypoint.py @@ -56,9 +56,7 @@ def train(in_model_path, out_model_path): parameters_np = helper.load(in_model_path) # Train on flower client - params, num_examples = flwr_adapter.train( - parameters=parameters_np, partition_id=_get_node_id(), config={} - ) + params, num_examples = flwr_adapter.train(parameters=parameters_np, partition_id=_get_node_id(), config={}) # Metadata needed for aggregation server side metadata = { diff --git a/examples/flower-client/client/flwr_client.py b/examples/flower-client/client/flwr_client.py index 297df3ca7..843e9d31d 100644 --- a/examples/flower-client/client/flwr_client.py +++ b/examples/flower-client/client/flwr_client.py @@ -3,8 +3,7 @@ """ from flwr.client import ClientApp, NumPyClient -from flwr_task import (DEVICE, Net, get_weights, load_data, set_weights, test, - train) +from flwr_task import DEVICE, Net, get_weights, load_data, set_weights, test, train # Define FlowerClient and client_fn @@ -12,9 +11,7 @@ class FlowerClient(NumPyClient): def __init__(self, cid) -> None: super().__init__() self.net = Net().to(DEVICE) - self.trainloader, self.testloader = load_data( - partition_id=int(cid), num_clients=10 - ) + self.trainloader, self.testloader = load_data(partition_id=int(cid), num_clients=10) def get_parameters(self, config): return [val.cpu().numpy() for _, val in self.net.state_dict().items()] diff --git a/examples/flower-client/client/flwr_task.py b/examples/flower-client/client/flwr_task.py index dea53d9cc..6cd908177 100644 --- a/examples/flower-client/client/flwr_task.py +++ b/examples/flower-client/client/flwr_task.py @@ -5,9 +5,9 @@ from collections import OrderedDict import torch -import torch.nn as nn -import torch.nn.functional as F +import torch.nn.functional as F # noqa: N812 from flwr_datasets import FederatedDataset +from torch import nn from torch.utils.data import DataLoader from torchvision.transforms import Compose, Normalize, ToTensor from tqdm import tqdm @@ -42,9 +42,7 @@ def load_data(partition_id, num_clients): partition = fds.load_partition(partition_id) # Divide data on each node: 80% train, 20% test partition_train_test = partition.train_test_split(test_size=0.2) - pytorch_transforms = Compose( - [ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] - ) + pytorch_transforms = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) def apply_transforms(batch): """Apply transforms to the partition from FederatedDataset.""" diff --git a/examples/flower-client/client/python_env.yaml b/examples/flower-client/client/python_env.yaml index 4c5fc4668..984a2e96d 100644 --- a/examples/flower-client/client/python_env.yaml +++ b/examples/flower-client/client/python_env.yaml @@ -8,4 +8,4 @@ dependencies: - torchvision==0.17.1 - fire==0.3.1 - fedn[flower]==0.9.0 - - flwr-datasets[vision]==0.1.0 + - flwr-datasets[vision]==0.1.0 \ No newline at end of file diff --git a/examples/flower-client/init_fedn.py b/examples/flower-client/init_fedn.py index 23078fcd9..81864293b 100644 --- a/examples/flower-client/init_fedn.py +++ b/examples/flower-client/init_fedn.py @@ -1,8 +1,8 @@ from fedn import APIClient -DISCOVER_HOST = '127.0.0.1' +DISCOVER_HOST = "127.0.0.1" DISCOVER_PORT = 8092 client = APIClient(DISCOVER_HOST, DISCOVER_PORT) -client.set_package('package.tgz', 'numpyhelper') -client.set_initial_model('seed.npz') +client.set_package("package.tgz", "numpyhelper") +client.set_initial_model("seed.npz") diff --git a/examples/mnist-keras/client/entrypoint.py b/examples/mnist-keras/client/entrypoint.py index 15c0693f2..5420a78bb 100755 --- a/examples/mnist-keras/client/entrypoint.py +++ b/examples/mnist-keras/client/entrypoint.py @@ -7,7 +7,7 @@ from fedn.utils.helpers.helpers import get_helper, save_metadata, save_metrics -HELPER_MODULE = 'numpyhelper' +HELPER_MODULE = "numpyhelper" helper = get_helper(HELPER_MODULE) NUM_CLASSES = 10 @@ -17,13 +17,13 @@ def _get_data_path(): - data_path = os.environ.get('FEDN_DATA_PATH', abs_path + '/data/clients/1/mnist.npz') + data_path = os.environ.get("FEDN_DATA_PATH", abs_path + "/data/clients/1/mnist.npz") return data_path def compile_model(img_rows=28, img_cols=28): - """ Compile the TF model. + """Compile the TF model. param: img_rows: The number of rows in the image type: img_rows: int @@ -38,13 +38,11 @@ def compile_model(img_rows=28, img_cols=28): # Define model model = tf.keras.models.Sequential() model.add(tf.keras.layers.Flatten(input_shape=input_shape)) - model.add(tf.keras.layers.Dense(64, activation='relu')) + model.add(tf.keras.layers.Dense(64, activation="relu")) model.add(tf.keras.layers.Dropout(0.5)) - model.add(tf.keras.layers.Dense(32, activation='relu')) - model.add(tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')) - model.compile(loss=tf.keras.losses.categorical_crossentropy, - optimizer=tf.keras.optimizers.Adam(), - metrics=['accuracy']) + model.add(tf.keras.layers.Dense(32, activation="relu")) + model.add(tf.keras.layers.Dense(NUM_CLASSES, activation="softmax")) + model.compile(loss=tf.keras.losses.categorical_crossentropy, optimizer=tf.keras.optimizers.Adam(), metrics=["accuracy"]) return model @@ -57,14 +55,14 @@ def load_data(data_path, is_train=True): data = np.load(data_path) if is_train: - X = data['x_train'] - y = data['y_train'] + X = data["x_train"] + y = data["y_train"] else: - X = data['x_test'] - y = data['y_test'] + X = data["x_test"] + y = data["y_test"] # Normalize - X = X.astype('float32') + X = X.astype("float32") X = np.expand_dims(X, -1) X = X / 255 y = tf.keras.utils.to_categorical(y, NUM_CLASSES) @@ -72,8 +70,8 @@ def load_data(data_path, is_train=True): return X, y -def init_seed(out_path='../seed.npz'): - """ Initialize seed model and save it to file. +def init_seed(out_path="../seed.npz"): + """Initialize seed model and save it to file. :param out_path: The path to save the seed model to. :type out_path: str @@ -83,7 +81,7 @@ def init_seed(out_path='../seed.npz'): def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1): - """ Complete a model update. + """Complete a model update. Load model paramters from in_model_path (managed by the FEDn client), perform a model update, and write updated paramters @@ -114,9 +112,9 @@ def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1 # Metadata needed for aggregation server side metadata = { # num_examples are mandatory - 'num_examples': len(x_train), - 'batch_size': batch_size, - 'epochs': epochs, + "num_examples": len(x_train), + "batch_size": batch_size, + "epochs": epochs, } # Save JSON metadata file (mandatory) @@ -128,7 +126,7 @@ def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1 def validate(in_model_path, out_json_path, data_path=None): - """ Validate model. + """Validate model. :param in_model_path: The path to the input model. :type in_model_path: str @@ -182,14 +180,16 @@ def predict(in_model_path, out_json_path, data_path=None): # Save JSON with open(out_json_path, "w") as fh: - fh.write(json.dumps({'predictions': y_pred.tolist()})) - - -if __name__ == '__main__': - fire.Fire({ - 'init_seed': init_seed, - 'train': train, - 'validate': validate, - 'predict': predict, - '_get_data_path': _get_data_path, # for testing - }) + fh.write(json.dumps({"predictions": y_pred.tolist()})) + + +if __name__ == "__main__": + fire.Fire( + { + "init_seed": init_seed, + "train": train, + "validate": validate, + "predict": predict, + "_get_data_path": _get_data_path, # for testing + } + ) diff --git a/examples/mnist-keras/client/get_data.py b/examples/mnist-keras/client/get_data.py index 28a12bd20..ed123a4a3 100755 --- a/examples/mnist-keras/client/get_data.py +++ b/examples/mnist-keras/client/get_data.py @@ -7,14 +7,14 @@ def splitset(dataset, parts): n = dataset.shape[0] - local_n = floor(n/parts) + local_n = floor(n / parts) result = [] for i in range(parts): - result.append(dataset[i*local_n: (i+1)*local_n]) + result.append(dataset[i * local_n : (i + 1) * local_n]) return np.array(result) -def split(dataset='data/mnist.npz', outdir='data', n_splits=2): +def split(dataset="data/mnist.npz", outdir="data", n_splits=2): # Load and convert to dict package = np.load(dataset) data = {} @@ -22,32 +22,27 @@ def split(dataset='data/mnist.npz', outdir='data', n_splits=2): data[key] = splitset(val, n_splits) # Make dir if necessary - if not os.path.exists(f'{outdir}/clients'): - os.mkdir(f'{outdir}/clients') + if not os.path.exists(f"{outdir}/clients"): + os.mkdir(f"{outdir}/clients") # Make splits for i in range(n_splits): - subdir = f'{outdir}/clients/{str(i+1)}' + subdir = f"{outdir}/clients/{str(i+1)}" if not os.path.exists(subdir): os.mkdir(subdir) - np.savez(f'{subdir}/mnist.npz', - x_train=data['x_train'][i], - y_train=data['y_train'][i], - x_test=data['x_test'][i], - y_test=data['y_test'][i]) + np.savez(f"{subdir}/mnist.npz", x_train=data["x_train"][i], y_train=data["y_train"][i], x_test=data["x_test"][i], y_test=data["y_test"][i]) -def get_data(out_dir='data'): +def get_data(out_dir="data"): # Make dir if necessary if not os.path.exists(out_dir): os.mkdir(out_dir) # Download data (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() - np.savez(f'{out_dir}/mnist.npz', x_train=x_train, - y_train=y_train, x_test=x_test, y_test=y_test) + np.savez(f"{out_dir}/mnist.npz", x_train=x_train, y_train=y_train, x_test=x_test, y_test=y_test) -if __name__ == '__main__': +if __name__ == "__main__": get_data() split() diff --git a/examples/mnist-pytorch/README.rst b/examples/mnist-pytorch/README.rst index 4fa022bfd..71e5de2d1 100644 --- a/examples/mnist-pytorch/README.rst +++ b/examples/mnist-pytorch/README.rst @@ -72,8 +72,8 @@ For example, to split the data in 10 parts and start a client using the 8th part export FEDN_PACKAGE_EXTRACT_DIR=package export FEDN_NUM_DATA_SPLITS=10 - export FEDN_DATA_PATH=package/data/clients/8/mnist.pt - fedn run client -in client.yaml --secure=True --force-ssl + export FEDN_DATA_PATH=./data/clients/8/mnist.pt + fedn client start -in client.yaml --secure=True --force-ssl The default is to split the data into 2 partitions and use the first partition. diff --git a/examples/mnist-pytorch/client/data.py b/examples/mnist-pytorch/client/data.py index d67274548..b921f3132 100644 --- a/examples/mnist-pytorch/client/data.py +++ b/examples/mnist-pytorch/client/data.py @@ -8,22 +8,20 @@ abs_path = os.path.abspath(dir_path) -def get_data(out_dir='data'): +def get_data(out_dir="data"): # Make dir if necessary if not os.path.exists(out_dir): os.mkdir(out_dir) # Only download if not already downloaded - if not os.path.exists(f'{out_dir}/train'): - torchvision.datasets.MNIST( - root=f'{out_dir}/train', transform=torchvision.transforms.ToTensor, train=True, download=True) - if not os.path.exists(f'{out_dir}/test'): - torchvision.datasets.MNIST( - root=f'{out_dir}/test', transform=torchvision.transforms.ToTensor, train=False, download=True) + if not os.path.exists(f"{out_dir}/train"): + torchvision.datasets.MNIST(root=f"{out_dir}/train", transform=torchvision.transforms.ToTensor, train=True, download=True) + if not os.path.exists(f"{out_dir}/test"): + torchvision.datasets.MNIST(root=f"{out_dir}/test", transform=torchvision.transforms.ToTensor, train=False, download=True) def load_data(data_path, is_train=True): - """ Load data from disk. + """Load data from disk. :param data_path: Path to data file. :type data_path: str @@ -33,16 +31,16 @@ def load_data(data_path, is_train=True): :rtype: tuple """ if data_path is None: - data_path = os.environ.get("FEDN_DATA_PATH", abs_path+'/data/clients/1/mnist.pt') + data_path = os.environ.get("FEDN_DATA_PATH", abs_path + "/data/clients/1/mnist.pt") data = torch.load(data_path) if is_train: - X = data['x_train'] - y = data['y_train'] + X = data["x_train"] + y = data["y_train"] else: - X = data['x_test'] - y = data['y_test'] + X = data["x_test"] + y = data["y_test"] # Normalize X = X / 255 @@ -52,49 +50,48 @@ def load_data(data_path, is_train=True): def splitset(dataset, parts): n = dataset.shape[0] - local_n = floor(n/parts) + local_n = floor(n / parts) result = [] for i in range(parts): - result.append(dataset[i*local_n: (i+1)*local_n]) + result.append(dataset[i * local_n : (i + 1) * local_n]) return result -def split(out_dir='data'): - +def split(out_dir="data"): n_splits = int(os.environ.get("FEDN_NUM_DATA_SPLITS", 2)) # Make dir - if not os.path.exists(f'{out_dir}/clients'): - os.mkdir(f'{out_dir}/clients') + if not os.path.exists(f"{out_dir}/clients"): + os.mkdir(f"{out_dir}/clients") # Load and convert to dict - train_data = torchvision.datasets.MNIST( - root=f'{out_dir}/train', transform=torchvision.transforms.ToTensor, train=True) - test_data = torchvision.datasets.MNIST( - root=f'{out_dir}/test', transform=torchvision.transforms.ToTensor, train=False) + train_data = torchvision.datasets.MNIST(root=f"{out_dir}/train", transform=torchvision.transforms.ToTensor, train=True) + test_data = torchvision.datasets.MNIST(root=f"{out_dir}/test", transform=torchvision.transforms.ToTensor, train=False) data = { - 'x_train': splitset(train_data.data, n_splits), - 'y_train': splitset(train_data.targets, n_splits), - 'x_test': splitset(test_data.data, n_splits), - 'y_test': splitset(test_data.targets, n_splits), + "x_train": splitset(train_data.data, n_splits), + "y_train": splitset(train_data.targets, n_splits), + "x_test": splitset(test_data.data, n_splits), + "y_test": splitset(test_data.targets, n_splits), } # Make splits for i in range(n_splits): - subdir = f'{out_dir}/clients/{str(i+1)}' + subdir = f"{out_dir}/clients/{str(i+1)}" if not os.path.exists(subdir): os.mkdir(subdir) - torch.save({ - 'x_train': data['x_train'][i], - 'y_train': data['y_train'][i], - 'x_test': data['x_test'][i], - 'y_test': data['y_test'][i], - }, - f'{subdir}/mnist.pt') - - -if __name__ == '__main__': + torch.save( + { + "x_train": data["x_train"][i], + "y_train": data["y_train"][i], + "x_test": data["x_test"][i], + "y_test": data["y_test"][i], + }, + f"{subdir}/mnist.pt", + ) + + +if __name__ == "__main__": # Prepare data if not already done - if not os.path.exists(abs_path+'/data/clients/1'): + if not os.path.exists(abs_path + "/data/clients/1"): get_data() split() diff --git a/examples/mnist-pytorch/client/model.py b/examples/mnist-pytorch/client/model.py index cb5b7afc2..6ad344770 100644 --- a/examples/mnist-pytorch/client/model.py +++ b/examples/mnist-pytorch/client/model.py @@ -4,16 +4,17 @@ from fedn.utils.helpers.helpers import get_helper -HELPER_MODULE = 'numpyhelper' +HELPER_MODULE = "numpyhelper" helper = get_helper(HELPER_MODULE) def compile_model(): - """ Compile the pytorch model. + """Compile the pytorch model. :return: The compiled model. :rtype: torch.nn.Module """ + class Net(torch.nn.Module): def __init__(self): super(Net, self).__init__() @@ -32,7 +33,7 @@ def forward(self, x): def save_parameters(model, out_path): - """ Save model paramters to file. + """Save model paramters to file. :param model: The model to serialize. :type model: torch.nn.Module @@ -44,7 +45,7 @@ def save_parameters(model, out_path): def load_parameters(model_path): - """ Load model parameters from file and populate model. + """Load model parameters from file and populate model. param model_path: The path to load from. :type model_path: str @@ -60,8 +61,8 @@ def load_parameters(model_path): return model -def init_seed(out_path='seed.npz'): - """ Initialize seed model and save it to file. +def init_seed(out_path="seed.npz"): + """Initialize seed model and save it to file. :param out_path: The path to save the seed model to. :type out_path: str @@ -72,4 +73,4 @@ def init_seed(out_path='seed.npz'): if __name__ == "__main__": - init_seed('../seed.npz') + init_seed("../seed.npz") diff --git a/examples/mnist-pytorch/client/train.py b/examples/mnist-pytorch/client/train.py index fdf2480ae..9ac9cce61 100644 --- a/examples/mnist-pytorch/client/train.py +++ b/examples/mnist-pytorch/client/train.py @@ -3,9 +3,9 @@ import sys import torch -from data import load_data from model import load_parameters, save_parameters +from data import load_data from fedn.utils.helpers.helpers import save_metadata dir_path = os.path.dirname(os.path.realpath(__file__)) @@ -13,7 +13,7 @@ def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1, lr=0.01): - """ Complete a model update. + """Complete a model update. Load model paramters from in_model_path (managed by the FEDn client), perform a model update, and write updated paramters @@ -45,8 +45,8 @@ def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1 for e in range(epochs): # epoch loop for b in range(n_batches): # batch loop # Retrieve current batch - batch_x = x_train[b * batch_size:(b + 1) * batch_size] - batch_y = y_train[b * batch_size:(b + 1) * batch_size] + batch_x = x_train[b * batch_size : (b + 1) * batch_size] + batch_y = y_train[b * batch_size : (b + 1) * batch_size] # Train on batch optimizer.zero_grad() outputs = model(batch_x) @@ -55,16 +55,15 @@ def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1 optimizer.step() # Log if b % 100 == 0: - print( - f"Epoch {e}/{epochs-1} | Batch: {b}/{n_batches-1} | Loss: {loss.item()}") + print(f"Epoch {e}/{epochs-1} | Batch: {b}/{n_batches-1} | Loss: {loss.item()}") # Metadata needed for aggregation server side metadata = { # num_examples are mandatory - 'num_examples': len(x_train), - 'batch_size': batch_size, - 'epochs': epochs, - 'lr': lr + "num_examples": len(x_train), + "batch_size": batch_size, + "epochs": epochs, + "lr": lr, } # Save JSON metadata file (mandatory) diff --git a/examples/mnist-pytorch/client/validate.py b/examples/mnist-pytorch/client/validate.py index 0a5592368..09328181f 100644 --- a/examples/mnist-pytorch/client/validate.py +++ b/examples/mnist-pytorch/client/validate.py @@ -2,9 +2,9 @@ import sys import torch -from data import load_data from model import load_parameters +from data import load_data from fedn.utils.helpers.helpers import save_metrics dir_path = os.path.dirname(os.path.realpath(__file__)) @@ -12,7 +12,7 @@ def validate(in_model_path, out_json_path, data_path=None): - """ Validate model. + """Validate model. :param in_model_path: The path to the input model. :type in_model_path: str @@ -34,12 +34,10 @@ def validate(in_model_path, out_json_path, data_path=None): with torch.no_grad(): train_out = model(x_train) training_loss = criterion(train_out, y_train) - training_accuracy = torch.sum(torch.argmax( - train_out, dim=1) == y_train) / len(train_out) + training_accuracy = torch.sum(torch.argmax(train_out, dim=1) == y_train) / len(train_out) test_out = model(x_test) test_loss = criterion(test_out, y_test) - test_accuracy = torch.sum(torch.argmax( - test_out, dim=1) == y_test) / len(test_out) + test_accuracy = torch.sum(torch.argmax(test_out, dim=1) == y_test) / len(test_out) # JSON schema report = { diff --git a/examples/mnist-pytorch/docker-compose.override.yaml b/examples/mnist-pytorch/docker-compose.override.yaml index eb7aedc47..822a696dc 100644 --- a/examples/mnist-pytorch/docker-compose.override.yaml +++ b/examples/mnist-pytorch/docker-compose.override.yaml @@ -3,8 +3,7 @@ version: '3.4' # Overriding requirements -x-env: - &defaults +x-env: &defaults GET_HOSTS_FROM: dns FEDN_PACKAGE_EXTRACT_DIR: package FEDN_NUM_DATA_SPLITS: 2 @@ -17,7 +16,7 @@ services: service: client environment: <<: *defaults - FEDN_DATA_PATH: /app/package/data/clients/1/mnist.pt + FEDN_DATA_PATH: /app/package/client/data/clients/1/mnist.pt deploy: replicas: 1 volumes: @@ -29,7 +28,7 @@ services: service: client environment: <<: *defaults - FEDN_DATA_PATH: /app/package/data/clients/2/mnist.pt + FEDN_DATA_PATH: /app/package/client/data/clients/2/mnist.pt deploy: replicas: 1 volumes: diff --git a/examples/mnist-pytorch/API_Example.ipynb b/examples/notebooks/API_Example.ipynb similarity index 99% rename from examples/mnist-pytorch/API_Example.ipynb rename to examples/notebooks/API_Example.ipynb index 59fced6ee..d5d0e616d 100644 --- a/examples/mnist-pytorch/API_Example.ipynb +++ b/examples/notebooks/API_Example.ipynb @@ -5,7 +5,7 @@ "id": "622f7047", "metadata": {}, "source": [ - "## FEDn Quickstart (PyTorch)\n", + "## API Example\n", "\n", "This notebook provides an example of how to use the FEDn API to organize experiments and to analyze validation results. We will here run one training session (a collection of global rounds) using FedAvg, then retrive and visualize the results.\n", "\n", @@ -71,8 +71,8 @@ } ], "source": [ - "client.set_active_package('package.tgz', 'numpyhelper')\n", - "client.set_active_model('seed.npz')\n", + "client.set_active_package('../mnist-pytorch/package.tgz', 'numpyhelper')\n", + "client.set_active_model('../mnist-pytorch/seed.npz')\n", "seed_model = client.get_active_model()\n", "print(seed_model)" ] @@ -124,7 +124,6 @@ "metadata": {}, "outputs": [], "source": [ - "session_id = \"experiment1\"\n", "models = client.get_model_trail()\n", "\n", "acc = []\n", diff --git a/examples/notebooks/Aggregators.ipynb b/examples/notebooks/Aggregators.ipynb new file mode 100644 index 000000000..27b0a22f7 --- /dev/null +++ b/examples/notebooks/Aggregators.ipynb @@ -0,0 +1,318 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "622f7047", + "metadata": {}, + "source": [ + "## Use of different Aggregators\n", + "\n", + "This notebook shows how to use different Aggregators (FedAvg, FedAdam, FedYogi, FedAdaGrad). \n", + "\n", + "When you start this tutorial you should either have an account and project in FEDn Studio, or have deployed a FEDn in pseudo-distributed mode. You should also have created the compute package and the initial model, see README.md for instructions.\n", + "\n", + " \n", + "Note that this notebook is intended to showcase the aggregator API. Fine-tuning of the server-side hyperparameters would be necessary for optimal performance and will need to be done a use-case basis." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "743dfe47", + "metadata": {}, + "outputs": [], + "source": [ + "from fedn import APIClient\n", + "import time\n", + "import uuid\n", + "import json\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import collections" + ] + }, + { + "cell_type": "markdown", + "id": "1046a4e5", + "metadata": {}, + "source": [ + "We make a client connection to the FEDn API service. Here we assume that FEDn is deployed locally in pseudo-distributed mode with default ports. To connect to Studio, generate an API token from the UI, and retrive the controller host URI from the Dashboard. " + ] + }, + { + "cell_type": "markdown", + "id": "8a5e4583-a6f4-456c-96e5-22f38c2c0ba8", + "metadata": {}, + "source": [ + "#### If using a local development deployment" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "1061722d", + "metadata": {}, + "outputs": [], + "source": [ + "CONTROLLER_HOST = \"127.0.0.1\"\n", + "CONTROLLER_PORT = 8092\n", + "client = APIClient(CONTROLLER_HOST, CONTROLLER_PORT)" + ] + }, + { + "cell_type": "markdown", + "id": "da10cf9e-a4fd-41b5-98ef-80fbbfa7b56d", + "metadata": {}, + "source": [ + "#### If using a FEDn Studio project" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "2e5948f2-178c-4b66-93f1-ccb72b897133", + "metadata": {}, + "outputs": [], + "source": [ + "# Get the controller host for your project from the Dashboard page in Studio\n", + "#CONTROLLER_HOST = \"fedn.scaleoutsystems.com/aggtestproject-etq-fedn-reducer\"\n", + "# Generate an API token from Settings->Generate token (the one below is just an example, it will not work)\n", + "#TOKEN = \"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.bl90eXBlIjoiYWNjZXNzIiwiZXhwIjoxNzE3MDE5MjIwLCJpYXQiOjE3MTQ0MIjQ0OWYyODQwMDQ3NzQxMzll9pZCI6MzcsImNyZWF0b3IiOiJhbmRyZWFzaCIsInJvbGUiOiJhZG1pbiIsInByb2plY3Rfc2x1ZyI6ImFnZ3Rlc3Rwcm9qZWN0LWV0cSJ9.P1RwQElLy3kx3h2o9uE-TUICT4CLlgrrM9YuRasCrBM\"\n", + "#client = APIClient(CONTROLLER_HOST, token=TOKEN, secure=True,verify=True)" + ] + }, + { + "cell_type": "markdown", + "id": "07f69f5f", + "metadata": {}, + "source": [ + "Initialize FEDn with the compute package and seed model. Note that these files needs to be created separately by follwing instructions in the README." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5107f6f9", + "metadata": {}, + "outputs": [], + "source": [ + "client.set_active_package('../mnist-pytorch/package.tgz', 'numpyhelper')\n", + "client.set_active_model('../mnist-pytorch/seed.npz')\n", + "seed_model = client.get_active_model()\n", + "print(seed_model)" + ] + }, + { + "cell_type": "markdown", + "id": "4e26c50b", + "metadata": {}, + "source": [ + "### FedAvg" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f0380d35", + "metadata": {}, + "outputs": [], + "source": [ + "session_id = \"experiment_fedavg\"\n", + "\n", + "session_config = {\n", + " \"helper\": \"numpyhelper\",\n", + " \"id\": session_id,\n", + " \"model_id\": seed_model['model'],\n", + " \"rounds\": 10\n", + " }\n", + "\n", + "result_fedavg = client.start_session(**session_config)" + ] + }, + { + "cell_type": "markdown", + "id": "6cd9b26b-9ea7-4d65-9f23-7f14c54c4ef0", + "metadata": {}, + "source": [ + "### FedAdam\n", + "\n", + "Here we use the FedOpt family of aggregation algorithms. FEDn support adam, yogi and adagrad as server side optimizers. In the session_config below we illustrate how to set hyperparamters (will be valid for this session). The values below are actually the default values and are passed here for illustrative purposes.\n", + "\n", + "**Note that the server side-momentum terms are only retained within one session - each new session you will re-initialize the optimizer to default values.** " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dfd69029-4ec5-4ebc-9121-b8ecde421afe", + "metadata": {}, + "outputs": [], + "source": [ + "session_id = \"experiment_fedadam\"\n", + "\n", + "session_config = {\n", + " \"helper\": \"numpyhelper\",\n", + " \"id\": session_id,\n", + " \"aggregator\": \"fedopt\",\n", + " \"aggregator_kwargs\": {\n", + " \"serveropt\": \"adam\",\n", + " \"learning_rate\": 1e-2,\n", + " \"beta1\": 0.9,\n", + " \"beta2\": 0.99,\n", + " \"tau\": 1e-4\n", + " },\n", + " \"model_id\": seed_model['model'],\n", + " \"rounds\": 10\n", + " }\n", + "\n", + "result_fedadam = client.start_session(**session_config)" + ] + }, + { + "cell_type": "markdown", + "id": "a6d2de1f-4b28-402f-a6c3-1b91beb0b889", + "metadata": {}, + "source": [ + "### FedYogi" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "018b4a7a-96af-49ca-8c52-6bac7a9de357", + "metadata": {}, + "outputs": [], + "source": [ + "session_id = \"experiment_fedyogi\"\n", + "\n", + "session_config = {\n", + " \"helper\": \"numpyhelper\",\n", + " \"id\": session_id,\n", + " \"aggregator\": \"fedopt\",\n", + " \"aggregator_kwargs\": {\n", + " \"serveropt\": \"yogi\",\n", + " \"learning_rate\": 1e-2,\n", + " },\n", + " \"model_id\": seed_model['model'],\n", + " \"rounds\": 10\n", + " }\n", + "\n", + "result_fedyogi = client.start_session(**session_config)\n", + "while not client.session_is_finished(session_config['id']):\n", + " time.sleep(2)" + ] + }, + { + "cell_type": "markdown", + "id": "3432d6d5-84cb-4825-9ae9-5bc932ceea77", + "metadata": {}, + "source": [ + "### FedAdaGrad" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "10827f50-ebf4-458c-9bbe-114e7412ada1", + "metadata": {}, + "outputs": [], + "source": [ + "session_id = \"experiment_fedadagrad\"\n", + "\n", + "session_config = {\n", + " \"helper\": \"numpyhelper\",\n", + " \"id\": session_id,\n", + " \"aggregator\": \"fedopt\",\n", + " \"aggregator_kwargs\": {\n", + " \"serveropt\": \"adagrad\",\n", + " \"learning_rate\": 1e-1,\n", + " }, \n", + " \"model_id\": seed_model['model'],\n", + " \"rounds\": 10\n", + " }\n", + "\n", + "result_fedadagrad = client.start_session(**session_config)" + ] + }, + { + "cell_type": "markdown", + "id": "16874cec", + "metadata": {}, + "source": [ + "Next, we get the model trail, retrieve all model validations from all clients, extract the training accuracy metric, and compute its mean value accross all clients." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "6f68c692-d299-430e-a253-a4f21f16789d", + "metadata": {}, + "outputs": [], + "source": [ + "from collections import OrderedDict\n", + "def get_validation_acc(session_id):\n", + " validations = client.get_validations(session_id)\n", + " acc = OrderedDict()\n", + " for validation in validations['result']:\n", + " try: \n", + " acc[validation['model_id']].append(json.loads(validation['data'])['training_accuracy'])\n", + " except:\n", + " acc[validation['model_id']] = [json.loads(validation['data'])['training_accuracy']]\n", + "\n", + " accuracy_score = []\n", + " for key, value in acc.items():\n", + " accuracy_score.append(np.mean(value))\n", + " accuracy_score.reverse()\n", + " return(accuracy_score)\n", + " \n", + "score = get_validation_acc(\"experiment_fedadagrad\")" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "4e8044b7", + "metadata": {}, + "outputs": [], + "source": [ + "mean_acc_fedavg = get_validation_acc(\"experiment_fedavg\")\n", + "mean_acc_fedadam = get_validation_acc(\"experiment_fedadam\")\n", + "mean_acc_yogi = get_validation_acc(\"experiment_fedyogi\")\n", + "mean_acc_adagrad = get_validation_acc(\"experiment_fedadagrad\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "42425c43", + "metadata": {}, + "outputs": [], + "source": [ + "x = range(1,len(mean_acc_fedavg)+1)\n", + "plt.plot(x,mean_acc_fedavg, x, mean_acc_fedadam, x, mean_acc_yogi, x, mean_acc_adagrad)\n", + "plt.legend(['FedAvg','FedAdam', 'FedYogi', 'FedAdaGrad'])" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/fedn/__init__.py b/fedn/__init__.py index 31be09d81..703eab7b2 100644 --- a/fedn/__init__.py +++ b/fedn/__init__.py @@ -1,3 +1,23 @@ -"""The fedn package.""" +import glob +import os +from os.path import basename, dirname, isfile + +from fedn.network.api.client import APIClient # flake8: noqa + + +modules = glob.glob(dirname(__file__) + "/fedn" + "/*.py") +__all__ = [basename(f)[:-3] for f in modules if isfile(f) and not f.endswith("__init__.py")] + + +_ROOT = os.path.abspath(os.path.dirname(__file__)) + + +def get_data(path): + """ + + :param path: + :return: + """ + return os.path.join(_ROOT, "data", path) diff --git a/fedn/cli/__init__.py b/fedn/cli/__init__.py index 840d4252b..137fc9b9c 100644 --- a/fedn/cli/__init__.py +++ b/fedn/cli/__init__.py @@ -1,2 +1,11 @@ +from .client_cmd import client_cmd # noqa: F401 +from .combiner_cmd import combiner_cmd # noqa: F401 +from .config_cmd import config_cmd # noqa: F401 from .main import main # noqa: F401 +from .model_cmd import model_cmd # noqa: F401 +from .package_cmd import package_cmd # noqa: F401 +from .round_cmd import round_cmd # noqa: F401 from .run_cmd import run_cmd # noqa: F401 +from .session_cmd import session_cmd # noqa: F401 +from .status_cmd import status_cmd # noqa: F401 +from .validation_cmd import validation_cmd # noqa: F401 diff --git a/fedn/cli/client_cmd.py b/fedn/cli/client_cmd.py new file mode 100644 index 000000000..e72f29569 --- /dev/null +++ b/fedn/cli/client_cmd.py @@ -0,0 +1,172 @@ +import uuid + +import click +import requests + +from fedn.common.exceptions import InvalidClientConfig +from fedn.network.clients.client import Client + +from .main import main +from .shared import CONTROLLER_DEFAULTS, apply_config, get_api_url, get_token, print_response + + +def validate_client_config(config): + """Validate client configuration. + + :param config: Client config (dict). + """ + + try: + if config["discover_host"] is None or config["discover_host"] == "": + raise InvalidClientConfig("Missing required configuration: discover_host") + if "discover_port" not in config.keys(): + config["discover_port"] = None + except Exception: + raise InvalidClientConfig("Could not load config from file. Check config") + + +@main.group("client") +@click.pass_context +def client_cmd(ctx): + """ + + :param ctx: + """ + pass + + +@click.option("-p", "--protocol", required=False, default=CONTROLLER_DEFAULTS["protocol"], help="Communication protocol of controller (api)") +@click.option("-H", "--host", required=False, default=CONTROLLER_DEFAULTS["host"], help="Hostname of controller (api)") +@click.option("-P", "--port", required=False, default=CONTROLLER_DEFAULTS["port"], help="Port of controller (api)") +@click.option("-t", "--token", required=False, help="Authentication token") +@click.option("--n_max", required=False, help="Number of items to list") +@client_cmd.command("list") +@click.pass_context +def list_clients(ctx, protocol: str, host: str, port: str, token: str = None, n_max: int = None): + """ + Return: + ------ + - count: number of clients + - result: list of clients + + """ + url = get_api_url(protocol=protocol, host=host, port=port, endpoint="clients") + headers = {} + + if n_max: + headers["X-Limit"] = n_max + + _token = get_token(token) + + if _token: + headers["Authorization"] = _token + + click.echo(f"\nListing clients: {url}\n") + click.echo(f"Headers: {headers}") + + try: + response = requests.get(url, headers=headers) + print_response(response, "clients") + except requests.exceptions.ConnectionError: + click.echo(f"Error: Could not connect to {url}") + + +@client_cmd.command("start") +@click.option("-d", "--discoverhost", required=False, help="Hostname for discovery services(reducer).") +@click.option("-p", "--discoverport", required=False, help="Port for discovery services (reducer).") +@click.option("--token", required=False, help="Set token provided by reducer if enabled") +@click.option("-n", "--name", required=False, default="client" + str(uuid.uuid4())[:8]) +@click.option("-i", "--client_id", required=False) +@click.option("--local-package", is_flag=True, help="Enable local compute package") +@click.option("--force-ssl", is_flag=True, help="Force SSL/TLS for REST service") +@click.option("-u", "--dry-run", required=False, default=False) +@click.option("-s", "--secure", required=False, default=False) +@click.option("-pc", "--preshared-cert", required=False, default=False) +@click.option("-v", "--verify", is_flag=True, help="Verify SSL/TLS for REST service") +@click.option("-c", "--preferred-combiner", required=False, default=False) +@click.option("-va", "--validator", required=False, default=True) +@click.option("-tr", "--trainer", required=False, default=True) +@click.option("-in", "--init", required=False, default=None, help="Set to a filename to (re)init client from file state.") +@click.option("-l", "--logfile", required=False, default=None, help="Set logfile for client log to file.") +@click.option("--heartbeat-interval", required=False, default=2) +@click.option("--reconnect-after-missed-heartbeat", required=False, default=30) +@click.option("--verbosity", required=False, default="INFO", type=click.Choice(["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"], case_sensitive=False)) +@click.pass_context +def client_cmd( + ctx, + discoverhost, + discoverport, + token, + name, + client_id, + local_package, + force_ssl, + dry_run, + secure, + preshared_cert, + verify, + preferred_combiner, + validator, + trainer, + init, + logfile, + heartbeat_interval, + reconnect_after_missed_heartbeat, + verbosity, +): + """ + + :param ctx: + :param discoverhost: + :param discoverport: + :param token: + :param name: + :param client_id: + :param remote: + :param dry_run: + :param secure: + :param preshared_cert: + :param verify_cert: + :param preferred_combiner: + :param init: + :param logfile: + :param hearbeat_interval + :param reconnect_after_missed_heartbeat + :param verbosity + :return: + """ + remote = False if local_package else True + config = { + "discover_host": discoverhost, + "discover_port": discoverport, + "token": token, + "name": name, + "client_id": client_id, + "remote_compute_context": remote, + "force_ssl": force_ssl, + "dry_run": dry_run, + "secure": secure, + "preshared_cert": preshared_cert, + "verify": verify, + "preferred_combiner": preferred_combiner, + "validator": validator, + "trainer": trainer, + "logfile": logfile, + "heartbeat_interval": heartbeat_interval, + "reconnect_after_missed_heartbeat": reconnect_after_missed_heartbeat, + "verbosity": verbosity, + } + + if init: + apply_config(init, config) + click.echo(f"\nClient configuration loaded from file: {init}") + click.echo("Values set in file override defaults and command line arguments...\n") + + try: + validate_client_config(config) + except InvalidClientConfig as e: + click.echo(f"Error: {e}") + return + + client = Client(config) + client.run() diff --git a/fedn/cli/combiner_cmd.py b/fedn/cli/combiner_cmd.py new file mode 100644 index 000000000..2b4447437 --- /dev/null +++ b/fedn/cli/combiner_cmd.py @@ -0,0 +1,104 @@ +import uuid + +import click +import requests + +from fedn.network.combiner.combiner import Combiner + +from .main import main +from .shared import CONTROLLER_DEFAULTS, apply_config, get_api_url, get_token, print_response + + +@main.group("combiner") +@click.pass_context +def combiner_cmd(ctx): + """ + + :param ctx: + """ + pass + + +@combiner_cmd.command("start") +@click.option("-d", "--discoverhost", required=False, help="Hostname for discovery services (reducer).") +@click.option("-p", "--discoverport", required=False, help="Port for discovery services (reducer).") +@click.option("-t", "--token", required=False, help="Set token provided by reducer if enabled") +@click.option("-n", "--name", required=False, default="combiner" + str(uuid.uuid4())[:8], help="Set name for combiner.") +@click.option("-h", "--host", required=False, default="combiner", help="Set hostname.") +@click.option("-i", "--port", required=False, default=12080, help="Set port.") +@click.option("-f", "--fqdn", required=False, default=None, help="Set fully qualified domain name") +@click.option("-s", "--secure", is_flag=True, help="Enable SSL/TLS encrypted gRPC channels.") +@click.option("-v", "--verify", is_flag=True, help="Verify SSL/TLS for REST discovery service (reducer)") +@click.option("-c", "--max_clients", required=False, default=30, help="The maximal number of client connections allowed.") +@click.option("-in", "--init", required=False, default=None, help="Path to configuration file to (re)init combiner.") +@click.pass_context +def start_cmd(ctx, discoverhost, discoverport, token, name, host, port, fqdn, secure, verify, max_clients, init): + """ + + :param ctx: + :param discoverhost: + :param discoverport: + :param token: + :param name: + :param hostname: + :param port: + :param secure: + :param max_clients: + :param init: + """ + config = { + "discover_host": discoverhost, + "discover_port": discoverport, + "token": token, + "host": host, + "port": port, + "fqdn": fqdn, + "name": name, + "secure": secure, + "verify": verify, + "max_clients": max_clients, + } + + if init: + apply_config(init, config) + click.echo(f"\nCombiner configuration loaded from file: {init}") + click.echo("Values set in file override defaults and command line arguments...\n") + + combiner = Combiner(config) + combiner.run() + + +@click.option("-p", "--protocol", required=False, default=CONTROLLER_DEFAULTS["protocol"], help="Communication protocol of controller (api)") +@click.option("-H", "--host", required=False, default=CONTROLLER_DEFAULTS["host"], help="Hostname of controller (api)") +@click.option("-P", "--port", required=False, default=CONTROLLER_DEFAULTS["port"], help="Port of controller (api)") +@click.option("-t", "--token", required=False, help="Authentication token") +@click.option("--n_max", required=False, help="Number of items to list") +@combiner_cmd.command("list") +@click.pass_context +def list_combiners(ctx, protocol: str, host: str, port: str, token: str = None, n_max: int = None): + """ + Return: + ------ + - count: number of combiners + - result: list of combiners + + """ + url = get_api_url(protocol=protocol, host=host, port=port, endpoint="combiners") + headers = {} + + if n_max: + headers["X-Limit"] = n_max + + _token = get_token(token) + + if _token: + headers["Authorization"] = _token + + click.echo(f"\nListing combiners: {url}\n") + click.echo(f"Headers: {headers}") + + try: + response = requests.get(url, headers=headers) + print_response(response, "combiners") + except requests.exceptions.ConnectionError: + click.echo(f"Error: Could not connect to {url}") diff --git a/fedn/cli/config_cmd.py b/fedn/cli/config_cmd.py new file mode 100644 index 000000000..d5286997f --- /dev/null +++ b/fedn/cli/config_cmd.py @@ -0,0 +1,36 @@ +import os + +import click + +from .main import main + +envs = [ + {"name": "FEDN_CONTROLLER_PROTOCOL", "description": "The protocol to use for communication with the controller."}, + {"name": "FEDN_CONTROLLER_HOST", "description": "The host to use for communication with the controller."}, + {"name": "FEDN_CONTROLLER_PORT", "description": "The port to use for communication with the controller."}, + {"name": "FEDN_AUTH_TOKEN", "description": "The authentication token to use for communication with the controller and combiner."}, + {"name": "FEDN_AUTH_SCHEME", "description": "The authentication scheme to use for communication with the controller and combiner."}, + { + "name": "FEDN_CONTROLLER_URL", + "description": "The URL of the controller. Overrides FEDN_CONTROLLER_PROTOCOL, FEDN_CONTROLLER_HOST and FEDN_CONTROLLER_PORT.", + }, + {"name": "FEDN_PACKAGE_EXTRACT_DIR", "description": "The directory to extract packages to."}, +] + + +@main.group("config", invoke_without_command=True) +@click.pass_context +def config_cmd(ctx): + """ + - Configuration commands for the FEDn CLI. + """ + if ctx.invoked_subcommand is None: + click.echo("\n--- FEDn Cli Configuration ---\n") + click.echo("Current configuration:\n") + + for env in envs: + name = env["name"] + value = os.environ.get(name) + click.echo(f'{name}: {value or "Not set"}') + click.echo(f'{env["description"]}\n') + click.echo("\n") diff --git a/fedn/cli/main.py b/fedn/cli/main.py index d7b63ba13..32aa6e3e4 100644 --- a/fedn/cli/main.py +++ b/fedn/cli/main.py @@ -2,7 +2,7 @@ CONTEXT_SETTINGS = dict( # Support -h as a shortcut for --help - help_option_names=['-h', '--help'], + help_option_names=["-h", "--help"], ) diff --git a/fedn/cli/model_cmd.py b/fedn/cli/model_cmd.py new file mode 100644 index 000000000..e44793a9f --- /dev/null +++ b/fedn/cli/model_cmd.py @@ -0,0 +1,51 @@ +import click +import requests + +from .main import main +from .shared import CONTROLLER_DEFAULTS, get_api_url, get_token, print_response + + +@main.group("model") +@click.pass_context +def model_cmd(ctx): + """ + + :param ctx: + """ + pass + + +@click.option("-p", "--protocol", required=False, default=CONTROLLER_DEFAULTS["protocol"], help="Communication protocol of controller (api)") +@click.option("-H", "--host", required=False, default=CONTROLLER_DEFAULTS["host"], help="Hostname of controller (api)") +@click.option("-P", "--port", required=False, default=CONTROLLER_DEFAULTS["port"], help="Port of controller (api)") +@click.option("-t", "--token", required=False, help="Authentication token") +@click.option("--n_max", required=False, help="Number of items to list") +@model_cmd.command("list") +@click.pass_context +def list_models(ctx, protocol: str, host: str, port: str, token: str = None, n_max: int = None): + """ + Return: + ------ + - count: number of models + - result: list of models + + """ + url = get_api_url(protocol=protocol, host=host, port=port, endpoint="models") + headers = {} + + if n_max: + headers["X-Limit"] = n_max + + _token = get_token(token) + + if _token: + headers["Authorization"] = _token + + click.echo(f"\nListing models: {url}\n") + click.echo(f"Headers: {headers}") + + try: + response = requests.get(url, headers=headers) + print_response(response, "models") + except requests.exceptions.ConnectionError: + click.echo(f"Error: Could not connect to {url}") diff --git a/fedn/cli/package_cmd.py b/fedn/cli/package_cmd.py new file mode 100644 index 000000000..6d503d414 --- /dev/null +++ b/fedn/cli/package_cmd.py @@ -0,0 +1,79 @@ +import os +import tarfile + +import click +import requests + +from fedn.common.log_config import logger + +from .main import main +from .shared import CONTROLLER_DEFAULTS, get_api_url, get_token, print_response + + +@main.group("package") +@click.pass_context +def package_cmd(ctx): + """ + + :param ctx: + """ + pass + + +@package_cmd.command("create") +@click.option("-p", "--path", required=True, help="Path to package directory containing fedn.yaml") +@click.option("-n", "--name", required=False, default="package.tgz", help="Name of package tarball") +@click.pass_context +def create_cmd(ctx, path, name): + """Create compute package. + + Make a tar.gz archive of folder given by --path + + :param ctx: + :param path: + """ + path = os.path.abspath(path) + yaml_file = os.path.join(path, "fedn.yaml") + if not os.path.exists(yaml_file): + logger.error(f"Could not find fedn.yaml in {path}") + exit(-1) + + with tarfile.open(name, "w:gz") as tar: + tar.add(path, arcname=os.path.basename(path)) + logger.info(f"Created package {name}") + + +@click.option("-p", "--protocol", required=False, default=CONTROLLER_DEFAULTS["protocol"], help="Communication protocol of controller (api)") +@click.option("-H", "--host", required=False, default=CONTROLLER_DEFAULTS["host"], help="Hostname of controller (api)") +@click.option("-P", "--port", required=False, default=CONTROLLER_DEFAULTS["port"], help="Port of controller (api)") +@click.option("-t", "--token", required=False, help="Authentication token") +@click.option("--n_max", required=False, help="Number of items to list") +@package_cmd.command("list") +@click.pass_context +def list_packages(ctx, protocol: str, host: str, port: str, token: str = None, n_max: int = None): + """ + Return: + ------ + - count: number of packages + - result: list of packages + + """ + url = get_api_url(protocol=protocol, host=host, port=port, endpoint="packages") + headers = {} + + if n_max: + headers["X-Limit"] = n_max + + _token = get_token(token) + + if _token: + headers["Authorization"] = _token + + click.echo(f"\nListing packages: {url}\n") + click.echo(f"Headers: {headers}") + + try: + response = requests.get(url, headers=headers) + print_response(response, "packages") + except requests.exceptions.ConnectionError: + click.echo(f"Error: Could not connect to {url}") diff --git a/fedn/cli/round_cmd.py b/fedn/cli/round_cmd.py new file mode 100644 index 000000000..ca23cafe7 --- /dev/null +++ b/fedn/cli/round_cmd.py @@ -0,0 +1,51 @@ +import click +import requests + +from .main import main +from .shared import CONTROLLER_DEFAULTS, get_api_url, get_token, print_response + + +@main.group("round") +@click.pass_context +def round_cmd(ctx): + """ + + :param ctx: + """ + pass + + +@click.option("-p", "--protocol", required=False, default=CONTROLLER_DEFAULTS["protocol"], help="Communication protocol of controller (api)") +@click.option("-H", "--host", required=False, default=CONTROLLER_DEFAULTS["host"], help="Hostname of controller (api)") +@click.option("-P", "--port", required=False, default=CONTROLLER_DEFAULTS["port"], help="Port of controller (api)") +@click.option("-t", "--token", required=False, help="Authentication token") +@click.option("--n_max", required=False, help="Number of items to list") +@round_cmd.command("list") +@click.pass_context +def list_rounds(ctx, protocol: str, host: str, port: str, token: str = None, n_max: int = None): + """ + Return: + ------ + - count: number of rounds + - result: list of rounds + + """ + url = get_api_url(protocol=protocol, host=host, port=port, endpoint="rounds") + headers = {} + + if n_max: + headers["X-Limit"] = n_max + + _token = get_token(token) + + if _token: + headers["Authorization"] = _token + + click.echo(f"\nListing rounds: {url}\n") + click.echo(f"Headers: {headers}") + + try: + response = requests.get(url, headers=headers) + print_response(response, "rounds") + except requests.exceptions.ConnectionError: + click.echo(f"Error: Could not connect to {url}") diff --git a/fedn/cli/run_cmd.py b/fedn/cli/run_cmd.py index 818abc2c0..b89a9b5fe 100644 --- a/fedn/cli/run_cmd.py +++ b/fedn/cli/run_cmd.py @@ -1,6 +1,5 @@ import os import shutil -import tarfile import uuid import click @@ -13,7 +12,9 @@ from fedn.network.combiner.combiner import Combiner from fedn.utils.dispatcher import Dispatcher, _read_yaml_file +from .client_cmd import validate_client_config from .main import main +from .shared import apply_config def get_statestore_config_from_file(init): @@ -22,7 +23,7 @@ def get_statestore_config_from_file(init): :param init: :return: """ - with open(init, 'r') as file: + with open(init, "r") as file: try: settings = dict(yaml.safe_load(file)) return settings @@ -31,7 +32,7 @@ def get_statestore_config_from_file(init): def check_helper_config_file(config): - control = config['control'] + control = config["control"] try: helper = control["helper"] except KeyError: @@ -40,40 +41,6 @@ def check_helper_config_file(config): return helper -def apply_config(config): - """Parse client config from file. - - Override configs from the CLI with settings in config file. - - :param config: Client config (dict). - """ - with open(config['init'], 'r') as file: - try: - settings = dict(yaml.safe_load(file)) - except Exception: - logger.error('Failed to read config from settings file, exiting.') - return - - for key, val in settings.items(): - config[key] = val - - -def validate_client_config(config): - """Validate client configuration. - - :param config: Client config (dict). - """ - - try: - if config['discover_host'] is None or \ - config['discover_host'] == '': - raise InvalidClientConfig("Missing required configuration: discover_host") - if 'discover_port' not in config.keys(): - config['discover_port'] = None - except Exception: - raise InvalidClientConfig("Could not load config from file. Check config") - - def sanitize_config(config): # List of keys to sanitize (remove or mask) sensitive_keys = ["discover_host", @@ -100,32 +67,81 @@ def run_cmd(ctx): pass -@run_cmd.command('client') -@click.option('-d', '--discoverhost', required=False, help='Hostname for discovery services(reducer).') -@click.option('-p', '--discoverport', required=False, help='Port for discovery services (reducer).') -@click.option('--token', required=False, help='Set token provided by reducer if enabled') -@click.option('-n', '--name', required=False, default="client" + str(uuid.uuid4())[:8]) -@click.option('-i', '--client_id', required=False) -@click.option('--local-package', is_flag=True, help='Enable local compute package') -@click.option('--force-ssl', is_flag=True, help='Force SSL/TLS for REST service') -@click.option('-u', '--dry-run', required=False, default=False) -@click.option('-s', '--secure', required=False, default=False) -@click.option('-pc', '--preshared-cert', required=False, default=False) -@click.option('-v', '--verify', is_flag=True, help='Verify SSL/TLS for REST service') -@click.option('-c', '--preferred-combiner', required=False, default=False) -@click.option('-va', '--validator', required=False, default=True) -@click.option('-tr', '--trainer', required=False, default=True) -@click.option('-in', '--init', required=False, default=None, - help='Set to a filename to (re)init client from file state.') -@click.option('-l', '--logfile', required=False, default=None, - help='Set logfile for client log to file.') -@click.option('--heartbeat-interval', required=False, default=2) -@click.option('--reconnect-after-missed-heartbeat', required=False, default=30) -@click.option('--verbosity', required=False, default='INFO', type=click.Choice(['CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG'], case_sensitive=False)) +@run_cmd.command("build") +@click.option("-p", "--path", required=True, help="Path to package directory containing fedn.yaml") +@click.pass_context +def build_cmd(ctx, path): + """Execute 'build' entrypoint in fedn.yaml. + + :param ctx: + :param path: Path to folder containing fedn.yaml + :type path: str + """ + path = os.path.abspath(path) + yaml_file = os.path.join(path, "fedn.yaml") + if not os.path.exists(yaml_file): + logger.error(f"Could not find fedn.yaml in {path}") + exit(-1) + + config = _read_yaml_file(yaml_file) + # Check that build is defined in fedn.yaml under entry_points + if "build" not in config["entry_points"]: + logger.error("No build command defined in fedn.yaml") + exit(-1) + + dispatcher = Dispatcher(config, path) + _ = dispatcher._get_or_create_python_env() + dispatcher.run_cmd("build") + + # delete the virtualenv + if dispatcher.python_env_path: + logger.info(f"Removing virtualenv {dispatcher.python_env_path}") + shutil.rmtree(dispatcher.python_env_path) + + +@run_cmd.command("client") +@click.option("-d", "--discoverhost", required=False, help="Hostname for discovery services(reducer).") +@click.option("-p", "--discoverport", required=False, help="Port for discovery services (reducer).") +@click.option("--token", required=False, help="Set token provided by reducer if enabled") +@click.option("-n", "--name", required=False, default="client" + str(uuid.uuid4())[:8]) +@click.option("-i", "--client_id", required=False) +@click.option("--local-package", is_flag=True, help="Enable local compute package") +@click.option("--force-ssl", is_flag=True, help="Force SSL/TLS for REST service") +@click.option("-u", "--dry-run", required=False, default=False) +@click.option("-s", "--secure", required=False, default=False) +@click.option("-pc", "--preshared-cert", required=False, default=False) +@click.option("-v", "--verify", is_flag=True, help="Verify SSL/TLS for REST service") +@click.option("-c", "--preferred-combiner", required=False, default=False) +@click.option("-va", "--validator", required=False, default=True) +@click.option("-tr", "--trainer", required=False, default=True) +@click.option("-in", "--init", required=False, default=None, help="Set to a filename to (re)init client from file state.") +@click.option("-l", "--logfile", required=False, default=None, help="Set logfile for client log to file.") +@click.option("--heartbeat-interval", required=False, default=2) +@click.option("--reconnect-after-missed-heartbeat", required=False, default=30) +@click.option("--verbosity", required=False, default="INFO", type=click.Choice(["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"], case_sensitive=False)) @click.pass_context -def client_cmd(ctx, discoverhost, discoverport, token, name, client_id, local_package, force_ssl, dry_run, secure, preshared_cert, - verify, preferred_combiner, validator, trainer, init, logfile, heartbeat_interval, reconnect_after_missed_heartbeat, - verbosity): +def client_cmd( + ctx, + discoverhost, + discoverport, + token, + name, + client_id, + local_package, + force_ssl, + dry_run, + secure, + preshared_cert, + verify, + preferred_combiner, + validator, + trainer, + init, + logfile, + heartbeat_interval, + reconnect_after_missed_heartbeat, + verbosity, +): """ :param ctx: @@ -149,36 +165,58 @@ def client_cmd(ctx, discoverhost, discoverport, token, name, client_id, local_pa """ with tracer.start_as_current_span("client_cmd") as span: remote = False if local_package else True - config = {'discover_host': discoverhost, 'discover_port': discoverport, 'token': token, 'name': name, - 'client_id': client_id, 'remote_compute_context': remote, 'force_ssl': force_ssl, 'dry_run': dry_run, 'secure': secure, - 'preshared_cert': preshared_cert, 'verify': verify, 'preferred_combiner': preferred_combiner, - 'validator': validator, 'trainer': trainer, 'init': init, 'logfile': logfile, 'heartbeat_interval': heartbeat_interval, - 'reconnect_after_missed_heartbeat': reconnect_after_missed_heartbeat, 'verbosity': verbosity} - span.set_attribute("client_config", str(sanitize_config(config))) - context = get_context() - span.set_attribute("context", str(context)) + config = { + "discover_host": discoverhost, + "discover_port": discoverport, + "token": token, + "name": name, + "client_id": client_id, + "remote_compute_context": remote, + "force_ssl": force_ssl, + "dry_run": dry_run, + "secure": secure, + "preshared_cert": preshared_cert, + "verify": verify, + "preferred_combiner": preferred_combiner, + "validator": validator, + "trainer": trainer, + "logfile": logfile, + "heartbeat_interval": heartbeat_interval, + "reconnect_after_missed_heartbeat": reconnect_after_missed_heartbeat, + "verbosity": verbosity, + } + + click.echo( + click.style("\n*** fedn run client is deprecated and will be removed. Please use fedn client start instead. ***\n", blink=True, bold=True, fg="red") + ) + if init: - apply_config(config) + apply_config(init, config) + click.echo(f"\nClient configuration loaded from file: {init}") + click.echo("Values set in file override defaults and command line arguments...\n") - validate_client_config(config) + try: + validate_client_config(config) + except InvalidClientConfig as e: + click.echo(f"Error: {e}") + return client = Client(config) client.run() -@run_cmd.command('combiner') -@click.option('-d', '--discoverhost', required=False, help='Hostname for discovery services (reducer).') -@click.option('-p', '--discoverport', required=False, help='Port for discovery services (reducer).') -@click.option('-t', '--token', required=False, help='Set token provided by reducer if enabled') -@click.option('-n', '--name', required=False, default="combiner" + str(uuid.uuid4())[:8], help='Set name for combiner.') -@click.option('-h', '--host', required=False, default="combiner", help='Set hostname.') -@click.option('-i', '--port', required=False, default=12080, help='Set port.') -@click.option('-f', '--fqdn', required=False, default=None, help='Set fully qualified domain name') -@click.option('-s', '--secure', is_flag=True, help='Enable SSL/TLS encrypted gRPC channels.') -@click.option('-v', '--verify', is_flag=True, help='Verify SSL/TLS for REST discovery service (reducer)') -@click.option('-c', '--max_clients', required=False, default=30, help='The maximal number of client connections allowed.') -@click.option('-in', '--init', required=False, default=None, - help='Path to configuration file to (re)init combiner.') +@run_cmd.command("combiner") +@click.option("-d", "--discoverhost", required=False, help="Hostname for discovery services (reducer).") +@click.option("-p", "--discoverport", required=False, help="Port for discovery services (reducer).") +@click.option("-t", "--token", required=False, help="Set token provided by reducer if enabled") +@click.option("-n", "--name", required=False, default="combiner" + str(uuid.uuid4())[:8], help="Set name for combiner.") +@click.option("-h", "--host", required=False, default="combiner", help="Set hostname.") +@click.option("-i", "--port", required=False, default=12080, help="Set port.") +@click.option("-f", "--fqdn", required=False, default=None, help="Set fully qualified domain name") +@click.option("-s", "--secure", is_flag=True, help="Enable SSL/TLS encrypted gRPC channels.") +@click.option("-v", "--verify", is_flag=True, help="Verify SSL/TLS for REST discovery service (reducer)") +@click.option("-c", "--max_clients", required=False, default=30, help="The maximal number of client connections allowed.") +@click.option("-in", "--init", required=False, default=None, help="Path to configuration file to (re)init combiner.") @click.pass_context def combiner_cmd(ctx, discoverhost, discoverport, token, name, host, port, fqdn, secure, verify, max_clients, init): """ @@ -195,79 +233,27 @@ def combiner_cmd(ctx, discoverhost, discoverport, token, name, host, port, fqdn, :param init: """ with tracer.start_as_current_span("combiner_cmd"): - config = {'discover_host': discoverhost, 'discover_port': discoverport, 'token': token, 'host': host, - 'port': port, 'fqdn': fqdn, 'name': name, 'secure': secure, 'verify': verify, 'max_clients': max_clients, - 'init': init} + config = { + "discover_host": discoverhost, + "discover_port": discoverport, + "token": token, + "host": host, + "port": port, + "fqdn": fqdn, + "name": name, + "secure": secure, + "verify": verify, + "max_clients": max_clients, + } + + click.echo( + click.style("\n*** fedn run combiner is deprecated and will be removed. Please use fedn combiner start instead. ***\n", blink=True, bold=True, fg="red") + ) - if config['init']: - apply_config(config) + if init: + apply_config(init, config) + click.echo(f"\nCombiner configuration loaded from file: {init}") + click.echo("Values set in file override defaults and command line arguments...\n") combiner = Combiner(config) combiner.run() - - -@run_cmd.command('build') -@click.option('-p', '--path', required=True, help='Path to package directory containing fedn.yaml') -@click.pass_context -def build_cmd(ctx, path): - """ Execute 'build' entrypoint in fedn.yaml. - - :param ctx: - :param path: Path to folder containing fedn.yaml - :type path: str - """ - with tracer.start_as_current_span("build_cmd"): - path = os.path.abspath(path) - yaml_file = os.path.join(path, 'fedn.yaml') - if not os.path.exists(yaml_file): - logger.error(f"Could not find fedn.yaml in {path}") - exit(-1) - - config = _read_yaml_file(yaml_file) - # Check that build is defined in fedn.yaml under entry_points - if 'build' not in config['entry_points']: - logger.error("No build command defined in fedn.yaml") - exit(-1) - - dispatcher = Dispatcher(config, path) - _ = dispatcher._get_or_create_python_env() - dispatcher.run_cmd("build") - - # delete the virtualenv - if dispatcher.python_env_path: - logger.info(f"Removing virtualenv {dispatcher.python_env_path}") - shutil.rmtree(dispatcher.python_env_path) - - -@main.group('package') -@click.pass_context -def package_cmd(ctx): - """ - - :param ctx: - """ - pass - - -@package_cmd.command('create') -@click.option('-p', '--path', required=True, help='Path to package directory containing fedn.yaml') -@click.option('-n', '--name', required=False, default='package.tgz', help='Name of package tarball') -@click.pass_context -def create_cmd(ctx, path, name): - """ Create compute package. - - Make a tar.gz archive of folder given by --path - - :param ctx: - :param path: - """ - with tracer.start_as_current_span("create_cmd"): - path = os.path.abspath(path) - yaml_file = os.path.join(path, 'fedn.yaml') - if not os.path.exists(yaml_file): - logger.error(f"Could not find fedn.yaml in {path}") - exit(-1) - - with tarfile.open(name, "w:gz") as tar: - tar.add(path, arcname=os.path.basename(path)) - logger.info(f"Created package {name}") diff --git a/fedn/cli/session_cmd.py b/fedn/cli/session_cmd.py new file mode 100644 index 000000000..55597b5b3 --- /dev/null +++ b/fedn/cli/session_cmd.py @@ -0,0 +1,51 @@ +import click +import requests + +from .main import main +from .shared import CONTROLLER_DEFAULTS, get_api_url, get_token, print_response + + +@main.group("session") +@click.pass_context +def session_cmd(ctx): + """ + + :param ctx: + """ + pass + + +@click.option("-p", "--protocol", required=False, default=CONTROLLER_DEFAULTS["protocol"], help="Communication protocol of controller (api)") +@click.option("-H", "--host", required=False, default=CONTROLLER_DEFAULTS["host"], help="Hostname of controller (api)") +@click.option("-P", "--port", required=False, default=CONTROLLER_DEFAULTS["port"], help="Port of controller (api)") +@click.option("-t", "--token", required=False, help="Authentication token") +@click.option("--n_max", required=False, help="Number of items to list") +@session_cmd.command("list") +@click.pass_context +def list_sessions(ctx, protocol: str, host: str, port: str, token: str = None, n_max: int = None): + """ + Return: + ------ + - count: number of sessions + - result: list of sessions + + """ + url = get_api_url(protocol=protocol, host=host, port=port, endpoint="sessions") + headers = {} + + if n_max: + headers["X-Limit"] = n_max + + _token = get_token(token) + + if _token: + headers["Authorization"] = _token + + click.echo(f"\nListing sessions: {url}\n") + click.echo(f"Headers: {headers}") + + try: + response = requests.get(url, headers=headers) + print_response(response, "sessions") + except requests.exceptions.ConnectionError: + click.echo(f"Error: Could not connect to {url}") diff --git a/fedn/cli/shared.py b/fedn/cli/shared.py new file mode 100644 index 000000000..2500d9e2b --- /dev/null +++ b/fedn/cli/shared.py @@ -0,0 +1,92 @@ +import os + +import click +import yaml + +from fedn.common.log_config import logger + +CONTROLLER_DEFAULTS = {"protocol": "http", "host": "localhost", "port": 8092, "debug": False} + +COMBINER_DEFAULTS = {"discover_host": "localhost", "discover_port": 8092, "host": "localhost", "port": 12080, "name": "combiner", "max_clients": 30} + +CLIENT_DEFAULTS = { + "discover_host": "localhost", + "discover_port": 8092, +} + +API_VERSION = "v1" + + +def apply_config(path: str, config: dict): + """Parse client config from file. + + Override configs from the CLI with settings in config file. + + :param config: Client config (dict). + """ + with open(path, "r") as file: + try: + settings = dict(yaml.safe_load(file)) + except Exception: + logger.error("Failed to read config from settings file, exiting.") + return + + for key, val in settings.items(): + config[key] = val + + +def get_api_url(protocol: str, host: str, port: str, endpoint: str) -> str: + _url = os.environ.get("FEDN_CONTROLLER_URL") + + if _url: + return f"{_url}/api/{API_VERSION}/{endpoint}/" + + _protocol = protocol or os.environ.get("FEDN_CONTROLLER_PROTOCOL") or CONTROLLER_DEFAULTS["protocol"] + _host = host or os.environ.get("FEDN_CONTROLLER_HOST") or CONTROLLER_DEFAULTS["host"] + _port = port or os.environ.get("FEDN_CONTROLLER_PORT") or CONTROLLER_DEFAULTS["port"] + + return f"{_protocol}://{_host}:{_port}/api/{API_VERSION}/{endpoint}/" + + +def get_token(token: str) -> str: + _token = token or os.environ.get("FEDN_AUTH_TOKEN", None) + + if _token is None: + return None + + scheme = os.environ.get("FEDN_AUTH_SCHEME", "Bearer") + + return f"{scheme} {_token}" + + +def get_client_package_dir(path: str) -> str: + return path or os.environ.get("FEDN_PACKAGE_DIR", None) + + +# Print response from api (list of entities) +def print_response(response, entity_name: str): + """ + Prints the api response to the cli. + :param response: + type: array + description: list of entities + :param entity_name: + type: string + description: name of entity + return: None + """ + if response.status_code == 200: + json_data = response.json() + count, result = json_data.values() + click.echo(f"Found {count} {entity_name}") + click.echo("\n---------------------------------\n") + for obj in result: + click.echo("{") + for k, v in obj.items(): + click.echo(f"\t{k}: {v}") + click.echo("}") + elif response.status_code == 500: + json_data = response.json() + click.echo(f'Error: {json_data["message"]}') + else: + click.echo(f"Error: {response.status_code}") diff --git a/fedn/cli/status_cmd.py b/fedn/cli/status_cmd.py new file mode 100644 index 000000000..a4f17e349 --- /dev/null +++ b/fedn/cli/status_cmd.py @@ -0,0 +1,51 @@ +import click +import requests + +from .main import main +from .shared import CONTROLLER_DEFAULTS, get_api_url, get_token, print_response + + +@main.group("status") +@click.pass_context +def status_cmd(ctx): + """ + + :param ctx: + """ + pass + + +@click.option("-p", "--protocol", required=False, default=CONTROLLER_DEFAULTS["protocol"], help="Communication protocol of controller (api)") +@click.option("-H", "--host", required=False, default=CONTROLLER_DEFAULTS["host"], help="Hostname of controller (api)") +@click.option("-P", "--port", required=False, default=CONTROLLER_DEFAULTS["port"], help="Port of controller (api)") +@click.option("-t", "--token", required=False, help="Authentication token") +@click.option("--n_max", required=False, help="Number of items to list") +@status_cmd.command("list") +@click.pass_context +def list_statuses(ctx, protocol: str, host: str, port: str, token: str = None, n_max: int = None): + """ + Return: + ------ + - count: number of statuses + - result: list of statuses + + """ + url = get_api_url(protocol=protocol, host=host, port=port, endpoint="statuses") + headers = {} + + if n_max: + headers["X-Limit"] = n_max + + _token = get_token(token) + + if _token: + headers["Authorization"] = _token + + click.echo(f"\nListing statuses: {url}\n") + click.echo(f"Headers: {headers}") + + try: + response = requests.get(url, headers=headers) + print_response(response, "statuses") + except requests.exceptions.ConnectionError: + click.echo(f"Error: Could not connect to {url}") diff --git a/fedn/cli/validation_cmd.py b/fedn/cli/validation_cmd.py new file mode 100644 index 000000000..055be0c65 --- /dev/null +++ b/fedn/cli/validation_cmd.py @@ -0,0 +1,51 @@ +import click +import requests + +from .main import main +from .shared import CONTROLLER_DEFAULTS, get_api_url, get_token, print_response + + +@main.group("validation") +@click.pass_context +def validation_cmd(ctx): + """ + + :param ctx: + """ + pass + + +@click.option("-p", "--protocol", required=False, default=CONTROLLER_DEFAULTS["protocol"], help="Communication protocol of controller (api)") +@click.option("-H", "--host", required=False, default=CONTROLLER_DEFAULTS["host"], help="Hostname of controller (api)") +@click.option("-P", "--port", required=False, default=CONTROLLER_DEFAULTS["port"], help="Port of controller (api)") +@click.option("-t", "--token", required=False, help="Authentication token") +@click.option("--n_max", required=False, help="Number of items to list") +@validation_cmd.command("list") +@click.pass_context +def list_validations(ctx, protocol: str, host: str, port: str, token: str = None, n_max: int = None): + """ + Return: + ------ + - count: number of validations + - result: list of validations + + """ + url = get_api_url(protocol=protocol, host=host, port=port, endpoint="validations") + headers = {} + + if n_max: + headers["X-Limit"] = n_max + + _token = get_token(token) + + if _token: + headers["Authorization"] = _token + + click.echo(f"\nListing validations: {url}\n") + click.echo(f"Headers: {headers}") + + try: + response = requests.get(url, headers=headers) + print_response(response, "validations") + except requests.exceptions.ConnectionError: + click.echo(f"Error: Could not connect to {url}") diff --git a/fedn/fedn/common/__init__.py b/fedn/common/__init__.py similarity index 100% rename from fedn/fedn/common/__init__.py rename to fedn/common/__init__.py diff --git a/fedn/fedn/common/certificate/__init__.py b/fedn/common/certificate/__init__.py similarity index 100% rename from fedn/fedn/common/certificate/__init__.py rename to fedn/common/certificate/__init__.py diff --git a/fedn/fedn/common/certificate/certificate.py b/fedn/common/certificate/certificate.py similarity index 82% rename from fedn/fedn/common/certificate/certificate.py rename to fedn/common/certificate/certificate.py index a2c059748..857a05e7c 100644 --- a/fedn/fedn/common/certificate/certificate.py +++ b/fedn/common/certificate/certificate.py @@ -13,19 +13,18 @@ class Certificate: Utility to generate unsigned certificates. """ + CERT_NAME = "cert.pem" KEY_NAME = "key.pem" BITS = 2048 def __init__(self, cwd, name=None, key_name="key.pem", cert_name="cert.pem", create_dirs=True): - try: os.makedirs(cwd) except OSError: logger.info("Directory exists, will store all cert and keys here.") else: - logger.info( - "Successfully created the directory to store cert and keys in {}".format(cwd)) + logger.info("Successfully created the directory to store cert and keys in {}".format(cwd)) self.key_path = os.path.join(cwd, key_name) self.cert_path = os.path.join(cwd, cert_name) @@ -35,7 +34,9 @@ def __init__(self, cwd, name=None, key_name="key.pem", cert_name="cert.pem", cre else: self.name = str(uuid.uuid4()) - def gen_keypair(self, ): + def gen_keypair( + self, + ): """ Generate keypair. @@ -70,21 +71,19 @@ def set_keypair_raw(self, certificate, privatekey): :param privatekey: """ with open(self.key_path, "wb") as keyfile: - keyfile.write(crypto.dump_privatekey( - crypto.FILETYPE_PEM, privatekey)) + keyfile.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, privatekey)) with open(self.cert_path, "wb") as certfile: - certfile.write(crypto.dump_certificate( - crypto.FILETYPE_PEM, certificate)) + certfile.write(crypto.dump_certificate(crypto.FILETYPE_PEM, certificate)) def get_keypair_raw(self): """ :return: """ - with open(self.key_path, 'rb') as keyfile: + with open(self.key_path, "rb") as keyfile: key_buf = keyfile.read() - with open(self.cert_path, 'rb') as certfile: + with open(self.cert_path, "rb") as certfile: cert_buf = certfile.read() return copy.deepcopy(cert_buf), copy.deepcopy(key_buf) @@ -93,7 +92,7 @@ def get_key(self): :return: """ - with open(self.key_path, 'rb') as keyfile: + with open(self.key_path, "rb") as keyfile: key_buf = keyfile.read() key = crypto.load_privatekey(crypto.FILETYPE_PEM, key_buf) return key @@ -103,7 +102,7 @@ def get_cert(self): :return: """ - with open(self.cert_path, 'rb') as certfile: + with open(self.cert_path, "rb") as certfile: cert_buf = certfile.read() cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert_buf) return cert diff --git a/fedn/fedn/common/certificate/certificatemanager.py b/fedn/common/certificate/certificatemanager.py similarity index 80% rename from fedn/fedn/common/certificate/certificatemanager.py rename to fedn/common/certificate/certificatemanager.py index 3d34fa1ad..ce165d862 100644 --- a/fedn/fedn/common/certificate/certificatemanager.py +++ b/fedn/common/certificate/certificatemanager.py @@ -10,7 +10,6 @@ class CertificateManager: """ def __init__(self, directory): - self.directory = directory self.certificates = [] self.allowed = dict() @@ -28,8 +27,7 @@ def get_or_create(self, name): if search: return search else: - cert = Certificate(self.directory, name=name, - cert_name=name + '-cert.pem', key_name=name + '-key.pem') + cert = Certificate(self.directory, name=name, cert_name=name + "-cert.pem", key_name=name + "-key.pem") cert.gen_keypair() self.certificates.append(cert) return cert @@ -53,12 +51,11 @@ def load_all(self): """ for filename in sorted(os.listdir(self.directory)): - if filename.endswith('cert.pem'): - name = filename.split('-')[0] - key_name = name + '-key.pem' + if filename.endswith("cert.pem"): + name = filename.split("-")[0] + key_name = name + "-key.pem" - c = Certificate(self.directory, name=name, - cert_name=filename, key_name=key_name) + c = Certificate(self.directory, name=name, cert_name=filename, key_name=key_name) self.certificates.append(c) def find(self, name): diff --git a/fedn/fedn/common/config.py b/fedn/common/config.py similarity index 57% rename from fedn/fedn/common/config.py rename to fedn/common/config.py index b7edec319..4864ce1ef 100644 --- a/fedn/fedn/common/config.py +++ b/fedn/common/config.py @@ -5,34 +5,31 @@ global STATESTORE_CONFIG global MODELSTORAGE_CONFIG -SECRET_KEY = os.environ.get('FEDN_JWT_SECRET_KEY', False) -FEDN_JWT_CUSTOM_CLAIM_KEY = os.environ.get('FEDN_JWT_CUSTOM_CLAIM_KEY', False) -FEDN_JWT_CUSTOM_CLAIM_VALUE = os.environ.get('FEDN_JWT_CUSTOM_CLAIM_VALUE', False) +SECRET_KEY = os.environ.get("FEDN_JWT_SECRET_KEY", False) +FEDN_JWT_CUSTOM_CLAIM_KEY = os.environ.get("FEDN_JWT_CUSTOM_CLAIM_KEY", False) +FEDN_JWT_CUSTOM_CLAIM_VALUE = os.environ.get("FEDN_JWT_CUSTOM_CLAIM_VALUE", False) -FEDN_AUTH_WHITELIST_URL_PREFIX = os.environ.get('FEDN_AUTH_WHITELIST_URL_PREFIX', False) -FEDN_JWT_ALGORITHM = os.environ.get('FEDN_JWT_ALGORITHM', 'HS256') -FEDN_AUTH_SCHEME = os.environ.get('FEDN_AUTH_SCHEME', 'Bearer') -FEDN_AUTH_REFRESH_TOKEN_URI = os.environ.get('FEDN_AUTH_REFRESH_TOKEN_URI', False) -FEDN_AUTH_REFRESH_TOKEN = os.environ.get('FEDN_AUTH_REFRESH_TOKEN', False) -FEDN_CUSTOM_URL_PREFIX = os.environ.get('FEDN_CUSTOM_URL_PREFIX', '') +FEDN_AUTH_WHITELIST_URL_PREFIX = os.environ.get("FEDN_AUTH_WHITELIST_URL_PREFIX", False) +FEDN_JWT_ALGORITHM = os.environ.get("FEDN_JWT_ALGORITHM", "HS256") +FEDN_AUTH_SCHEME = os.environ.get("FEDN_AUTH_SCHEME", "Bearer") +FEDN_AUTH_REFRESH_TOKEN_URI = os.environ.get("FEDN_AUTH_REFRESH_TOKEN_URI", False) +FEDN_AUTH_REFRESH_TOKEN = os.environ.get("FEDN_AUTH_REFRESH_TOKEN", False) +FEDN_CUSTOM_URL_PREFIX = os.environ.get("FEDN_CUSTOM_URL_PREFIX", "") -FEDN_PACKAGE_EXTRACT_DIR = os.environ.get('FEDN_PACKAGE_EXTRACT_DIR', '') +FEDN_PACKAGE_EXTRACT_DIR = os.environ.get("FEDN_PACKAGE_EXTRACT_DIR", "package") def get_environment_config(): - """ Get the configuration from environment variables. - """ + """Get the configuration from environment variables.""" global STATESTORE_CONFIG global MODELSTORAGE_CONFIG - STATESTORE_CONFIG = os.environ.get('STATESTORE_CONFIG', - '/workspaces/fedn/config/settings-reducer.yaml.template') - MODELSTORAGE_CONFIG = os.environ.get('MODELSTORAGE_CONFIG', - '/workspaces/fedn/config/settings-reducer.yaml.template') + STATESTORE_CONFIG = os.environ.get("STATESTORE_CONFIG", "/workspaces/fedn/config/settings-reducer.yaml.template") + MODELSTORAGE_CONFIG = os.environ.get("MODELSTORAGE_CONFIG", "/workspaces/fedn/config/settings-reducer.yaml.template") def get_statestore_config(file=None): - """ Get the statestore configuration from file. + """Get the statestore configuration from file. :param file: The statestore configuration file (yaml) path (optional). :type file: str @@ -42,7 +39,7 @@ def get_statestore_config(file=None): if file is None: get_environment_config() file = STATESTORE_CONFIG - with open(file, 'r') as config_file: + with open(file, "r") as config_file: try: settings = dict(yaml.safe_load(config_file)) except yaml.YAMLError as e: @@ -51,7 +48,7 @@ def get_statestore_config(file=None): def get_modelstorage_config(file=None): - """ Get the model storage configuration from file. + """Get the model storage configuration from file. :param file: The model storage configuration file (yaml) path (optional). :type file: str @@ -61,7 +58,7 @@ def get_modelstorage_config(file=None): if file is None: get_environment_config() file = MODELSTORAGE_CONFIG - with open(file, 'r') as config_file: + with open(file, "r") as config_file: try: settings = dict(yaml.safe_load(config_file)) except yaml.YAMLError as e: @@ -70,7 +67,7 @@ def get_modelstorage_config(file=None): def get_network_config(file=None): - """ Get the network configuration from file. + """Get the network configuration from file. :param file: The network configuration file (yaml) path (optional). :type file: str @@ -80,7 +77,7 @@ def get_network_config(file=None): if file is None: get_environment_config() file = STATESTORE_CONFIG - with open(file, 'r') as config_file: + with open(file, "r") as config_file: try: settings = dict(yaml.safe_load(config_file)) except yaml.YAMLError as e: @@ -89,7 +86,7 @@ def get_network_config(file=None): def get_controller_config(file=None): - """ Get the controller configuration from file. + """Get the controller configuration from file. :param file: The controller configuration file (yaml) path (optional). :type file: str @@ -99,7 +96,7 @@ def get_controller_config(file=None): if file is None: get_environment_config() file = STATESTORE_CONFIG - with open(file, 'r') as config_file: + with open(file, "r") as config_file: try: settings = dict(yaml.safe_load(config_file)) except yaml.YAMLError as e: diff --git a/fedn/fedn/common/exceptions.py b/fedn/common/exceptions.py similarity index 63% rename from fedn/fedn/common/exceptions.py rename to fedn/common/exceptions.py index 8d970e4d4..db8f36f86 100644 --- a/fedn/fedn/common/exceptions.py +++ b/fedn/common/exceptions.py @@ -4,3 +4,7 @@ class ModelError(BaseException): class InvalidClientConfig(BaseException): pass + + +class InvalidParameterError(BaseException): + pass diff --git a/fedn/fedn/common/log_config.py b/fedn/common/log_config.py similarity index 56% rename from fedn/fedn/common/log_config.py rename to fedn/common/log_config.py index 0e61a6a83..b8aa1218b 100644 --- a/fedn/fedn/common/log_config.py +++ b/fedn/common/log_config.py @@ -5,13 +5,7 @@ import requests import urllib3 -log_levels = { - "DEBUG": logging.DEBUG, - "INFO": logging.INFO, - "WARNING": logging.WARNING, - "ERROR": logging.ERROR, - "CRITICAL": logging.CRITICAL -} +log_levels = {"DEBUG": logging.DEBUG, "INFO": logging.INFO, "WARNING": logging.WARNING, "ERROR": logging.ERROR, "CRITICAL": logging.CRITICAL} urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) @@ -21,12 +15,12 @@ logger = logging.getLogger("fedn") logger.addHandler(handler) logger.setLevel(logging.DEBUG) -formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S') +formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S") handler.setFormatter(formatter) class StudioHTTPHandler(logging.handlers.HTTPHandler): - def __init__(self, host, url, method='POST', token=None): + def __init__(self, host, url, method="POST", token=None): super().__init__(host, url, method) self.token = token @@ -34,41 +28,35 @@ def emit(self, record): log_entry = self.mapLogRecord(record) log_entry = { - "msg": log_entry['msg'], - "levelname": log_entry['levelname'], + "msg": log_entry["msg"], + "levelname": log_entry["levelname"], "project": os.environ.get("PROJECT_ID"), - "appinstance": os.environ.get("APP_ID") - + "appinstance": os.environ.get("APP_ID"), } # Setup headers headers = { - 'Content-type': 'application/json', + "Content-type": "application/json", } if self.token: - remote_token_protocol = os.environ.get('FEDN_REMOTE_LOG_TOKEN_PROTOCOL', "Token") - headers['Authorization'] = f"{remote_token_protocol} {self.token}" - if self.method.lower() == 'post': - requests.post(self.host+self.url, json=log_entry, headers=headers) + remote_token_protocol = os.environ.get("FEDN_REMOTE_LOG_TOKEN_PROTOCOL", "Token") + headers["Authorization"] = f"{remote_token_protocol} {self.token}" + if self.method.lower() == "post": + requests.post(self.host + self.url, json=log_entry, headers=headers) else: # No other methods implemented. return # Remote logging can only be configured via environment variables for now. -REMOTE_LOG_SERVER = os.environ.get('FEDN_REMOTE_LOG_SERVER', False) -REMOTE_LOG_PATH = os.environ.get('FEDN_REMOTE_LOG_PATH', False) -REMOTE_LOG_LEVEL = os.environ.get('FEDN_REMOTE_LOG_LEVEL', 'INFO') +REMOTE_LOG_SERVER = os.environ.get("FEDN_REMOTE_LOG_SERVER", False) +REMOTE_LOG_PATH = os.environ.get("FEDN_REMOTE_LOG_PATH", False) +REMOTE_LOG_LEVEL = os.environ.get("FEDN_REMOTE_LOG_LEVEL", "INFO") if REMOTE_LOG_SERVER: rloglevel = log_levels.get(REMOTE_LOG_LEVEL, logging.INFO) - remote_token = os.environ.get('FEDN_REMOTE_LOG_TOKEN', None) - - http_handler = StudioHTTPHandler( - host=REMOTE_LOG_SERVER, - url=REMOTE_LOG_PATH, - method='POST', - token=remote_token - ) + remote_token = os.environ.get("FEDN_REMOTE_LOG_TOKEN", None) + + http_handler = StudioHTTPHandler(host=REMOTE_LOG_SERVER, url=REMOTE_LOG_PATH, method="POST", token=remote_token) http_handler.setLevel(rloglevel) logger.addHandler(http_handler) @@ -79,11 +67,11 @@ def set_log_level_from_string(level_str): """ # Mapping of string representation to logging constants level_mapping = { - 'CRITICAL': logging.CRITICAL, - 'ERROR': logging.ERROR, - 'WARNING': logging.WARNING, - 'INFO': logging.INFO, - 'DEBUG': logging.DEBUG, + "CRITICAL": logging.CRITICAL, + "ERROR": logging.ERROR, + "WARNING": logging.WARNING, + "INFO": logging.INFO, + "DEBUG": logging.DEBUG, } # Get the logging level from the mapping diff --git a/fedn/fedn/common/net/__init__.py b/fedn/common/net/__init__.py similarity index 100% rename from fedn/fedn/common/net/__init__.py rename to fedn/common/net/__init__.py diff --git a/fedn/fedn/common/telemetry.py b/fedn/common/telemetry.py similarity index 100% rename from fedn/fedn/common/telemetry.py rename to fedn/common/telemetry.py diff --git a/fedn/fedn/__init__.py b/fedn/fedn/__init__.py deleted file mode 100644 index 7e3df239e..000000000 --- a/fedn/fedn/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -import glob -import os -from os.path import basename, dirname, isfile - -from fedn.network.api.client import APIClient - -# flake8: noqa - - -modules = glob.glob(dirname(__file__) + "/*.py") -__all__ = [basename(f)[:-3] for f in modules if isfile(f) - and not f.endswith('__init__.py')] - - -_ROOT = os.path.abspath(os.path.dirname(__file__)) - - -def get_data(path): - """ - - :param path: - :return: - """ - return os.path.join(_ROOT, 'data', path) diff --git a/fedn/fedn/network/combiner/aggregators/aggregator.py b/fedn/fedn/network/combiner/aggregators/aggregator.py deleted file mode 100644 index 32458b46e..000000000 --- a/fedn/fedn/network/combiner/aggregators/aggregator.py +++ /dev/null @@ -1,121 +0,0 @@ -import json -import queue -from abc import ABC, abstractmethod - -import fedn.common.net.grpc.fedn_pb2 as fedn -from fedn.common.telemetry import trace_all_methods - - -@trace_all_methods -class Aggregator(ABC): - """ Abstract class defining an aggregator. """ - - @abstractmethod - def __init__(self, id, storage, server, modelservice, control): - """ Initialize the aggregator. - - :param id: A reference to id of :class: `fedn.network.combiner.Combiner` - :type id: str - :param storage: Model repository for :class: `fedn.network.combiner.Combiner` - :type storage: class: `fedn.common.storage.s3.s3repo.S3ModelRepository` - :param server: A handle to the Combiner class :class: `fedn.network.combiner.Combiner` - :type server: class: `fedn.network.combiner.Combiner` - :param modelservice: A handle to the model service :class: `fedn.network.combiner.modelservice.ModelService` - :type modelservice: class: `fedn.network.combiner.modelservice.ModelService` - :param control: A handle to the :class: `fedn.network.combiner.round.RoundController` - :type control: class: `fedn.network.combiner.round.RoundController` - """ - self.name = self.__class__.__name__ - self.storage = storage - self.id = id - self.server = server - self.modelservice = modelservice - self.control = control - self.model_updates = queue.Queue() - - @abstractmethod - def combine_models(self, nr_expected_models=None, nr_required_models=1, helper=None, timeout=180): - """Routine for combining model updates. Implemented in subclass. - - :param nr_expected_models: Number of expected models. If None, wait for all models. - :type nr_expected_models: int - :param nr_required_models: Number of required models to combine. - :type nr_required_models: int - :param helper: A helper object. - :type helper: :class: `fedn.utils.plugins.helperbase.HelperBase` - :param timeout: Timeout in seconds to wait for models to be combined. - :type timeout: int - :return: A combined model. - """ - pass - - def on_model_update(self, model_update): - """Callback when a new client model update is recieved. - Performs (optional) pre-processing and then puts the update id - on the aggregation queue. Override in subclass as needed. - - :param model_update: A ModelUpdate message. - :type model_id: str - """ - try: - self.server.report_status("AGGREGATOR({}): callback received model update {}".format(self.name, model_update.model_update_id), - log_level=fedn.Status.INFO) - - # Validate the update and metadata - valid_update = self._validate_model_update(model_update) - if valid_update: - # Push the model update to the processing queue - self.model_updates.put(model_update) - else: - self.server.report_status("AGGREGATOR({}): Invalid model update, skipping.".format(self.name)) - except Exception as e: - self.server.report_status("AGGREGATOR({}): Failed to receive candidate model! {}".format(self.name, e), - log_level=fedn.Status.WARNING) - pass - - def on_model_validation(self, model_validation): - """ Callback when a new client model validation is recieved. - Performs (optional) pre-processing and then writes the validation - to the database. Override in subclass as needed. - - :param validation: Dict containing validation data sent by client. - Must be valid JSON. - :type validation: dict - """ - - # self.report_validation(validation) - self.server.report_status("AGGREGATOR({}): callback processed validation {}".format(self.name, model_validation.model_id), - log_level=fedn.Status.INFO) - - def _validate_model_update(self, model_update): - """ Validate the model update. - - :param model_update: A ModelUpdate message. - :type model_update: object - :return: True if the model update is valid, False otherwise. - :rtype: bool - """ - # TODO: Validate the metadata to check that it contains all variables assumed by the aggregator. - data = json.loads(model_update.meta)['training_metadata'] - if 'num_examples' not in data.keys(): - self.server.report_status("AGGREGATOR({}): Model validation failed, num_examples missing in metadata.".format(self.name)) - return False - return True - - def next_model_update(self, helper): - """ Get the next model update from the queue. - - :param helper: A helper object. - :type helper: object - :return: A tuple containing the model update, metadata and model id. - :rtype: tuple - """ - model_update = self.model_updates.get(block=False) - model_id = model_update.model_update_id - model_next = self.control.load_model_update(helper, model_id) - # Get relevant metadata - data = json.loads(model_update.meta)['training_metadata'] - config = json.loads(json.loads(model_update.meta)['config']) - data['round_id'] = config['round_id'] - - return model_next, data, model_id diff --git a/fedn/fedn/network/combiner/aggregators/fedopt.py b/fedn/fedn/network/combiner/aggregators/fedopt.py deleted file mode 100644 index 737fd916b..000000000 --- a/fedn/fedn/network/combiner/aggregators/fedopt.py +++ /dev/null @@ -1,139 +0,0 @@ -import math - -from fedn.common.log_config import logger -from fedn.common.telemetry import trace_all_methods -from fedn.network.combiner.aggregators.aggregatorbase import AggregatorBase - - -@trace_all_methods -class Aggregator(AggregatorBase): - """ Federated Optimization (FedOpt) aggregator. - - Implmentation following: https://arxiv.org/pdf/2003.00295.pdf - - Aggregate pseudo gradients computed by subtracting the model - update from the global model weights from the previous round. - - :param id: A reference to id of :class: `fedn.network.combiner.Combiner` - :type id: str - :param storage: Model repository for :class: `fedn.network.combiner.Combiner` - :type storage: class: `fedn.common.storage.s3.s3repo.S3ModelRepository` - :param server: A handle to the Combiner class :class: `fedn.network.combiner.Combiner` - :type server: class: `fedn.network.combiner.Combiner` - :param modelservice: A handle to the model service :class: `fedn.network.combiner.modelservice.ModelService` - :type modelservice: class: `fedn.network.combiner.modelservice.ModelService` - :param control: A handle to the :class: `fedn.network.combiner.roundhandler.RoundHandler` - :type control: class: `fedn.network.combiner.roundhandler.RoundHandler` - - """ - - def __init__(self, storage, server, modelservice, round_handler): - - super().__init__(storage, server, modelservice, round_handler) - - self.name = "fedopt" - self.v = None - self.m = None - - # Server side hyperparameters. Note that these may need extensive fine tuning. - self.eta = 1e-2 - self.beta1 = 0.9 - self.beta2 = 0.99 - self.tau = 1e-4 - - def combine_models(self, helper=None, delete_models=True): - """Compute pseudo gradients using model updates in the queue. - - :param helper: An instance of :class: `fedn.utils.helpers.helpers.HelperBase`, ML framework specific helper, defaults to None - :type helper: class: `fedn.utils.helpers.helpers.HelperBase`, optional - :param time_window: The time window for model aggregation, defaults to 180 - :type time_window: int, optional - :param max_nr_models: The maximum number of updates aggregated, defaults to 100 - :type max_nr_models: int, optional - :param delete_models: Delete models from storage after aggregation, defaults to True - :type delete_models: bool, optional - :return: The global model and metadata - :rtype: tuple - """ - - data = {} - data['time_model_load'] = 0.0 - data['time_model_aggregation'] = 0.0 - - model = None - nr_aggregated_models = 0 - total_examples = 0 - - logger.info( - "AGGREGATOR({}): Aggregating model updates... ".format(self.name)) - - while not self.model_updates.empty(): - try: - # Get next model from queue - model_update = self.next_model_update() - - # Load model paratmeters and metadata - model_next, metadata = self.load_model_update(model_update, helper) - - logger.info( - "AGGREGATOR({}): Processing model update {}, metadata: {} ".format(self.name, model_update.model_update_id, metadata)) - logger.info("***** {}".format(model_update)) - - # Increment total number of examples - total_examples += metadata['num_examples'] - - if nr_aggregated_models == 0: - model_old = self.round_handler.load_model_update(helper, model_update.model_id) - pseudo_gradient = helper.subtract(model_next, model_old) - else: - pseudo_gradient_next = helper.subtract(model_next, model_old) - pseudo_gradient = helper.increment_average( - pseudo_gradient, pseudo_gradient_next, metadata['num_examples'], total_examples) - - logger.info("NORM PSEUDOGRADIENT: {}".format(helper.norm(pseudo_gradient))) - - nr_aggregated_models += 1 - # Delete model from storage - if delete_models: - self.modelservice.temp_model_storage.delete(model_update.model_update_id) - logger.info( - "AGGREGATOR({}): Deleted model update {} from storage.".format(self.name, model_update.model_update_id)) - self.model_updates.task_done() - except Exception as e: - logger.error( - "AGGREGATOR({}): Error encoutered while processing model update {}, skipping this update.".format(self.name, e)) - self.model_updates.task_done() - - model = self.serveropt_adam(helper, pseudo_gradient, model_old) - - data['nr_aggregated_models'] = nr_aggregated_models - - logger.info("AGGREGATOR({}): Aggregation completed, aggregated {} models.".format(self.name, nr_aggregated_models)) - return model, data - - def serveropt_adam(self, helper, pseudo_gradient, model_old): - """ Server side optimization, FedAdam. - - :param helper: instance of helper class. - :type helper: Helper - :param pseudo_gradient: The pseudo gradient. - :type pseudo_gradient: As defined by helper. - :return: new model weights. - :rtype: as defined by helper. - """ - - if not self.v: - self.v = helper.ones(pseudo_gradient, math.pow(self.tau, 2)) - - if not self.m: - self.m = helper.multiply(pseudo_gradient, [(1.0-self.beta1)]*len(pseudo_gradient)) - else: - self.m = helper.add(self.m, pseudo_gradient, self.beta1, (1.0-self.beta1)) - - p = helper.power(pseudo_gradient, 2) - self.v = helper.add(self.v, p, self.beta2, (1.0-self.beta2)) - sv = helper.add(helper.sqrt(self.v), helper.ones(self.v, self.tau)) - t = helper.divide(self.m, sv) - - model = helper.add(model_old, t, 1.0, self.eta) - return model diff --git a/fedn/fedn/utils/plots.py b/fedn/fedn/utils/plots.py deleted file mode 100644 index 605404e4e..000000000 --- a/fedn/fedn/utils/plots.py +++ /dev/null @@ -1,478 +0,0 @@ -import json -from datetime import datetime - -import numpy -import plotly -import plotly.graph_objs as go -from plotly.subplots import make_subplots - -from fedn.common.log_config import logger -from fedn.common.telemetry import trace_all_methods -from fedn.network.storage.statestore.mongostatestore import MongoStateStore - - -@trace_all_methods -class Plot: - """ - - """ - - def __init__(self, statestore): - try: - statestore_config = statestore.get_config() - statestore = MongoStateStore( - statestore_config['network_id'], statestore_config['mongo_config']) - self.mdb = statestore.connect() - self.status = self.mdb['control.status'] - self.round_time = self.mdb["control.round_time"] - self.combiner_round_time = self.mdb["control.combiner_round_time"] - self.psutil_usage = self.mdb["control.psutil_monitoring"] - self.network_clients = self.mdb["network.clients"] - - except Exception as e: - logger.error("FAILED TO CONNECT TO MONGO, {}".format(e)) - self.collection = None - raise - - # plot metrics from DB - def _scalar_metrics(self, metrics): - """ Extract all scalar valued metrics from a MODEL_VALIDATON. """ - - data = json.loads(metrics['data']) - data = json.loads(data['data']) - - valid_metrics = [] - for metric, val in data.items(): - # If it can be converted to a float it is a valid, scalar metric - try: - val = float(val) - valid_metrics.append(metric) - except Exception: - pass - - return valid_metrics - - def create_table_plot(self): - """ - - :return: - """ - metrics = self.status.find_one({'type': 'MODEL_VALIDATION'}) - if metrics is None: - fig = go.Figure(data=[]) - fig.update_layout( - title_text='No data currently available for table mean metrics') - table = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) - return False - - valid_metrics = self._scalar_metrics(metrics) - if valid_metrics == []: - fig = go.Figure(data=[]) - fig.update_layout(title_text='No scalar metrics found') - table = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) - return False - - all_vals = [] - models = [] - for metric in valid_metrics: - validations = {} - for post in self.status.find({'type': 'MODEL_VALIDATION'}): - e = json.loads(post['data']) - try: - validations[e['modelId']].append( - float(json.loads(e['data'])[metric])) - except KeyError: - validations[e['modelId']] = [ - float(json.loads(e['data'])[metric])] - - vals = [] - models = [] - for model, data in validations.items(): - vals.append(numpy.mean(data)) - models.append(model) - all_vals.append(vals) - - header_vals = valid_metrics - models.reverse() - values = [models] - - for vals in all_vals: - vals.reverse() - values.append(vals) - - fig = go.Figure(data=[go.Table( - header=dict(values=['Model ID'] + header_vals, - line_color='darkslategray', - fill_color='lightskyblue', - align='left'), - - cells=dict(values=values, # 2nd column - line_color='darkslategray', - fill_color='lightcyan', - align='left')) - ]) - - fig.update_layout(title_text='Summary: mean metrics') - table = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) - return table - - def create_timeline_plot(self): - """ - - :return: - """ - trace_data = [] - x = [] - y = [] - base = [] - for p in self.status.find({'type': 'MODEL_UPDATE_REQUEST'}): - e = json.loads(p['data']) - cid = e['correlationId'] - for cc in self.status.find({'sender': p['sender'], 'type': 'MODEL_UPDATE'}): - da = json.loads(cc['data']) - if da['correlationId'] == cid: - cp = cc - - cd = json.loads(cp['data']) - tr = datetime.strptime(e['timestamp'], '%Y-%m-%d %H:%M:%S.%f') - tu = datetime.strptime(cd['timestamp'], '%Y-%m-%d %H:%M:%S.%f') - ts = tu - tr - base.append(tr.timestamp()) - x.append(ts.total_seconds() / 60.0) - y.append(p['sender']['name']) - - trace_data.append(go.Bar( - x=y, - y=x, - marker=dict(color='royalblue'), - name="Training", - )) - - x = [] - y = [] - base = [] - for p in self.status.find({'type': 'MODEL_VALIDATION_REQUEST'}): - e = json.loads(p['data']) - cid = e['correlationId'] - for cc in self.status.find({'sender': p['sender'], 'type': 'MODEL_VALIDATION'}): - da = json.loads(cc['data']) - if da['correlationId'] == cid: - cp = cc - cd = json.loads(cp['data']) - tr = datetime.strptime(e['timestamp'], '%Y-%m-%d %H:%M:%S.%f') - tu = datetime.strptime(cd['timestamp'], '%Y-%m-%d %H:%M:%S.%f') - ts = tu - tr - base.append(tr.timestamp()) - x.append(ts.total_seconds() / 60.0) - y.append(p['sender']['name']) - - trace_data.append(go.Bar( - x=y, - y=x, - marker=dict(color='lightskyblue'), - name="Validation", - )) - - layout = go.Layout( - barmode='stack', - showlegend=True, - ) - - fig = go.Figure(data=trace_data, layout=layout) - fig.update_xaxes(title_text='Alliance/client') - fig.update_yaxes(title_text='Time (Min)') - fig.update_layout(title_text='Alliance timeline') - timeline = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) - return timeline - - def create_client_training_distribution(self): - """ - - :return: - """ - training = [] - for p in self.status.find({'type': 'MODEL_UPDATE'}): - e = json.loads(p['data']) - meta = json.loads(e['meta']) - training.append(meta['exec_training']) - - if not training: - return False - fig = go.Figure(data=go.Histogram(x=training)) - fig.update_layout( - title_text='Client model training time, mean: {}'.format(numpy.mean(training))) - histogram = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) - return histogram - - def create_client_histogram_plot(self): - """ - - :return: - """ - training = [] - for p in self.status.find({'type': 'MODEL_UPDATE'}): - e = json.loads(p['data']) - meta = json.loads(e['meta']) - training.append(meta['exec_training']) - - fig = go.Figure() - - fig.update_layout( - template="simple_white", - xaxis=dict(title_text="Time (s)"), - yaxis=dict(title_text='Number of updates'), - title="Mean client training time: {}".format( - numpy.mean(training)), - # showlegend=True - ) - if not training: - return False - - fig.add_trace(go.Histogram(x=training)) - - histogram_plot = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) - return histogram_plot - - def create_client_plot(self): - """ - - :return: - """ - processing = [] - upload = [] - download = [] - training = [] - for p in self.status.find({'type': 'MODEL_UPDATE'}): - e = json.loads(p['data']) - meta = json.loads(e['meta']) - upload.append(meta['upload_model']) - download.append(meta['fetch_model']) - training.append(meta['exec_training']) - processing.append(meta['processing_time']) - - fig = go.Figure() - fig.update_layout( - template="simple_white", - title="Mean client processing time: {}".format( - numpy.mean(processing)), - showlegend=True - ) - if not processing: - return False - data = [numpy.mean(training), numpy.mean(upload), numpy.mean(download)] - labels = ["Training execution", "Model upload (to combiner)", "Model download (from combiner)"] - fig.add_trace(go.Pie(labels=labels, values=data)) - - client_plot = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) - return client_plot - - def create_combiner_plot(self): - """ - - :return: - """ - waiting = [] - aggregation = [] - model_load = [] - combination = [] - for round in self.mdb['control.round'].find(): - try: - for combiner in round['combiners']: - data = combiner - stats = data['local_round']['1'] - ml = stats['aggregation_time']['time_model_load'] - ag = stats['aggregation_time']['time_model_aggregation'] - combination.append(stats['time_combination']) - waiting.append(stats['time_combination'] - ml - ag) - model_load.append(ml) - aggregation.append(ag) - except Exception: - pass - - labels = ['Waiting for client updates', - 'Aggregation', 'Loading model updates from disk'] - val = [numpy.mean(waiting), numpy.mean( - aggregation), numpy.mean(model_load)] - fig = go.Figure() - - fig.update_layout( - template="simple_white", - title="Mean combiner round time: {}".format( - numpy.mean(combination)), - showlegend=True - ) - if not combination: - return False - fig.add_trace(go.Pie(labels=labels, values=val)) - combiner_plot = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) - return combiner_plot - - def fetch_valid_metrics(self): - """ - - :return: - """ - metrics = self.status.find_one({'type': 'MODEL_VALIDATION'}) - valid_metrics = self._scalar_metrics(metrics) - return valid_metrics - - def create_box_plot(self, metric): - """ - - :param metric: - :return: - """ - metrics = self.status.find_one({'type': 'MODEL_VALIDATION'}) - if metrics is None: - fig = go.Figure(data=[]) - fig.update_layout(title_text='No data currently available for metric distribution over ' - 'participants') - box = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) - return box - - valid_metrics = self._scalar_metrics(metrics) - if valid_metrics == []: - fig = go.Figure(data=[]) - fig.update_layout(title_text='No scalar metrics found') - box = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) - return box - - validations = {} - for post in self.status.find({'type': 'MODEL_VALIDATION'}): - e = json.loads(post['data']) - try: - validations[e['modelId']].append( - float(json.loads(e['data'])[metric])) - except KeyError: - validations[e['modelId']] = [ - float(json.loads(e['data'])[metric])] - - # Make sure validations are plotted in chronological order - model_trail = self.mdb.control.model.find_one({'key': 'model_trail'}) - model_trail_ids = model_trail['model'] - validations_sorted = [] - for model_id in model_trail_ids: - try: - validations_sorted.append(validations[model_id]) - except Exception: - pass - - validations = validations_sorted - - box = go.Figure() - - y = [] - for j, acc in enumerate(validations): - # x.append(j) - y.append(numpy.mean([float(i) for i in acc])) - if len(acc) >= 2: - box.add_trace(go.Box(y=acc, name=str(j), marker_color="royalblue", showlegend=False, - boxpoints=False)) - else: - box.add_trace(go.Scatter( - x=[str(j)], y=[y[j]], showlegend=False)) - - rounds = list(range(len(y))) - box.add_trace(go.Scatter( - x=rounds, - y=y, - name='Mean' - )) - - box.update_xaxes(title_text='Rounds') - box.update_yaxes( - tickvals=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]) - box.update_layout(title_text='Metric distribution over clients: {}'.format(metric), - margin=dict(l=20, r=20, t=45, b=20)) - box = json.dumps(box, cls=plotly.utils.PlotlyJSONEncoder) - return box - - def create_round_plot(self): - """ - - :return: - """ - trace_data = [] - metrics = self.round_time.find_one({'key': 'round_time'}) - if metrics is None: - fig = go.Figure(data=[]) - fig.update_layout( - title_text='No data currently available for round time') - return False - - for post in self.round_time.find({'key': 'round_time'}): - rounds = post['round'] - traces_data = post['round_time'] - - trace_data.append(go.Scatter( - x=rounds, - y=traces_data, - mode='lines+markers', - name='Reducer' - )) - - for rec in self.combiner_round_time.find({'key': 'combiner_round_time'}): - c_traces_data = rec['round_time'] - - trace_data.append(go.Scatter( - x=rounds, - y=c_traces_data, - mode='lines+markers', - name='Combiner' - )) - - fig = go.Figure(data=trace_data) - fig.update_xaxes(title_text='Round') - fig.update_yaxes(title_text='Time (s)') - fig.update_layout(title_text='Round time') - round_t = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) - return round_t - - def create_cpu_plot(self): - """ - - :return: - """ - metrics = self.psutil_usage.find_one({'key': 'cpu_mem_usage'}) - if metrics is None: - fig = go.Figure(data=[]) - fig.update_layout( - title_text='No data currently available for MEM and CPU usage') - cpu = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) - return False - - for post in self.psutil_usage.find({'key': 'cpu_mem_usage'}): - cpu = post['cpu'] - mem = post['mem'] - ps_time = post['time'] - round = post['round'] - - # Create figure with secondary y-axis - fig = make_subplots(specs=[[{"secondary_y": True}]]) - fig.add_trace(go.Scatter( - x=ps_time, - y=cpu, - mode='lines+markers', - name='CPU (%)' - )) - - fig.add_trace(go.Scatter( - x=ps_time, - y=mem, - mode='lines+markers', - name='MEM (%)' - )) - - fig.add_trace(go.Scatter( - x=ps_time, - y=round, - mode='lines+markers', - name='Round', - ), secondary_y=True) - - fig.update_xaxes(title_text='Date Time') - fig.update_yaxes(title_text='Percentage (%)') - fig.update_yaxes(title_text="Round", secondary_y=True) - fig.update_layout(title_text='CPU loads and memory usage') - cpu = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) - return cpu diff --git a/fedn/fedn/network/__init__.py b/fedn/network/__init__.py similarity index 100% rename from fedn/fedn/network/__init__.py rename to fedn/network/__init__.py diff --git a/fedn/fedn/network/api/__init__.py b/fedn/network/api/__init__.py similarity index 100% rename from fedn/fedn/network/api/__init__.py rename to fedn/network/api/__init__.py diff --git a/fedn/fedn/network/api/auth.py b/fedn/network/api/auth.py similarity index 64% rename from fedn/fedn/network/api/auth.py rename to fedn/network/api/auth.py index 4ba958356..4780614d7 100644 --- a/fedn/fedn/network/api/auth.py +++ b/fedn/network/api/auth.py @@ -3,18 +3,22 @@ import jwt from flask import jsonify, request -from fedn.common.config import (FEDN_AUTH_SCHEME, - FEDN_AUTH_WHITELIST_URL_PREFIX, - FEDN_JWT_ALGORITHM, FEDN_JWT_CUSTOM_CLAIM_KEY, - FEDN_JWT_CUSTOM_CLAIM_VALUE, SECRET_KEY) from fedn.common.telemetry import tracer +from fedn.common.config import ( + FEDN_AUTH_SCHEME, + FEDN_AUTH_WHITELIST_URL_PREFIX, + FEDN_JWT_ALGORITHM, + FEDN_JWT_CUSTOM_CLAIM_KEY, + FEDN_JWT_CUSTOM_CLAIM_VALUE, + SECRET_KEY, +) @tracer.start_as_current_span(name="check_role_claims") def check_role_claims(payload, role): - if 'role' not in payload: + if "role" not in payload: return False - if payload['role'] != role: + if payload["role"] != role: return False return True @@ -46,30 +50,29 @@ def actual_decorator(func): def decorated(*args, **kwargs): if if_whitelisted_url_prefix(request.path): return func(*args, **kwargs) - token = request.headers.get('Authorization') + token = request.headers.get("Authorization") if not token: - return jsonify({'message': 'Missing token'}), 401 + return jsonify({"message": "Missing token"}), 401 # Get token from the header Bearer if token.startswith(FEDN_AUTH_SCHEME): - token = token.split(' ')[1] + token = token.split(" ")[1] else: - return jsonify({'message': - f'Invalid token scheme, expected {FEDN_AUTH_SCHEME}' - }), 401 + return jsonify({"message": f"Invalid token scheme, expected {FEDN_AUTH_SCHEME}"}), 401 try: payload = jwt.decode(token, SECRET_KEY, algorithms=[FEDN_JWT_ALGORITHM]) if not check_role_claims(payload, role): - return jsonify({'message': 'Invalid token'}), 401 + return jsonify({"message": "Invalid token"}), 401 if not check_custom_claims(payload): - return jsonify({'message': 'Invalid token'}), 401 + return jsonify({"message": "Invalid token"}), 401 except jwt.ExpiredSignatureError: - return jsonify({'message': 'Token expired'}), 401 + return jsonify({"message": "Token expired"}), 401 except jwt.InvalidTokenError: - return jsonify({'message': 'Invalid token'}), 401 + return jsonify({"message": "Invalid token"}), 401 return func(*args, **kwargs) return decorated + return actual_decorator diff --git a/fedn/fedn/network/api/client.py b/fedn/network/api/client.py similarity index 65% rename from fedn/fedn/network/api/client.py rename to fedn/network/api/client.py index 4901cfda6..b2e748656 100644 --- a/fedn/fedn/network/api/client.py +++ b/fedn/network/api/client.py @@ -4,12 +4,12 @@ from fedn.common.telemetry import trace_all_methods -__all__ = ['APIClient'] +__all__ = ["APIClient"] @trace_all_methods class APIClient: - """ An API client for interacting with the statestore and controller. + """An API client for interacting with the statestore and controller. :param host: The host of the api server. :type host: str @@ -40,34 +40,34 @@ def __init__(self, host, port=None, secure=False, verify=False, token=None, auth def _get_url(self, endpoint): if self.secure: - protocol = 'https' + protocol = "https" else: - protocol = 'http' + protocol = "http" if self.port: - return f'{protocol}://{self.host}:{self.port}/{endpoint}' - return f'{protocol}://{self.host}/{endpoint}' + return f"{protocol}://{self.host}:{self.port}/{endpoint}" + return f"{protocol}://{self.host}/{endpoint}" def _get_url_api_v1(self, endpoint): - return self._get_url(f'api/v1/{endpoint}') + return self._get_url(f"api/v1/{endpoint}") # --- Clients --- # def get_client(self, id: str): - """ Get a client from the statestore. + """Get a client from the statestore. :param id: The client id to get. :type id: str :return: Client. :rtype: dict """ - response = requests.get(self._get_url_api_v1(f'clients/{id}'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1(f"clients/{id}"), verify=self.verify, headers=self.headers) _json = response.json() return _json def get_clients(self, n_max: int = None): - """ Get clients from the statestore. + """Get clients from the statestore. :param n_max: The maximum number of clients to get (If none all will be fetched). :type n_max: int @@ -77,28 +77,28 @@ def get_clients(self, n_max: int = None): _headers = self.headers.copy() if n_max: - _headers['X-Limit'] = str(n_max) + _headers["X-Limit"] = str(n_max) - response = requests.get(self._get_url_api_v1('clients'), verify=self.verify, headers=_headers) + response = requests.get(self._get_url_api_v1("clients/"), verify=self.verify, headers=_headers) _json = response.json() return _json def get_clients_count(self): - """ Get the number of clients in the statestore. + """Get the number of clients in the statestore. :return: The number of clients. :rtype: dict """ - response = requests.get(self._get_url_api_v1('clients/count'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1("clients/count"), verify=self.verify, headers=self.headers) _json = response.json() return _json def get_client_config(self, checksum=True): - """ Get client config from controller. Optionally include the checksum. + """Get client config from controller. Optionally include the checksum. The config is used for clients to connect to the controller and ask for combiner assignment. :param checksum: Whether to include the checksum of the package. @@ -106,18 +106,16 @@ def get_client_config(self, checksum=True): :return: The client configuration. :rtype: dict """ - _params = { - 'checksum': "true" if checksum else "false" - } + _params = {"checksum": "true" if checksum else "false"} - response = requests.get(self._get_url('get_client_config'), params=_params, verify=self.verify, headers=self.headers) + response = requests.get(self._get_url("get_client_config"), params=_params, verify=self.verify, headers=self.headers) _json = response.json() return _json def get_active_clients(self, combiner_id: str = None, n_max: int = None): - """ Get active clients from the statestore. + """Get active clients from the statestore. :param combiner_id: The combiner id to get active clients for. :type combiner_id: str @@ -129,14 +127,14 @@ def get_active_clients(self, combiner_id: str = None, n_max: int = None): _params = {"status": "online"} if combiner_id: - _params['combiner'] = combiner_id + _params["combiner"] = combiner_id _headers = self.headers.copy() if n_max: - _headers['X-Limit'] = str(n_max) + _headers["X-Limit"] = str(n_max) - response = requests.get(self._get_url_api_v1('clients/'), params=_params, verify=self.verify, headers=_headers) + response = requests.get(self._get_url_api_v1("clients/"), params=_params, verify=self.verify, headers=_headers) _json = response.json() @@ -145,21 +143,21 @@ def get_active_clients(self, combiner_id: str = None, n_max: int = None): # --- Combiners --- # def get_combiner(self, id: str): - """ Get a combiner from the statestore. + """Get a combiner from the statestore. :param id: The combiner id to get. :type id: str :return: Combiner. :rtype: dict """ - response = requests.get(self._get_url_api_v1(f'combiners/{id}'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1(f"combiners/{id}"), verify=self.verify, headers=self.headers) _json = response.json() return _json def get_combiners(self, n_max: int = None): - """ Get combiners in the network. + """Get combiners in the network. :param n_max: The maximum number of combiners to get (If none all will be fetched). :type n_max: int @@ -169,21 +167,21 @@ def get_combiners(self, n_max: int = None): _headers = self.headers.copy() if n_max: - _headers['X-Limit'] = str(n_max) + _headers["X-Limit"] = str(n_max) - response = requests.get(self._get_url_api_v1('combiners/'), verify=self.verify, headers=_headers) + response = requests.get(self._get_url_api_v1("combiners/"), verify=self.verify, headers=_headers) _json = response.json() return _json def get_combiners_count(self): - """ Get the number of combiners in the statestore. + """Get the number of combiners in the statestore. :return: The number of combiners. :rtype: dict """ - response = requests.get(self._get_url_api_v1('combiners/count'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1("combiners/count"), verify=self.verify, headers=self.headers) _json = response.json() @@ -192,12 +190,12 @@ def get_combiners_count(self): # --- Controllers --- # def get_controller_status(self): - """ Get the status of the controller. + """Get the status of the controller. :return: The status of the controller. :rtype: dict """ - response = requests.get(self._get_url('get_controller_status'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url("get_controller_status"), verify=self.verify, headers=self.headers) _json = response.json() @@ -206,21 +204,21 @@ def get_controller_status(self): # --- Models --- # def get_model(self, id: str): - """ Get a model from the statestore. + """Get a model from the statestore. :param id: The id (or model property) of the model to get. :type id: str :return: Model. :rtype: dict """ - response = requests.get(self._get_url_api_v1(f'models/{id}'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1(f"models/{id}"), verify=self.verify, headers=self.headers) _json = response.json() return _json def get_models(self, session_id: str = None, n_max: int = None): - """ Get models from the statestore. + """Get models from the statestore. :param session_id: The session id to get models for. (optional) :type session_id: str @@ -232,41 +230,41 @@ def get_models(self, session_id: str = None, n_max: int = None): _params = {} if session_id: - _params['session_id'] = session_id + _params["session_id"] = session_id _headers = self.headers.copy() if n_max: - _headers['X-Limit'] = str(n_max) + _headers["X-Limit"] = str(n_max) - response = requests.get(self._get_url_api_v1('models/'), params=_params, verify=self.verify, headers=_headers) + response = requests.get(self._get_url_api_v1("models/"), params=_params, verify=self.verify, headers=_headers) _json = response.json() return _json def get_models_count(self): - """ Get the number of models in the statestore. + """Get the number of models in the statestore. :return: The number of models. :rtype: dict """ - response = requests.get(self._get_url_api_v1('models/count'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1("models/count"), verify=self.verify, headers=self.headers) _json = response.json() return _json def get_active_model(self): - """ Get the latest model from the statestore. + """Get the latest model from the statestore. :return: The latest model. :rtype: dict """ _headers = self.headers.copy() - _headers['X-Limit'] = "1" + _headers["X-Limit"] = "1" - response = requests.get(self._get_url_api_v1('models/'), verify=self.verify, headers=_headers) + response = requests.get(self._get_url_api_v1("models/"), verify=self.verify, headers=_headers) _json = response.json() if "result" in _json and len(_json["result"]) > 0: @@ -275,7 +273,7 @@ def get_active_model(self): return _json def get_model_trail(self, id: str = None, include_self: bool = True, reverse: bool = True, n_max: int = None): - """ Get the model trail. + """Get the model trail. :param id: The id (or model property) of the model to start the trail from. (optional) :type id: str @@ -294,18 +292,18 @@ def get_model_trail(self, id: str = None, include_self: bool = True, reverse: bo _headers = self.headers.copy() _count: int = n_max if n_max else self.get_models_count() - _headers['X-Limit'] = str(_count) - _headers['X-Reverse'] = "true" if reverse else "false" + _headers["X-Limit"] = str(_count) + _headers["X-Reverse"] = "true" if reverse else "false" _include_self_str: str = "true" if include_self else "false" - response = requests.get(self._get_url_api_v1(f'models/{id}/ancestors?include_self={_include_self_str}'), verify=self.verify, headers=_headers) + response = requests.get(self._get_url_api_v1(f"models/{id}/ancestors?include_self={_include_self_str}"), verify=self.verify, headers=_headers) _json = response.json() return _json def download_model(self, id: str, path: str): - """ Download the model with id id. + """Download the model with id id. :param id: The id (or model property) of the model to download. :type id: str @@ -314,46 +312,45 @@ def download_model(self, id: str, path: str): :return: Message with success or failure. :rtype: dict """ - response = requests.get(self._get_url_api_v1(f'models/{id}/download'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1(f"models/{id}/download"), verify=self.verify, headers=self.headers) if response.status_code == 200: - - with open(path, 'wb') as file: + with open(path, "wb") as file: file.write(response.content) - return {'success': True, 'message': 'Model downloaded successfully.'} + return {"success": True, "message": "Model downloaded successfully."} else: - return {'success': False, 'message': 'Failed to download model.'} + return {"success": False, "message": "Failed to download model."} def set_active_model(self, path): - """ Set the initial model in the statestore and upload to model repository. + """Set the initial model in the statestore and upload to model repository. :param path: The file path of the initial model to set. :type path: str :return: A dict with success or failure message. :rtype: dict """ - with open(path, 'rb') as file: - response = requests.post(self._get_url('set_initial_model'), files={'file': file}, verify=self.verify, headers=self.headers) + with open(path, "rb") as file: + response = requests.post(self._get_url("set_initial_model"), files={"file": file}, verify=self.verify, headers=self.headers) return response.json() # --- Packages --- # def get_package(self, id: str): - """ Get a compute package from the statestore. + """Get a compute package from the statestore. :param id: The id of the compute package to get. :type id: str :return: Package. :rtype: dict """ - response = requests.get(self._get_url_api_v1(f'packages/{id}'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1(f"packages/{id}"), verify=self.verify, headers=self.headers) _json = response.json() return _json def get_packages(self, n_max: int = None): - """ Get compute packages from the statestore. + """Get compute packages from the statestore. :param n_max: The maximum number of packages to get (If none all will be fetched). :type n_max: int @@ -363,68 +360,68 @@ def get_packages(self, n_max: int = None): _headers = self.headers.copy() if n_max: - _headers['X-Limit'] = str(n_max) + _headers["X-Limit"] = str(n_max) - response = requests.get(self._get_url_api_v1('packages/'), verify=self.verify, headers=_headers) + response = requests.get(self._get_url_api_v1("packages/"), verify=self.verify, headers=_headers) _json = response.json() return _json def get_packages_count(self): - """ Get the number of compute packages in the statestore. + """Get the number of compute packages in the statestore. :return: The number of packages. :rtype: dict """ - response = requests.get(self._get_url_api_v1('packages/count'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1("packages/count"), verify=self.verify, headers=self.headers) _json = response.json() return _json def get_active_package(self): - """ Get the (active) compute package from the statestore. + """Get the (active) compute package from the statestore. :return: Package. :rtype: dict """ - response = requests.get(self._get_url_api_v1('packages/active'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1("packages/active"), verify=self.verify, headers=self.headers) _json = response.json() return _json def get_package_checksum(self): - """ Get the checksum of the compute package. + """Get the checksum of the compute package. :return: The checksum. :rtype: dict """ - response = requests.get(self._get_url('get_package_checksum'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url("get_package_checksum"), verify=self.verify, headers=self.headers) _json = response.json() return _json def download_package(self, path: str): - """ Download the compute package. + """Download the compute package. :param path: The path to download the compute package to. :type path: str :return: Message with success or failure. :rtype: dict """ - response = requests.get(self._get_url('download_package'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url("download_package"), verify=self.verify, headers=self.headers) if response.status_code == 200: - with open(path, 'wb') as file: + with open(path, "wb") as file: file.write(response.content) - return {'success': True, 'message': 'Package downloaded successfully.'} + return {"success": True, "message": "Package downloaded successfully."} else: - return {'success': False, 'message': 'Failed to download package.'} + return {"success": False, "message": "Failed to download package."} def set_active_package(self, path: str, helper: str, name: str = None, description: str = None): - """ Set the compute package in the statestore. + """Set the compute package in the statestore. :param path: The file path of the compute package to set. :type path: str @@ -433,9 +430,14 @@ def set_active_package(self, path: str, helper: str, name: str = None, descripti :return: A dict with success or failure message. :rtype: dict """ - with open(path, 'rb') as file: - response = requests.post(self._get_url('set_package'), files={'file': file}, data={ - 'helper': helper, 'name': name, 'description': description}, verify=self.verify, headers=self.headers) + with open(path, "rb") as file: + response = requests.post( + self._get_url("set_package"), + files={"file": file}, + data={"helper": helper, "name": name, "description": description}, + verify=self.verify, + headers=self.headers, + ) _json = response.json() @@ -444,21 +446,21 @@ def set_active_package(self, path: str, helper: str, name: str = None, descripti # --- Rounds --- # def get_round(self, id: str): - """ Get a round from the statestore. + """Get a round from the statestore. :param round_id: The round id to get. :type round_id: str :return: Round (config and metrics). :rtype: dict """ - response = requests.get(self._get_url_api_v1(f'rounds/{id}'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1(f"rounds/{id}"), verify=self.verify, headers=self.headers) _json = response.json() return _json def get_rounds(self, n_max: int = None): - """ Get all rounds from the statestore. + """Get all rounds from the statestore. :param n_max: The maximum number of rounds to get (If none all will be fetched). :type n_max: int @@ -468,21 +470,21 @@ def get_rounds(self, n_max: int = None): _headers = self.headers.copy() if n_max: - _headers['X-Limit'] = str(n_max) + _headers["X-Limit"] = str(n_max) - response = requests.get(self._get_url_api_v1('rounds/'), verify=self.verify, headers=_headers) + response = requests.get(self._get_url_api_v1("rounds/"), verify=self.verify, headers=_headers) _json = response.json() return _json def get_rounds_count(self): - """ Get the number of rounds in the statestore. + """Get the number of rounds in the statestore. :return: The number of rounds. :rtype: dict """ - response = requests.get(self._get_url_api_v1('rounds/count'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1("rounds/count"), verify=self.verify, headers=self.headers) _json = response.json() @@ -491,7 +493,7 @@ def get_rounds_count(self): # --- Sessions --- # def get_session(self, id: str): - """ Get a session from the statestore. + """Get a session from the statestore. :param id: The session id to get. :type id: str @@ -499,14 +501,14 @@ def get_session(self, id: str): :rtype: dict """ - response = requests.get(self._get_url_api_v1(f'sessions/{id}'), self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1(f"sessions/{id}"), self.verify, headers=self.headers) _json = response.json() return _json def get_sessions(self, n_max: int = None): - """ Get sessions from the statestore. + """Get sessions from the statestore. :param n_max: The maximum number of sessions to get (If none all will be fetched). :type n_max: int @@ -516,28 +518,28 @@ def get_sessions(self, n_max: int = None): _headers = self.headers.copy() if n_max: - _headers['X-Limit'] = str(n_max) + _headers["X-Limit"] = str(n_max) - response = requests.get(self._get_url_api_v1('sessions/'), verify=self.verify, headers=_headers) + response = requests.get(self._get_url_api_v1("sessions/"), verify=self.verify, headers=_headers) _json = response.json() return _json def get_sessions_count(self): - """ Get the number of sessions in the statestore. + """Get the number of sessions in the statestore. :return: The number of sessions. :rtype: dict """ - response = requests.get(self._get_url_api_v1('sessions/count'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1("sessions/count"), verify=self.verify, headers=self.headers) _json = response.json() return _json def get_session_status(self, id: str): - """ Get the status of a session. + """Get the status of a session. :param id: The id of the session to get. :type id: str @@ -552,7 +554,7 @@ def get_session_status(self, id: str): return "Could not retrieve session status." def session_is_finished(self, id: str): - """ Check if a session with id has finished. + """Check if a session with id has finished. :param id: The id of the session to get. :type id: str @@ -563,17 +565,21 @@ def session_is_finished(self, id: str): return status and status.lower() == "finished" def start_session( - self, - id: str = None, - aggregator: str = 'fedavg', - model_id: str = None, - round_timeout: int = 180, - rounds: int = 5, - round_buffer_size: int = -1, - delete_models: bool = True, - validate: bool = True, helper: str = 'numpyhelper', min_clients: int = 1, requested_clients: int = 8 + self, + id: str = None, + aggregator: str = "fedavg", + aggregator_kwargs: dict = None, + model_id: str = None, + round_timeout: int = 180, + rounds: int = 5, + round_buffer_size: int = -1, + delete_models: bool = True, + validate: bool = True, + helper: str = "numpyhelper", + min_clients: int = 1, + requested_clients: int = 8, ): - """ Start a new session. + """Start a new session. :param id: The session id to start. :type id: str @@ -600,19 +606,25 @@ def start_session( :return: A dict with success or failure message and session config. :rtype: dict """ - response = requests.post(self._get_url('start_session'), json={ - 'session_id': id, - 'aggregator': aggregator, - 'model_id': model_id, - 'round_timeout': round_timeout, - 'rounds': rounds, - 'round_buffer_size': round_buffer_size, - 'delete_models': delete_models, - 'validate': validate, - 'helper': helper, - 'min_clients': min_clients, - 'requested_clients': requested_clients - }, verify=self.verify, headers=self.headers) + response = requests.post( + self._get_url("start_session"), + json={ + "session_id": id, + "aggregator": aggregator, + "aggregator_kwargs": aggregator_kwargs, + "model_id": model_id, + "round_timeout": round_timeout, + "rounds": rounds, + "round_buffer_size": round_buffer_size, + "delete_models": delete_models, + "validate": validate, + "helper": helper, + "min_clients": min_clients, + "requested_clients": requested_clients, + }, + verify=self.verify, + headers=self.headers, + ) _json = response.json() @@ -621,21 +633,21 @@ def start_session( # --- Statuses --- # def get_status(self, id: str): - """ Get a status object (event) from the statestore. + """Get a status object (event) from the statestore. :param id: The id of the status to get. :type id: str :return: Status. :rtype: dict """ - response = requests.get(self._get_url_api_v1(f'statuses/{id}'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1(f"statuses/{id}"), verify=self.verify, headers=self.headers) _json = response.json() return _json def get_statuses(self, session_id: str = None, event_type: str = None, sender_name: str = None, sender_role: str = None, n_max: int = None): - """ Get statuses from the statestore. Filter by input parameters + """Get statuses from the statestore. Filter by input parameters :param session_id: The session id to get statuses for. :type session_id: str @@ -666,21 +678,21 @@ def get_statuses(self, session_id: str = None, event_type: str = None, sender_na _headers = self.headers.copy() if n_max: - _headers['X-Limit'] = str(n_max) + _headers["X-Limit"] = str(n_max) - response = requests.get(self._get_url_api_v1('statuses/'), params=_params, verify=self.verify, headers=_headers) + response = requests.get(self._get_url_api_v1("statuses/"), params=_params, verify=self.verify, headers=_headers) _json = response.json() return _json def get_statuses_count(self): - """ Get the number of statuses in the statestore. + """Get the number of statuses in the statestore. :return: The number of statuses. :rtype: dict """ - response = requests.get(self._get_url_api_v1('statuses/count'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1("statuses/count"), verify=self.verify, headers=self.headers) _json = response.json() @@ -689,14 +701,14 @@ def get_statuses_count(self): # --- Validations --- # def get_validation(self, id: str): - """ Get a validation from the statestore. + """Get a validation from the statestore. :param id: The id of the validation to get. :type id: str :return: Validation. :rtype: dict """ - response = requests.get(self._get_url_api_v1(f'validations/{id}'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1(f"validations/{id}"), verify=self.verify, headers=self.headers) _json = response.json() @@ -711,9 +723,9 @@ def get_validations( sender_role: str = None, receiver_name: str = None, receiver_role: str = None, - n_max: int = None + n_max: int = None, ): - """ Get validations from the statestore. Filter by input parameters. + """Get validations from the statestore. Filter by input parameters. :param session_id: The session id to get validations for. :type session_id: str @@ -760,21 +772,21 @@ def get_validations( _headers = self.headers.copy() if n_max: - _headers['X-Limit'] = str(n_max) + _headers["X-Limit"] = str(n_max) - response = requests.get(self._get_url_api_v1('validations/'), params=_params, verify=self.verify, headers=_headers) + response = requests.get(self._get_url_api_v1("validations/"), params=_params, verify=self.verify, headers=_headers) _json = response.json() return _json def get_validations_count(self): - """ Get the number of validations in the statestore. + """Get the number of validations in the statestore. :return: The number of validations. :rtype: dict """ - response = requests.get(self._get_url_api_v1('validations/count'), verify=self.verify, headers=self.headers) + response = requests.get(self._get_url_api_v1("validations/count"), verify=self.verify, headers=self.headers) _json = response.json() diff --git a/fedn/fedn/network/api/interface.py b/fedn/network/api/interface.py similarity index 92% rename from fedn/fedn/network/api/interface.py rename to fedn/network/api/interface.py index f65cb7ef2..cdf29752c 100644 --- a/fedn/fedn/network/api/interface.py +++ b/fedn/network/api/interface.py @@ -11,8 +11,7 @@ from fedn.common.config import get_controller_config, get_network_config from fedn.common.log_config import logger from fedn.common.telemetry import trace_all_methods -from fedn.network.combiner.interfaces import (CombinerInterface, - CombinerUnavailableError) +from fedn.network.combiner.interfaces import CombinerInterface, CombinerUnavailableError from fedn.network.state import ReducerState, ReducerStateToString from fedn.utils.checksum import sha from fedn.utils.plots import Plot @@ -38,9 +37,7 @@ def _to_dict(self): data = {"name": self.name} return data - def _allowed_file_extension( - self, filename, ALLOWED_EXTENSIONS={"gz", "bz2", "tar", "zip", "tgz"} - ): + def _allowed_file_extension(self, filename, ALLOWED_EXTENSIONS={"gz", "bz2", "tar", "zip", "tgz"}): """Check if file extension is allowed. :param filename: The filename to check. @@ -172,11 +169,10 @@ def get_session(self, session_id): info = session_object["session_config"][0] status = session_object["status"] payload[id] = info - payload['status'] = status + payload["status"] = status return jsonify(payload) def set_active_compute_package(self, id: str): - success = self.statestore.set_active_compute_package(id) if not success: @@ -201,16 +197,12 @@ def set_compute_package(self, file, helper_type: str, name: str = None, descript :rtype: :class:`flask.Response` """ - if ( - self.control.state() == ReducerState.instructing - or self.control.state() == ReducerState.monitoring - ): + if self.control.state() == ReducerState.instructing or self.control.state() == ReducerState.monitoring: return ( jsonify( { "success": False, - "message": "Reducer is in instructing or monitoring state." - "Cannot set compute package.", + "message": "Reducer is in instructing or monitoring state." "Cannot set compute package.", } ), 400, @@ -290,9 +282,7 @@ def get_compute_package(self): result = self.statestore.get_compute_package() if result is None: return ( - jsonify( - {"success": False, "message": "No compute package found."} - ), + jsonify({"success": False, "message": "No compute package found."}), 404, ) @@ -329,9 +319,7 @@ def list_compute_packages(self, limit: str = None, skip: str = None, include_act result = self.statestore.list_compute_packages(limit, skip) if result is None: return ( - jsonify( - {"success": False, "message": "No compute packages found."} - ), + jsonify({"success": False, "message": "No compute packages found."}), 404, ) @@ -388,9 +376,7 @@ def download_compute_package(self, name): mutex = threading.Lock() mutex.acquire() # TODO: make configurable, perhaps in config.py or package.py - return send_from_directory( - "/app/client/package/", name, as_attachment=True - ) + return send_from_directory("/app/client/package/", name, as_attachment=True) except Exception: try: data = self.control.get_compute_package(name) @@ -399,9 +385,7 @@ def download_compute_package(self, name): with open(file_path, "wb") as fh: fh.write(data) # TODO: make configurable, perhaps in config.py or package.py - return send_from_directory( - "/app/client/package/", name, as_attachment=True - ) + return send_from_directory("/app/client/package/", name, as_attachment=True) except Exception: raise finally: @@ -420,9 +404,7 @@ def _create_checksum(self, name=None): name, message = self._get_compute_package_name() if name is None: return False, message, "" - file_path = os.path.join( - "/app/client/package/", name - ) # TODO: make configurable, perhaps in config.py or package.py + file_path = os.path.join("/app/client/package/", name) # TODO: make configurable, perhaps in config.py or package.py try: sum = str(sha(file_path)) except FileNotFoundError: @@ -507,9 +489,7 @@ def get_all_validations(self, **kwargs): payload[id] = info return jsonify(payload) - def add_combiner( - self, combiner_id, secure_grpc, address, remote_addr, fqdn, port - ): + def add_combiner(self, combiner_id, secure_grpc, address, remote_addr, fqdn, port): """Add a combiner to the network. :param combiner_id: The combiner id to add. @@ -542,9 +522,7 @@ def add_combiner( combiner = self.control.network.get_combiner(combiner_id) if not combiner: if secure_grpc == "True": - certificate, key = self.certificate_manager.get_or_create( - address - ).get_keypair_raw() + certificate, key = self.certificate_manager.get_or_create(address).get_keypair_raw() _ = base64.b64encode(certificate) _ = base64.b64encode(key) @@ -568,9 +546,7 @@ def add_combiner( # Check combiner now exists combiner = self.control.network.get_combiner(combiner_id) if not combiner: - return jsonify( - {"success": False, "message": "Combiner not added."} - ) + return jsonify({"success": False, "message": "Combiner not added."}) payload = { "success": True, @@ -625,9 +601,7 @@ def add_client(self, client_id, preferred_combiner, remote_addr): combiner = self.control.network.find_available_combiner() if combiner is None: return ( - jsonify( - {"success": False, "message": "No combiner available."} - ), + jsonify({"success": False, "message": "No combiner available."}), 400, ) @@ -693,9 +667,7 @@ def set_initial_model(self, file): logger.debug(e) return jsonify({"success": False, "message": e}) - return jsonify( - {"success": True, "message": "Initial model added successfully."} - ) + return jsonify({"success": True, "message": "Initial model added successfully."}) def get_latest_model(self): """Get the latest model from the statestore. @@ -708,9 +680,7 @@ def get_latest_model(self): payload = {"model_id": model_id} return jsonify(payload) else: - return jsonify( - {"success": False, "message": "No initial model set."} - ) + return jsonify({"success": False, "message": "No initial model set."}) def set_current_model(self, model_id: str): """Set the active model in the statestore. @@ -747,7 +717,6 @@ def get_models(self, session_id: str = None, limit: str = None, skip: str = None include_active: bool = include_active == "true" if include_active: - latest_model = self.statestore.get_latest_model() arr = [ @@ -803,9 +772,7 @@ def get_model_trail(self): if model_info: return jsonify(model_info) else: - return jsonify( - {"success": False, "message": "No model trail available."} - ) + return jsonify({"success": False, "message": "No model trail available."}) def get_model_ancestors(self, model_id: str, limit: str = None): """Get the model ancestors for a given model. @@ -818,15 +785,12 @@ def get_model_ancestors(self, model_id: str, limit: str = None): :rtype: :class:`flask.Response` """ if model_id is None: - return jsonify( - {"success": False, "message": "No model id provided."} - ) + return jsonify({"success": False, "message": "No model id provided."}) limit: int = int(limit) if limit is not None else 10 # if limit is None, default to 10 response = self.statestore.get_model_ancestors(model_id, limit) if response: - arr: list = [] for element in response: @@ -842,9 +806,7 @@ def get_model_ancestors(self, model_id: str, limit: str = None): return jsonify(result) else: - return jsonify( - {"success": False, "message": "No model ancestors available."} - ) + return jsonify({"success": False, "message": "No model ancestors available."}) def get_model_descendants(self, model_id: str, limit: str = None): """Get the model descendants for a given model. @@ -858,16 +820,13 @@ def get_model_descendants(self, model_id: str, limit: str = None): """ if model_id is None: - return jsonify( - {"success": False, "message": "No model id provided."} - ) + return jsonify({"success": False, "message": "No model id provided."}) limit: int = int(limit) if limit is not None else 10 response: list = self.statestore.get_model_descendants(model_id, limit) if response: - arr: list = [] for element in response: @@ -883,9 +842,7 @@ def get_model_descendants(self, model_id: str, limit: str = None): return jsonify(result) else: - return jsonify( - {"success": False, "message": "No model descendants available."} - ) + return jsonify({"success": False, "message": "No model descendants available."}) def get_all_rounds(self): """Get all rounds. @@ -928,8 +885,8 @@ def get_round(self, round_id): if round_object is None: return jsonify({"success": False, "message": "Round not found."}) payload = { - 'round_id': round_object['round_id'], - 'combiners': round_object['combiners'], + "round_id": round_object["round_id"], + "combiners": round_object["combiners"], } return jsonify(payload) @@ -994,7 +951,6 @@ def list_combiners_data(self, combiners): # order list by combiner name for element in response: - obj = { "combiner": element["_id"], "count": element["count"], @@ -1009,7 +965,8 @@ def list_combiners_data(self, combiners): def start_session( self, session_id, - aggregator='fedavg', + aggregator="fedavg", + aggregator_kwargs=None, model_id=None, rounds=5, round_timeout=180, @@ -1048,15 +1005,11 @@ def start_session( # Check if session already exists session = self.statestore.get_session(session_id) if session: - return jsonify( - {"success": False, "message": "Session already exists."} - ) + return jsonify({"success": False, "message": "Session already exists."}) # Check if session is running if self.control.state() == ReducerState.monitoring: - return jsonify( - {"success": False, "message": "A session is already running."} - ) + return jsonify({"success": False, "message": "A session is already running."}) # Check if compute package is set if not self.statestore.get_compute_package(): @@ -1110,6 +1063,7 @@ def start_session( session_config = { "session_id": session_id if session_id else str(uuid.uuid4()), "aggregator": aggregator, + "aggregator_kwargs": aggregator_kwargs, "round_timeout": round_timeout, "buffer_size": round_buffer_size, "model_id": model_id, @@ -1123,9 +1077,7 @@ def start_session( } # Start session - threading.Thread( - target=self.control.session, args=(session_config,) - ).start() + threading.Thread(target=self.control.session, args=(session_config,)).start() # Return success response return jsonify( diff --git a/fedn/fedn/network/api/network.py b/fedn/network/api/network.py similarity index 75% rename from fedn/fedn/network/api/network.py rename to fedn/network/api/network.py index 88c7e4dc9..4cf8a446d 100644 --- a/fedn/fedn/network/api/network.py +++ b/fedn/network/api/network.py @@ -5,15 +5,15 @@ from fedn.network.combiner.interfaces import CombinerInterface from fedn.network.loadbalancer.leastpacked import LeastPacked -__all__ = 'Network', +__all__ = ("Network",) @trace_all_methods class Network: - """ FEDn network interface. This class is used to interact with the network. - Note: This class contain redundant code, which is not used in the current version of FEDn. - Some methods has been moved to :class:`fedn.network.api.interface.API`. - """ + """FEDn network interface. This class is used to interact with the network. + Note: This class contain redundant code, which is not used in the current version of FEDn. + Some methods has been moved to :class:`fedn.network.api.interface.API`. + """ def __init__(self, control, statestore, load_balancer=None): """ """ @@ -27,7 +27,7 @@ def __init__(self, control, statestore, load_balancer=None): self.load_balancer = load_balancer def get_combiner(self, name): - """ Get combiner by name. + """Get combiner by name. :param name: name of combiner :type name: str @@ -41,7 +41,7 @@ def get_combiner(self, name): return None def get_combiners(self): - """ Get all combiners in the network. + """Get all combiners in the network. :return: list of combiners objects :rtype: list(:class:`fedn.network.combiner.interfaces.CombinerInterface`) @@ -49,21 +49,19 @@ def get_combiners(self): data = self.statestore.get_combiners() combiners = [] for c in data["result"]: - if c['certificate']: - cert = base64.b64decode(c['certificate']) - key = base64.b64decode(c['key']) + if c["certificate"]: + cert = base64.b64decode(c["certificate"]) + key = base64.b64decode(c["key"]) else: cert = None key = None - combiners.append( - CombinerInterface(c['parent'], c['name'], c['address'], c['fqdn'], c['port'], - certificate=cert, key=key, ip=c['ip'])) + combiners.append(CombinerInterface(c["parent"], c["name"], c["address"], c["fqdn"], c["port"], certificate=cert, key=key, ip=c["ip"])) return combiners def add_combiner(self, combiner): - """ Add a new combiner to the network. + """Add a new combiner to the network. :param combiner: The combiner instance object :type combiner: :class:`fedn.network.combiner.interfaces.CombinerInterface` @@ -80,7 +78,7 @@ def add_combiner(self, combiner): self.statestore.set_combiner(combiner.to_dict()) def remove_combiner(self, combiner): - """ Remove a combiner from the network. + """Remove a combiner from the network. :param combiner: The combiner instance object :type combiner: :class:`fedn.network.combiner.interfaces.CombinerInterface` @@ -92,7 +90,7 @@ def remove_combiner(self, combiner): self.statestore.delete_combiner(combiner.name) def find_available_combiner(self): - """ Find an available combiner in the network. + """Find an available combiner in the network. :return: The combiner instance object :rtype: :class:`fedn.network.combiner.interfaces.CombinerInterface` @@ -101,32 +99,31 @@ def find_available_combiner(self): return combiner def handle_unavailable_combiner(self, combiner): - """ This callback is triggered if a combiner is found to be unresponsive. + """This callback is triggered if a combiner is found to be unresponsive. :param combiner: The combiner instance object :type combiner: :class:`fedn.network.combiner.interfaces.CombinerInterface` :return: None """ # TODO: Implement strategy to handle an unavailable combiner. - logger.warning("REDUCER CONTROL: Combiner {} unavailable.".format( - combiner.name)) + logger.warning("REDUCER CONTROL: Combiner {} unavailable.".format(combiner.name)) def add_client(self, client): - """ Add a new client to the network. + """Add a new client to the network. :param client: The client instance object :type client: dict :return: None """ - if self.get_client(client['name']): + if self.get_client(client["name"]): return - logger.info("adding client {}".format(client['name'])) + logger.info("adding client {}".format(client["name"])) self.statestore.set_client(client) def get_client(self, name): - """ Get client by name. + """Get client by name. :param name: name of client :type name: str @@ -137,7 +134,7 @@ def get_client(self, name): return ret def update_client_data(self, client_data, status, role): - """ Update client status in statestore. + """Update client status in statestore. :param client_data: The client instance object :type client_data: dict @@ -150,7 +147,7 @@ def update_client_data(self, client_data, status, role): self.statestore.update_client_status(client_data, status, role) def get_client_info(self): - """ list available client in statestore. + """List available client in statestore. :return: list of client objects :rtype: list(ObjectId) diff --git a/fedn/fedn/network/api/server.py b/fedn/network/api/server.py similarity index 96% rename from fedn/fedn/network/api/server.py rename to fedn/network/api/server.py index 45cf410ce..5f645e4e2 100644 --- a/fedn/fedn/network/api/server.py +++ b/fedn/network/api/server.py @@ -2,8 +2,7 @@ from flask import Flask, jsonify, request -from fedn.common.config import (get_controller_config, get_modelstorage_config, - get_network_config, get_statestore_config) +from fedn.common.config import get_controller_config, get_modelstorage_config, get_network_config, get_statestore_config from fedn.network.api.auth import jwt_auth_required from fedn.network.api.interface import API from fedn.network.api.v1 import _routes @@ -23,14 +22,12 @@ for bp in _routes: app.register_blueprint(bp) if custom_url_prefix: - app.register_blueprint(bp, - name=f"{bp.name}_custom", - url_prefix=f"{custom_url_prefix}{bp.url_prefix}") + app.register_blueprint(bp, name=f"{bp.name}_custom", url_prefix=f"{custom_url_prefix}{bp.url_prefix}") -@app.route('/health', methods=["GET"]) +@app.route("/health", methods=["GET"]) def health_check(): - return 'OK', 200 + return "OK", 200 if custom_url_prefix: @@ -364,9 +361,7 @@ def set_package(): file = request.files["file"] except KeyError: return jsonify({"success": False, "message": "Missing file."}), 400 - return api.set_compute_package( - file=file, helper_type=helper_type, name=name, description=description - ) + return api.set_compute_package(file=file, helper_type=helper_type, name=name, description=description) if custom_url_prefix: @@ -399,9 +394,7 @@ def list_compute_packages(): skip = request.args.get("skip", None) include_active = request.args.get("include_active", None) - return api.list_compute_packages( - limit=limit, skip=skip, include_active=include_active - ) + return api.list_compute_packages(limit=limit, skip=skip, include_active=include_active) if custom_url_prefix: diff --git a/fedn/fedn/network/api/tests.py b/fedn/network/api/tests.py similarity index 100% rename from fedn/fedn/network/api/tests.py rename to fedn/network/api/tests.py diff --git a/fedn/fedn/network/api/v1/__init__.py b/fedn/network/api/v1/__init__.py similarity index 100% rename from fedn/fedn/network/api/v1/__init__.py rename to fedn/network/api/v1/__init__.py diff --git a/fedn/fedn/network/api/v1/client_routes.py b/fedn/network/api/v1/client_routes.py similarity index 96% rename from fedn/fedn/network/api/v1/client_routes.py rename to fedn/network/api/v1/client_routes.py index 30322a9b7..d5ccc58ee 100644 --- a/fedn/fedn/network/api/v1/client_routes.py +++ b/fedn/network/api/v1/client_routes.py @@ -1,8 +1,7 @@ from flask import Blueprint, jsonify, request from fedn.network.api.auth import jwt_auth_required -from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs, - get_typed_list_headers, mdb) +from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers, mdb from fedn.network.storage.statestore.stores.client_store import ClientStore from fedn.network.storage.statestore.stores.shared import EntityNotFound @@ -115,9 +114,7 @@ def get_clients(): limit, skip, sort_key, sort_order, _ = get_typed_list_headers(request.headers) kwargs = request.args.to_dict() - clients = client_store.list( - limit, skip, sort_key, sort_order, use_typing=False, **kwargs - ) + clients = client_store.list(limit, skip, sort_key, sort_order, use_typing=False, **kwargs) result = clients["result"] @@ -202,9 +199,7 @@ def list_clients(): try: limit, skip, sort_key, sort_order, _ = get_typed_list_headers(request.headers) kwargs = get_post_data_to_kwargs(request) - clients = client_store.list( - limit, skip, sort_key, sort_order, use_typing=False, **kwargs - ) + clients = client_store.list(limit, skip, sort_key, sort_order, use_typing=False, **kwargs) result = clients["result"] diff --git a/fedn/fedn/network/api/v1/combiner_routes.py b/fedn/network/api/v1/combiner_routes.py similarity index 96% rename from fedn/fedn/network/api/v1/combiner_routes.py rename to fedn/network/api/v1/combiner_routes.py index 7d1761bee..1f9360461 100644 --- a/fedn/fedn/network/api/v1/combiner_routes.py +++ b/fedn/network/api/v1/combiner_routes.py @@ -1,8 +1,7 @@ from flask import Blueprint, jsonify, request from fedn.network.api.auth import jwt_auth_required -from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs, - get_typed_list_headers, mdb) +from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers, mdb from fedn.network.storage.statestore.stores.combiner_store import CombinerStore from fedn.network.storage.statestore.stores.shared import EntityNotFound @@ -107,9 +106,7 @@ def get_combiners(): kwargs = request.args.to_dict() - combiners = combiner_store.list( - limit, skip, sort_key, sort_order, use_typing=False, **kwargs - ) + combiners = combiner_store.list(limit, skip, sort_key, sort_order, use_typing=False, **kwargs) result = combiners["result"] @@ -192,9 +189,7 @@ def list_combiners(): kwargs = get_post_data_to_kwargs(request) - combiners = combiner_store.list( - limit, skip, sort_key, sort_order, use_typing=False, **kwargs - ) + combiners = combiner_store.list(limit, skip, sort_key, sort_order, use_typing=False, **kwargs) result = combiners["result"] diff --git a/fedn/fedn/network/api/v1/model_routes.py b/fedn/network/api/v1/model_routes.py similarity index 96% rename from fedn/fedn/network/api/v1/model_routes.py rename to fedn/network/api/v1/model_routes.py index 8e9308408..f9708a149 100644 --- a/fedn/fedn/network/api/v1/model_routes.py +++ b/fedn/network/api/v1/model_routes.py @@ -4,10 +4,7 @@ from flask import Blueprint, jsonify, request, send_file from fedn.network.api.auth import jwt_auth_required -from fedn.network.api.v1.shared import (api_version, get_limit, - get_post_data_to_kwargs, get_reverse, - get_typed_list_headers, mdb, - modelstorage_config) +from fedn.network.api.v1.shared import api_version, get_limit, get_post_data_to_kwargs, get_reverse, get_typed_list_headers, mdb, modelstorage_config from fedn.network.storage.s3.base import RepositoryBase from fedn.network.storage.s3.miniorepository import MINIORepository from fedn.network.storage.statestore.stores.model_store import ModelStore @@ -112,9 +109,7 @@ def get_models(): limit, skip, sort_key, sort_order, _ = get_typed_list_headers(request.headers) kwargs = request.args.to_dict() - models = model_store.list( - limit, skip, sort_key, sort_order, use_typing=False, **kwargs - ) + models = model_store.list(limit, skip, sort_key, sort_order, use_typing=False, **kwargs) result = models["result"] @@ -199,9 +194,7 @@ def list_models(): limit, skip, sort_key, sort_order, _ = get_typed_list_headers(request.headers) kwargs = get_post_data_to_kwargs(request) - models = model_store.list( - limit, skip, sort_key, sort_order, use_typing=False, **kwargs - ) + models = model_store.list(limit, skip, sort_key, sort_order, use_typing=False, **kwargs) result = models["result"] @@ -466,7 +459,7 @@ def get_ancestors(id: str): try: limit = get_limit(request.headers) reverse = get_reverse(request.headers) - include_self_param: str = request.args.get('include_self') + include_self_param: str = request.args.get("include_self") include_self: bool = include_self_param and include_self_param.lower() == "true" diff --git a/fedn/fedn/network/api/v1/package_routes.py b/fedn/network/api/v1/package_routes.py similarity index 96% rename from fedn/fedn/network/api/v1/package_routes.py rename to fedn/network/api/v1/package_routes.py index 30ac4d51e..65783f54b 100644 --- a/fedn/fedn/network/api/v1/package_routes.py +++ b/fedn/network/api/v1/package_routes.py @@ -1,9 +1,7 @@ from flask import Blueprint, jsonify, request from fedn.network.api.auth import jwt_auth_required -from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs, - get_typed_list_headers, get_use_typing, - mdb) +from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers, get_use_typing, mdb from fedn.network.storage.statestore.stores.package_store import PackageStore from fedn.network.storage.statestore.stores.shared import EntityNotFound @@ -120,9 +118,7 @@ def get_packages(): limit, skip, sort_key, sort_order, _ = get_typed_list_headers(request.headers) kwargs = request.args.to_dict() - packages = package_store.list( - limit, skip, sort_key, sort_order, use_typing=True, **kwargs - ) + packages = package_store.list(limit, skip, sort_key, sort_order, use_typing=True, **kwargs) result = [package.__dict__ for package in packages["result"]] @@ -210,9 +206,7 @@ def list_packages(): limit, skip, sort_key, sort_order, _ = get_typed_list_headers(request.headers) kwargs = get_post_data_to_kwargs(request) - packages = package_store.list( - limit, skip, sort_key, sort_order, use_typing=True, **kwargs - ) + packages = package_store.list(limit, skip, sort_key, sort_order, use_typing=True, **kwargs) result = [package.__dict__ for package in packages["result"]] diff --git a/fedn/fedn/network/api/v1/round_routes.py b/fedn/network/api/v1/round_routes.py similarity index 95% rename from fedn/fedn/network/api/v1/round_routes.py rename to fedn/network/api/v1/round_routes.py index 8890c510a..4c2eb0c44 100644 --- a/fedn/fedn/network/api/v1/round_routes.py +++ b/fedn/network/api/v1/round_routes.py @@ -1,8 +1,7 @@ from flask import Blueprint, jsonify, request from fedn.network.api.auth import jwt_auth_required -from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs, - get_typed_list_headers, mdb) +from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers, mdb from fedn.network.storage.statestore.stores.round_store import RoundStore from fedn.network.storage.statestore.stores.shared import EntityNotFound @@ -95,9 +94,7 @@ def get_rounds(): kwargs = request.args.to_dict() - rounds = round_store.list( - limit, skip, sort_key, sort_order, use_typing=False, **kwargs - ) + rounds = round_store.list(limit, skip, sort_key, sort_order, use_typing=False, **kwargs) result = rounds["result"] @@ -176,9 +173,7 @@ def list_rounds(): kwargs = get_post_data_to_kwargs(request) - rounds = round_store.list( - limit, skip, sort_key, sort_order, use_typing=False, **kwargs - ) + rounds = round_store.list(limit, skip, sort_key, sort_order, use_typing=False, **kwargs) result = rounds["result"] diff --git a/fedn/fedn/network/api/v1/session_routes.py b/fedn/network/api/v1/session_routes.py similarity index 95% rename from fedn/fedn/network/api/v1/session_routes.py rename to fedn/network/api/v1/session_routes.py index 99c52d8db..ccfde590a 100644 --- a/fedn/fedn/network/api/v1/session_routes.py +++ b/fedn/network/api/v1/session_routes.py @@ -1,8 +1,7 @@ from flask import Blueprint, jsonify, request from fedn.network.api.auth import jwt_auth_required -from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs, - get_typed_list_headers, mdb) +from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers, mdb from fedn.network.storage.statestore.stores.session_store import SessionStore from fedn.network.storage.statestore.stores.shared import EntityNotFound @@ -87,9 +86,7 @@ def get_sessions(): limit, skip, sort_key, sort_order, _ = get_typed_list_headers(request.headers) kwargs = request.args.to_dict() - sessions = session_store.list( - limit, skip, sort_key, sort_order, use_typing=False, **kwargs - ) + sessions = session_store.list(limit, skip, sort_key, sort_order, use_typing=False, **kwargs) result = sessions["result"] @@ -167,9 +164,7 @@ def list_sessions(): limit, skip, sort_key, sort_order, _ = get_typed_list_headers(request.headers) kwargs = get_post_data_to_kwargs(request) - sessions = session_store.list( - limit, skip, sort_key, sort_order, use_typing=False, **kwargs - ) + sessions = session_store.list(limit, skip, sort_key, sort_order, use_typing=False, **kwargs) result = sessions["result"] diff --git a/fedn/fedn/network/api/v1/shared.py b/fedn/network/api/v1/shared.py similarity index 88% rename from fedn/fedn/network/api/v1/shared.py rename to fedn/network/api/v1/shared.py index 753414324..2fb6063c0 100644 --- a/fedn/fedn/network/api/v1/shared.py +++ b/fedn/network/api/v1/shared.py @@ -3,8 +3,7 @@ import pymongo from pymongo.database import Database -from fedn.common.config import (get_modelstorage_config, get_network_config, - get_statestore_config) +from fedn.common.config import get_modelstorage_config, get_network_config, get_statestore_config api_version = "v1" @@ -58,9 +57,7 @@ def get_typed_list_headers( use_typing: bool = get_use_typing(headers) if sort_order is not None: - sort_order = ( - pymongo.ASCENDING if sort_order.lower() == "asc" else pymongo.DESCENDING - ) + sort_order = pymongo.ASCENDING if sort_order.lower() == "asc" else pymongo.DESCENDING else: sort_order = pymongo.DESCENDING diff --git a/fedn/fedn/network/api/v1/status_routes.py b/fedn/network/api/v1/status_routes.py similarity index 93% rename from fedn/fedn/network/api/v1/status_routes.py rename to fedn/network/api/v1/status_routes.py index e78c18533..b88772b01 100644 --- a/fedn/fedn/network/api/v1/status_routes.py +++ b/fedn/network/api/v1/status_routes.py @@ -1,9 +1,7 @@ from flask import Blueprint, jsonify, request from fedn.network.api.auth import jwt_auth_required -from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs, - get_typed_list_headers, get_use_typing, - mdb) +from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers, get_use_typing, mdb from fedn.network.storage.statestore.stores.shared import EntityNotFound from fedn.network.storage.statestore.stores.status_store import StatusStore @@ -123,20 +121,12 @@ def get_statuses(): type: string """ try: - limit, skip, sort_key, sort_order, use_typing = get_typed_list_headers( - request.headers - ) + limit, skip, sort_key, sort_order, use_typing = get_typed_list_headers(request.headers) kwargs = request.args.to_dict() - statuses = status_store.list( - limit, skip, sort_key, sort_order, use_typing=use_typing, **kwargs - ) + statuses = status_store.list(limit, skip, sort_key, sort_order, use_typing=use_typing, **kwargs) - result = ( - [status.__dict__ for status in statuses["result"]] - if use_typing - else statuses["result"] - ) + result = [status.__dict__ for status in statuses["result"]] if use_typing else statuses["result"] response = {"count": statuses["count"], "result": result} @@ -226,20 +216,12 @@ def list_statuses(): type: string """ try: - limit, skip, sort_key, sort_order, use_typing = get_typed_list_headers( - request.headers - ) + limit, skip, sort_key, sort_order, use_typing = get_typed_list_headers(request.headers) kwargs = get_post_data_to_kwargs(request) - statuses = status_store.list( - limit, skip, sort_key, sort_order, use_typing=use_typing, **kwargs - ) + statuses = status_store.list(limit, skip, sort_key, sort_order, use_typing=use_typing, **kwargs) - result = ( - [status.__dict__ for status in statuses["result"]] - if use_typing - else statuses["result"] - ) + result = [status.__dict__ for status in statuses["result"]] if use_typing else statuses["result"] response = {"count": statuses["count"], "result": result} diff --git a/fedn/fedn/network/api/v1/validation_routes.py b/fedn/network/api/v1/validation_routes.py similarity index 92% rename from fedn/fedn/network/api/v1/validation_routes.py rename to fedn/network/api/v1/validation_routes.py index 96fbac55c..59767e3e8 100644 --- a/fedn/fedn/network/api/v1/validation_routes.py +++ b/fedn/network/api/v1/validation_routes.py @@ -1,12 +1,9 @@ from flask import Blueprint, jsonify, request from fedn.network.api.auth import jwt_auth_required -from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs, - get_typed_list_headers, get_use_typing, - mdb) +from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers, get_use_typing, mdb from fedn.network.storage.statestore.stores.shared import EntityNotFound -from fedn.network.storage.statestore.stores.validation_store import \ - ValidationStore +from fedn.network.storage.statestore.stores.validation_store import ValidationStore bp = Blueprint("validation", __name__, url_prefix=f"/api/{api_version}/validations") @@ -131,20 +128,12 @@ def get_validations(): type: string """ try: - limit, skip, sort_key, sort_order, use_typing = get_typed_list_headers( - request.headers - ) + limit, skip, sort_key, sort_order, use_typing = get_typed_list_headers(request.headers) kwargs = request.args.to_dict() - validations = validation_store.list( - limit, skip, sort_key, sort_order, use_typing=use_typing, **kwargs - ) + validations = validation_store.list(limit, skip, sort_key, sort_order, use_typing=use_typing, **kwargs) - result = ( - [validation.__dict__ for validation in validations["result"]] - if use_typing - else validations["result"] - ) + result = [validation.__dict__ for validation in validations["result"]] if use_typing else validations["result"] response = {"count": validations["count"], "result": result} @@ -237,20 +226,12 @@ def list_validations(): type: string """ try: - limit, skip, sort_key, sort_order, use_typing = get_typed_list_headers( - request.headers - ) + limit, skip, sort_key, sort_order, use_typing = get_typed_list_headers(request.headers) kwargs = get_post_data_to_kwargs(request) - validations = validation_store.list( - limit, skip, sort_key, sort_order, use_typing=use_typing, **kwargs - ) + validations = validation_store.list(limit, skip, sort_key, sort_order, use_typing=use_typing, **kwargs) - result = ( - [validation.__dict__ for validation in validations["result"]] - if use_typing - else validations["result"] - ) + result = [validation.__dict__ for validation in validations["result"]] if use_typing else validations["result"] response = {"count": validations["count"], "result": result} diff --git a/fedn/fedn/network/clients/__init__.py b/fedn/network/clients/__init__.py similarity index 100% rename from fedn/fedn/network/clients/__init__.py rename to fedn/network/clients/__init__.py diff --git a/fedn/fedn/network/clients/client.py b/fedn/network/clients/client.py similarity index 77% rename from fedn/fedn/network/clients/client.py rename to fedn/network/clients/client.py index bdfb9fa4b..085f02b51 100644 --- a/fedn/fedn/network/clients/client.py +++ b/fedn/network/clients/client.py @@ -22,19 +22,17 @@ import fedn.network.grpc.fedn_pb2 as fedn import fedn.network.grpc.fedn_pb2_grpc as rpc from fedn.common.config import FEDN_AUTH_SCHEME, FEDN_PACKAGE_EXTRACT_DIR -from fedn.common.log_config import (logger, set_log_level_from_string, - set_log_stream) from fedn.common.telemetry import trace_all_methods, tracer +from fedn.common.log_config import logger, set_log_level_from_string, set_log_stream from fedn.network.clients.connect import ConnectorClient, Status from fedn.network.clients.package import PackageRuntime from fedn.network.clients.state import ClientState, ClientStateToString -from fedn.network.combiner.modelservice import (get_tmp_path, - upload_request_generator) +from fedn.network.combiner.modelservice import get_tmp_path, upload_request_generator from fedn.utils.dispatcher import Dispatcher from fedn.utils.helpers.helpers import get_helper CHUNK_SIZE = 1024 * 1024 -VALID_NAME_REGEX = '^[a-zA-Z0-9_-]*$' +VALID_NAME_REGEX = "^[a-zA-Z0-9_-]*$" @trace_all_methods @@ -43,7 +41,7 @@ def __init__(self, key): self._key = key def __call__(self, context, callback): - callback((('authorization', f'{FEDN_AUTH_SCHEME} {self._key}'),), None) + callback((("authorization", f"{FEDN_AUTH_SCHEME} {self._key}"),), None) @trace_all_methods @@ -63,30 +61,32 @@ def __init__(self, config): self._missed_heartbeat = 0 self.config = config self.trace_attribs = False - set_log_level_from_string(config.get('verbosity', "INFO")) - set_log_stream(config.get('logfile', None)) - - self.connector = ConnectorClient(host=config['discover_host'], - port=config['discover_port'], - token=config['token'], - name=config['name'], - remote_package=config['remote_compute_context'], - force_ssl=config['force_ssl'], - verify=config['verify'], - combiner=config['preferred_combiner'], - id=config['client_id']) + set_log_level_from_string(config.get("verbosity", "INFO")) + set_log_stream(config.get("logfile", None)) + + self.connector = ConnectorClient( + host=config["discover_host"], + port=config["discover_port"], + token=config["token"], + name=config["name"], + remote_package=config["remote_compute_context"], + force_ssl=config["force_ssl"], + verify=config["verify"], + combiner=config["preferred_combiner"], + id=config["client_id"], + ) # Validate client name - match = re.search(VALID_NAME_REGEX, config['name']) + match = re.search(VALID_NAME_REGEX, config["name"]) if not match: - raise ValueError('Unallowed character in client name. Allowed characters: a-z, A-Z, 0-9, _, -.') + raise ValueError("Unallowed character in client name. Allowed characters: a-z, A-Z, 0-9, _, -.") # Folder where the client will store downloaded compute package and logs - self.name = config['name'] + self.name = config["name"] if FEDN_PACKAGE_EXTRACT_DIR: self.run_path = os.path.join(os.getcwd(), FEDN_PACKAGE_EXTRACT_DIR) else: - dirname = self.name+"-"+time.strftime("%Y%m%d-%H%M%S") + dirname = self.name + "-" + time.strftime("%Y%m%d-%H%M%S") self.run_path = os.path.join(os.getcwd(), dirname) if not os.path.exists(self.run_path): os.mkdir(self.run_path) @@ -104,8 +104,7 @@ def __init__(self, config): self._initialize_helper(combiner_config) if not self.helper: - logger.warning("Failed to retrieve helper class settings: {}".format( - combiner_config)) + logger.warning("Failed to retrieve helper class settings: {}".format(combiner_config)) self._subscribe_to_combiner(self.config) @@ -149,14 +148,14 @@ def _add_grpc_metadata(self, key, value): :type value: str """ # Check if metadata exists and add if not - if not hasattr(self, 'metadata'): + if not hasattr(self, "metadata"): self.metadata = () # Check if metadata key already exists and replace value if so for i, (k, v) in enumerate(self.metadata): if k == key: # Replace value - self.metadata = self.metadata[:i] + ((key, value),) + self.metadata[i + 1:] + self.metadata = self.metadata[:i] + ((key, value),) + self.metadata[i + 1 :] return # Set metadata using tuple concatenation @@ -188,39 +187,38 @@ def connect(self, combiner_config): return None # TODO use the combiner_config['certificate'] for setting up secure comms' - host = combiner_config['host'] + host = combiner_config["host"] # Add host to gRPC metadata - self._add_grpc_metadata('grpc-server', host) + self._add_grpc_metadata("grpc-server", host) logger.debug("Client using metadata: {}.".format(self.metadata)) - port = combiner_config['port'] + port = combiner_config["port"] secure = False - if combiner_config['fqdn'] is not None: - host = combiner_config['fqdn'] + if combiner_config["fqdn"] is not None: + host = combiner_config["fqdn"] # assuming https if fqdn is used port = 443 logger.info(f"Initiating connection to combiner host at: {host}:{port}") - if combiner_config['certificate']: + if combiner_config["certificate"]: logger.info("Utilizing CA certificate for GRPC channel authentication.") secure = True - cert = base64.b64decode( - combiner_config['certificate']) # .decode('utf-8') + cert = base64.b64decode(combiner_config["certificate"]) # .decode('utf-8') credentials = grpc.ssl_channel_credentials(root_certificates=cert) channel = grpc.secure_channel("{}:{}".format(host, str(port)), credentials) elif os.getenv("FEDN_GRPC_ROOT_CERT_PATH"): secure = True logger.info("Using root certificate from environment variable for GRPC channel.") - with open(os.environ["FEDN_GRPC_ROOT_CERT_PATH"], 'rb') as f: + with open(os.environ["FEDN_GRPC_ROOT_CERT_PATH"], "rb") as f: credentials = grpc.ssl_channel_credentials(f.read()) channel = grpc.secure_channel("{}:{}".format(host, str(port)), credentials) - elif self.config['secure']: + elif self.config["secure"]: secure = True logger.info("Using CA certificate for GRPC channel.") cert = self._get_ssl_certificate(host, port=port) - credentials = grpc.ssl_channel_credentials(cert.encode('utf-8')) - if self.config['token']: - token = self.config['token'] + credentials = grpc.ssl_channel_credentials(cert.encode("utf-8")) + if self.config["token"]: + token = self.config["token"] auth_creds = grpc.metadata_call_credentials(GrpcAuth(token)) channel = grpc.secure_channel("{}:{}".format(host, str(port)), grpc.composite_channel_credentials(credentials, auth_creds)) else: @@ -229,9 +227,7 @@ def connect(self, combiner_config): logger.info("Using insecure GRPC channel.") if port == 443: port = 80 - channel = grpc.insecure_channel("{}:{}".format( - host, - str(port))) + channel = grpc.insecure_channel("{}:{}".format(host, str(port))) self.channel = channel @@ -239,12 +235,9 @@ def connect(self, combiner_config): self.combinerStub = rpc.CombinerStub(channel) self.modelStub = rpc.ModelServiceStub(channel) - logger.info("Successfully established {} connection to {}:{}".format("secure" if secure else "insecure", - host, - port)) + logger.info("Successfully established {} connection to {}:{}".format("secure" if secure else "insecure", host, port)) - logger.info("Using {} compute package.".format( - combiner_config["package"])) + logger.info("Using {} compute package.".format(combiner_config["package"])) self._connected = True @@ -267,8 +260,8 @@ def _initialize_helper(self, combiner_config): :return: """ - if 'helper_type' in combiner_config.keys(): - self.helper = get_helper(combiner_config['helper_type']) + if "helper_type" in combiner_config.keys(): + self.helper = get_helper(combiner_config["helper_type"]) def _subscribe_to_combiner(self, config): """Listen to combiner message stream and start all processing threads. @@ -279,12 +272,10 @@ def _subscribe_to_combiner(self, config): """ # Start sending heartbeats to the combiner. - threading.Thread(target=self._send_heartbeat, kwargs={ - 'update_frequency': config['heartbeat_interval']}, daemon=True).start() + threading.Thread(target=self._send_heartbeat, kwargs={"update_frequency": config["heartbeat_interval"]}, daemon=True).start() # Start listening for combiner training and validation messages - threading.Thread( - target=self._listen_to_task_stream, daemon=True).start() + threading.Thread(target=self._listen_to_task_stream, daemon=True).start() self._connected = True # Start processing the client message inbox @@ -292,10 +283,11 @@ def _subscribe_to_combiner(self, config): @retry(stop=stop_after_attempt(3)) def untar_package(self, package_runtime): - package_runtime.unpack() + _, package_runpath = package_runtime.unpack() + return package_runpath def _initialize_dispatcher(self, config): - """ Initialize the dispatcher for the client. + """Initialize the dispatcher for the client. :param config: A configuration dictionary containing connection information for | the discovery service (controller) and settings governing e.g. @@ -303,19 +295,15 @@ def _initialize_dispatcher(self, config): :type config: dict :return: """ - if config['remote_compute_context']: - pr = PackageRuntime(os.getcwd(), os.getcwd()) + if config["remote_compute_context"]: + pr = PackageRuntime(self.run_path) retval = None tries = 10 while tries > 0: retval = pr.download( - host=config['discover_host'], - port=config['discover_port'], - token=config['token'], - force_ssl=config['force_ssl'], - secure=config['secure'] + host=config["discover_host"], port=config["discover_port"], token=config["token"], force_ssl=config["force_ssl"], secure=config["secure"] ) if retval: break @@ -324,19 +312,19 @@ def _initialize_dispatcher(self, config): tries -= 1 if retval: - if 'checksum' not in config: + if "checksum" not in config: logger.warning("Bypassing validation of package checksum. Ensure the package source is trusted.") else: - checks_out = pr.validate(config['checksum']) + checks_out = pr.validate(config["checksum"]) if not checks_out: logger.critical("Validation of local package failed. Client terminating.") self.error_state = True return - + package_runpath = "" if retval: - self.untar_package(pr) + package_runpath = self.untar_package(pr) - self.dispatcher = pr.dispatcher(self.run_path) + self.dispatcher = pr.dispatcher(package_runpath) try: logger.info("Initiating Dispatcher with entrypoint set to: startup") activate_cmd = self.dispatcher._get_or_create_python_env() @@ -349,11 +337,14 @@ def _initialize_dispatcher(self, config): else: # TODO: Deprecate - dispatch_config = {'entry_points': - {'predict': {'command': 'python3 predict.py'}, - 'train': {'command': 'python3 train.py'}, - 'validate': {'command': 'python3 validate.py'}}} - from_path = os.path.join(os.getcwd(), 'client') + dispatch_config = { + "entry_points": { + "predict": {"command": "python3 predict.py"}, + "train": {"command": "python3 train.py"}, + "validate": {"command": "python3 validate.py"}, + } + } + from_path = os.path.join(os.getcwd(), "client") copy_tree(from_path, self.run_path) self.dispatcher = Dispatcher(dispatch_config, self.run_path) @@ -379,7 +370,6 @@ def get_model_from_combiner(self, id, timeout=20): try: for part in self.modelStub.Download(request, metadata=self.metadata): - if part.status == fedn.ModelStatus.IN_PROGRESS: data.write(part.data) @@ -437,7 +427,7 @@ def _listen_to_task_stream(self): r.sender.name = self.name r.sender.role = fedn.WORKER # Add client to metadata - self._add_grpc_metadata('client', self.name) + self._add_grpc_metadata("client", self.name) while self._connected: try: @@ -446,14 +436,19 @@ def _listen_to_task_stream(self): logger.debug("Received model update request from combiner: {}.".format(request)) if request.sender.role == fedn.COMBINER: # Process training request - self.send_status("Received model update request.", log_level=fedn.Status.AUDIT, - type=fedn.StatusType.MODEL_UPDATE_REQUEST, request=request, sesssion_id=request.session_id) + self.send_status( + "Received model update request.", + log_level=fedn.Status.AUDIT, + type=fedn.StatusType.MODEL_UPDATE_REQUEST, + request=request, + sesssion_id=request.session_id, + ) logger.info("Received model update request of type {} for model_id {}".format(request.type, request.model_id)) - if request.type == fedn.StatusType.MODEL_UPDATE and self.config['trainer']: - self.inbox.put(('train', request)) - elif request.type == fedn.StatusType.MODEL_VALIDATION and self.config['validator']: - self.inbox.put(('validate', request)) + if request.type == fedn.StatusType.MODEL_UPDATE and self.config["trainer"]: + self.inbox.put(("train", request)) + elif request.type == fedn.StatusType.MODEL_VALIDATION and self.config["validator"]: + self.inbox.put(("validate", request)) else: logger.error("Unknown request type: {}".format(request.type)) @@ -466,7 +461,7 @@ def _listen_to_task_stream(self): time.sleep(5) if status_code == grpc.StatusCode.UNAUTHENTICATED: details = e.details() - if details == 'Token expired': + if details == "Token expired": logger.warning("GRPC TaskStream: Token expired. Reconnecting.") self.detach() @@ -497,8 +492,7 @@ def _process_training_request(self, model_id: str, session_id: str = None): :rtype: tuple """ - self.send_status( - "\t Starting processing of training request for model_id {}".format(model_id), sesssion_id=session_id) + self.send_status("\t Starting processing of training request for model_id {}".format(model_id), sesssion_id=session_id) self.state = ClientState.training try: @@ -508,10 +502,10 @@ def _process_training_request(self, model_id: str, session_id: str = None): if mdl is None: logger.error("Could not retrieve model from combiner. Aborting training request.") return None, None - meta['fetch_model'] = time.time() - tic + meta["fetch_model"] = time.time() - tic inpath = self.helper.get_tmp_path() - with open(inpath, 'wb') as fh: + with open(inpath, "wb") as fh: fh.write(mdl.getbuffer()) outpath = self.helper.get_tmp_path() @@ -520,7 +514,7 @@ def _process_training_request(self, model_id: str, session_id: str = None): self.dispatcher.run_cmd("train {} {}".format(inpath, outpath)) - meta['exec_training'] = time.time() - tic + meta["exec_training"] = time.time() - tic tic = time.time() out_model = None @@ -531,21 +525,21 @@ def _process_training_request(self, model_id: str, session_id: str = None): # Stream model update to combiner server updated_model_id = uuid.uuid4() self.send_model_to_combiner(out_model, str(updated_model_id)) - meta['upload_model'] = time.time() - tic + meta["upload_model"] = time.time() - tic # Read the metadata file - with open(outpath+'-metadata', 'r') as fh: + with open(outpath + "-metadata", "r") as fh: training_metadata = json.loads(fh.read()) - meta['training_metadata'] = training_metadata + meta["training_metadata"] = training_metadata os.unlink(inpath) os.unlink(outpath) - os.unlink(outpath+'-metadata') + os.unlink(outpath + "-metadata") except Exception as e: logger.error("Could not process training request due to error: {}".format(e)) updated_model_id = None - meta = {'status': 'failed', 'error': str(e)} + meta = {"status": "failed", "error": str(e)} self.state = ClientState.idle @@ -565,12 +559,11 @@ def _process_validation_request(self, model_id: str, is_inference: bool, session """ # Figure out cmd if is_inference: - cmd = 'infer' + cmd = "infer" else: - cmd = 'validate' + cmd = "validate" - self.send_status( - f"Processing {cmd} request for model_id {model_id}", sesssion_id=session_id) + self.send_status(f"Processing {cmd} request for model_id {model_id}", sesssion_id=session_id) self.state = ClientState.validating try: model = self.get_model_from_combiner(str(model_id)) @@ -600,25 +593,22 @@ def _process_validation_request(self, model_id: str, is_inference: bool, session return validation def process_request(self): - """Process training and validation tasks. """ + """Process training and validation tasks.""" while True: - if not self._connected: return try: (task_type, request) = self.inbox.get(timeout=1.0) - if task_type == 'train': - + if task_type == "train": tic = time.time() self.state = ClientState.training - model_id, meta = self._process_training_request( - request.model_id, session_id=request.session_id) + model_id, meta = self._process_training_request(request.model_id, session_id=request.session_id) if meta is not None: - processing_time = time.time()-tic - meta['processing_time'] = processing_time - meta['config'] = request.data + processing_time = time.time() - tic + meta["processing_time"] = processing_time + meta["config"] = request.data if model_id is not None: # Send model update to combiner @@ -635,28 +625,31 @@ def process_request(self): try: _ = self.combinerStub.SendModelUpdate(update, metadata=self.metadata) - self.send_status("Model update completed.", log_level=fedn.Status.AUDIT, - type=fedn.StatusType.MODEL_UPDATE, request=update, sesssion_id=request.session_id) + self.send_status( + "Model update completed.", + log_level=fedn.Status.AUDIT, + type=fedn.StatusType.MODEL_UPDATE, + request=update, + sesssion_id=request.session_id, + ) except grpc.RpcError as e: status_code = e.code() - logger.error("GRPC error, {}.".format( - status_code.name)) + logger.error("GRPC error, {}.".format(status_code.name)) logger.debug(e) except ValueError as e: logger.error("GRPC error, RPC channel closed. {}".format(e)) logger.debug(e) else: - self.send_status("Client {} failed to complete model update.", - log_level=fedn.Status.WARNING, - request=request, sesssion_id=request.session_id) + self.send_status( + "Client {} failed to complete model update.", log_level=fedn.Status.WARNING, request=request, sesssion_id=request.session_id + ) self.state = ClientState.idle self.inbox.task_done() - elif task_type == 'validate': + elif task_type == "validate": self.state = ClientState.validating - metrics = self._process_validation_request( - request.model_id, False, request.session_id) + metrics = self._process_validation_request(request.model_id, False, request.session_id) if metrics is not None: # Send validation @@ -672,23 +665,26 @@ def process_request(self): validation.session_id = request.session_id try: - _ = self.combinerStub.SendModelValidation( - validation, metadata=self.metadata) + _ = self.combinerStub.SendModelValidation(validation, metadata=self.metadata) status_type = fedn.StatusType.MODEL_VALIDATION - self.send_status("Model validation completed.", log_level=fedn.Status.AUDIT, - type=status_type, request=validation, sesssion_id=request.session_id) + self.send_status( + "Model validation completed.", log_level=fedn.Status.AUDIT, type=status_type, request=validation, sesssion_id=request.session_id + ) except grpc.RpcError as e: status_code = e.code() - logger.error("GRPC error, {}.".format( - status_code.name)) + logger.error("GRPC error, {}.".format(status_code.name)) logger.debug(e) except ValueError as e: logger.error("GRPC error, RPC channel closed. {}".format(e)) logger.debug(e) else: - self.send_status("Client {} failed to complete model validation.".format(self.name), - log_level=fedn.Status.WARNING, request=request, sesssion_id=request.session_id) + self.send_status( + "Client {} failed to complete model validation.".format(self.name), + log_level=fedn.Status.WARNING, + request=request, + sesssion_id=request.session_id, + ) self.state = ClientState.idle self.inbox.task_done() @@ -706,8 +702,7 @@ def _send_heartbeat(self, update_frequency=2.0): :rtype: None """ while True: - heartbeat = fedn.Heartbeat(sender=fedn.Client( - name=self.name, role=fedn.WORKER)) + heartbeat = fedn.Heartbeat(sender=fedn.Client(name=self.name, role=fedn.WORKER)) try: self.connectorStub.SendHeartbeat(heartbeat, metadata=self.metadata) self._missed_heartbeat = 0 @@ -715,13 +710,16 @@ def _send_heartbeat(self, update_frequency=2.0): status_code = e.code() if status_code == grpc.StatusCode.UNAVAILABLE: self._missed_heartbeat += 1 - logger.error("GRPC hearbeat: combiner unavailable, retrying (attempt {}/{}).".format(self._missed_heartbeat, - self.config['reconnect_after_missed_heartbeat'])) - if self._missed_heartbeat > self.config['reconnect_after_missed_heartbeat']: + logger.error( + "GRPC hearbeat: combiner unavailable, retrying (attempt {}/{}).".format( + self._missed_heartbeat, self.config["reconnect_after_missed_heartbeat"] + ) + ) + if self._missed_heartbeat > self.config["reconnect_after_missed_heartbeat"]: self.disconnect() if status_code == grpc.StatusCode.UNAUTHENTICATED: details = e.details() - if details == 'Token expired': + if details == "Token expired": logger.error("GRPC hearbeat: Token expired. Disconnecting.") self.disconnect() sys.exit("Unauthorized. Token expired. Please obtain a new token.") @@ -762,9 +760,7 @@ def send_status(self, msg, log_level=fedn.Status.INFO, type=None, request=None, if request is not None: status.data = MessageToJson(request) - self.logs.append( - "{} {} LOG LEVEL {} MESSAGE {}".format(str(datetime.now()), status.sender.name, status.log_level, - status.status)) + self.logs.append("{} {} LOG LEVEL {} MESSAGE {}".format(str(datetime.now()), status.sender.name, status.log_level, status.status)) try: _ = self.connectorStub.SendStatus(status, metadata=self.metadata) except grpc.RpcError as e: @@ -773,11 +769,11 @@ def send_status(self, msg, log_level=fedn.Status.INFO, type=None, request=None, logger.warning("GRPC SendStatus: server unavailable during send status.") if status_code == grpc.StatusCode.UNAUTHENTICATED: details = e.details() - if details == 'Token expired': + if details == "Token expired": logger.warning("GRPC SendStatus: Token expired.") def run(self): - """ Run the client. """ + """Run the client.""" try: cnt = 0 old_state = self.state diff --git a/fedn/fedn/network/clients/connect.py b/fedn/network/clients/connect.py similarity index 71% rename from fedn/fedn/network/clients/connect.py rename to fedn/network/clients/connect.py index 78be7a340..85941365e 100644 --- a/fedn/fedn/network/clients/connect.py +++ b/fedn/network/clients/connect.py @@ -9,15 +9,14 @@ import requests -from fedn.common.config import (FEDN_AUTH_REFRESH_TOKEN, - FEDN_AUTH_REFRESH_TOKEN_URI, FEDN_AUTH_SCHEME, - FEDN_CUSTOM_URL_PREFIX) +from fedn.common.config import FEDN_AUTH_REFRESH_TOKEN, FEDN_AUTH_REFRESH_TOKEN_URI, FEDN_AUTH_SCHEME, FEDN_CUSTOM_URL_PREFIX from fedn.common.log_config import logger from fedn.common.telemetry import trace_all_methods class Status(enum.Enum): - """ Enum for representing the status of a client assignment.""" + """Enum for representing the status of a client assignment.""" + Unassigned = 0 Assigned = 1 TryAgain = 2 @@ -27,7 +26,7 @@ class Status(enum.Enum): @trace_all_methods class ConnectorClient: - """ Connector for assigning client to a combiner in the FEDn network. + """Connector for assigning client to a combiner in the FEDn network. :param host: host of discovery service :type host: str @@ -49,7 +48,6 @@ class ConnectorClient: """ def __init__(self, host, port, token, name, remote_package, force_ssl=False, verify=False, combiner=None, id=None): - self.host = host self.port = port self.token = token @@ -57,7 +55,7 @@ def __init__(self, host, port, token, name, remote_package, force_ssl=False, ver self.verify = verify self.preferred_combiner = combiner self.id = id - self.package = 'remote' if remote_package else 'local' + self.package = "remote" if remote_package else "local" # for https we assume a an ingress handles permanent redirect (308) if force_ssl: @@ -65,11 +63,9 @@ def __init__(self, host, port, token, name, remote_package, force_ssl=False, ver else: self.prefix = "http://" if self.port: - self.connect_string = "{}{}:{}".format( - self.prefix, self.host, self.port) + self.connect_string = "{}{}:{}".format(self.prefix, self.host, self.port) else: - self.connect_string = "{}{}".format( - self.prefix, self.host) + self.connect_string = "{}{}".format(self.prefix, self.host) logger.info("Setting connection string to {}.".format(self.connect_string)) @@ -82,26 +78,28 @@ def assign(self): """ try: retval = None - payload = {'client_id': self.name, 'preferred_combiner': self.preferred_combiner} - retval = requests.post(self.connect_string + FEDN_CUSTOM_URL_PREFIX + '/add_client', - json=payload, - verify=self.verify, - allow_redirects=True, - headers={'Authorization': f"{FEDN_AUTH_SCHEME} {self.token}"}) + payload = {"client_id": self.name, "preferred_combiner": self.preferred_combiner} + retval = requests.post( + self.connect_string + FEDN_CUSTOM_URL_PREFIX + "/add_client", + json=payload, + verify=self.verify, + allow_redirects=True, + headers={"Authorization": f"{FEDN_AUTH_SCHEME} {self.token}"}, + ) except Exception as e: - logger.debug('***** {}'.format(e)) + logger.debug("***** {}".format(e)) return Status.Unassigned, {} if retval.status_code == 400: # Get error messange from response - reason = retval.json()['message'] + reason = retval.json()["message"] return Status.UnMatchedConfig, reason if retval.status_code == 401: - if 'message' in retval.json(): - reason = retval.json()['message'] + if "message" in retval.json(): + reason = retval.json()["message"] logger.warning(reason) - if reason == 'Token expired': + if reason == "Token expired": status_code = self.refresh_token() if status_code >= 200 and status_code < 204: logger.info("Token refreshed.") @@ -128,11 +126,11 @@ def assign(self): logger.error("Error: {}".format(e)) sys.exit(1) - reducer_package = retval.json()['package'] + reducer_package = retval.json()["package"] if reducer_package != self.package: - reason = "Unmatched config of compute package between client and reducer.\n" +\ - "Reducer uses {} package and client uses {}.".format( - reducer_package, self.package) + reason = "Unmatched config of compute package between client and reducer.\n" + "Reducer uses {} package and client uses {}.".format( + reducer_package, self.package + ) return Status.UnMatchedConfig, reason return Status.Assigned, retval.json() @@ -150,9 +148,6 @@ def refresh_token(self): logger.error("No refresh token URI/Token set, cannot refresh token.") return 401 - payload = requests.post(FEDN_AUTH_REFRESH_TOKEN_URI, - verify=self.verify, - allow_redirects=True, - json={'refresh': FEDN_AUTH_REFRESH_TOKEN}) - self.token = payload.json()['access'] + payload = requests.post(FEDN_AUTH_REFRESH_TOKEN_URI, verify=self.verify, allow_redirects=True, json={"refresh": FEDN_AUTH_REFRESH_TOKEN}) + self.token = payload.json()["access"] return payload.status_code diff --git a/fedn/fedn/network/clients/package.py b/fedn/network/clients/package.py similarity index 61% rename from fedn/fedn/network/clients/package.py rename to fedn/network/clients/package.py index 9cdaf61f6..a01c3ebd5 100644 --- a/fedn/fedn/network/clients/package.py +++ b/fedn/network/clients/package.py @@ -4,7 +4,6 @@ import cgi import os import tarfile -from distutils.dir_util import copy_tree import requests @@ -17,7 +16,7 @@ @trace_all_methods class PackageRuntime: - """ PackageRuntime is used to download, validate and unpack compute packages. + """PackageRuntime is used to download, validate and unpack compute packages. :param package_path: path to compute package :type package_path: str @@ -25,21 +24,22 @@ class PackageRuntime: :type package_dir: str """ - def __init__(self, package_path, package_dir): - - self.dispatch_config = {'entry_points': - {'predict': {'command': 'python3 predict.py'}, - 'train': {'command': 'python3 train.py'}, - 'validate': {'command': 'python3 validate.py'}}} + def __init__(self, package_path): + self.dispatch_config = { + "entry_points": { + "predict": {"command": "python3 predict.py"}, + "train": {"command": "python3 train.py"}, + "validate": {"command": "python3 validate.py"}, + } + } self.pkg_path = package_path self.pkg_name = None - self.dir = package_dir self.checksum = None self.expected_checksum = None def download(self, host, port, token, force_ssl=False, secure=False, name=None): - """ Download compute package from controller + """Download compute package from controller :param host: host of controller :param port: port of controller @@ -60,18 +60,16 @@ def download(self, host, port, token, force_ssl=False, secure=False, name=None): if name: path = path + "?name={}".format(name) - with requests.get(path, stream=True, verify=False, headers={'Authorization': f'{FEDN_AUTH_SCHEME} {token}'}) as r: + with requests.get(path, stream=True, verify=False, headers={"Authorization": f"{FEDN_AUTH_SCHEME} {token}"}) as r: if 200 <= r.status_code < 204: - - params = cgi.parse_header( - r.headers.get('Content-Disposition', ''))[-1] + params = cgi.parse_header(r.headers.get("Content-Disposition", ""))[-1] try: - self.pkg_name = params['filename'] + self.pkg_name = params["filename"] except KeyError: logger.error("No package returned.") return None r.raise_for_status() - with open(os.path.join(self.pkg_path, self.pkg_name), 'wb') as f: + with open(os.path.join(self.pkg_path, self.pkg_name), "wb") as f: for chunk in r.iter_content(chunk_size=8192): f.write(chunk) if port: @@ -81,19 +79,18 @@ def download(self, host, port, token, force_ssl=False, secure=False, name=None): if name: path = path + "?name={}".format(name) - with requests.get(path, verify=False, headers={'Authorization': f'{FEDN_AUTH_SCHEME} {token}'}) as r: + with requests.get(path, verify=False, headers={"Authorization": f"{FEDN_AUTH_SCHEME} {token}"}) as r: if 200 <= r.status_code < 204: - data = r.json() try: - self.checksum = data['checksum'] + self.checksum = data["checksum"] except Exception: logger.error("Could not extract checksum.") return True def validate(self, expected_checksum): - """ Validate the package against the checksum provided by the controller + """Validate the package against the checksum provided by the controller :param expected_checksum: checksum provided by the controller :return: True if checksums match, False otherwise @@ -111,55 +108,55 @@ def validate(self, expected_checksum): return False def unpack(self): - """ Unpack the compute package + """Unpack the compute package :return: True if unpacking was successful, False otherwise :rtype: bool """ if self.pkg_name: f = None - if self.pkg_name.endswith('tar.gz'): - f = tarfile.open(os.path.join( - self.pkg_path, self.pkg_name), 'r:gz') - if self.pkg_name.endswith('.tgz'): - f = tarfile.open(os.path.join( - self.pkg_path, self.pkg_name), 'r:gz') - if self.pkg_name.endswith('tar.bz2'): - f = tarfile.open(os.path.join( - self.pkg_path, self.pkg_name), 'r:bz2') + if self.pkg_name.endswith("tar.gz"): + f = tarfile.open(os.path.join(self.pkg_path, self.pkg_name), "r:gz") + if self.pkg_name.endswith(".tgz"): + f = tarfile.open(os.path.join(self.pkg_path, self.pkg_name), "r:gz") + if self.pkg_name.endswith("tar.bz2"): + f = tarfile.open(os.path.join(self.pkg_path, self.pkg_name), "r:bz2") else: - logger.error( - "Failed to unpack compute package, no pkg_name set." - "Has the reducer been configured with a compute package?" - ) + logger.error("Failed to unpack compute package, no pkg_name set." "Has the reducer been configured with a compute package?") return False - os.getcwd() try: - os.chdir(self.dir) - if f: - f.extractall() - logger.info("Successfully extracted compute package content in {}".format(self.dir)) - return True + f.extractall(self.pkg_path) + logger.info("Successfully extracted compute package content in {}".format(self.pkg_path)) + # delete the tarball + logger.info("Deleting temporary package tarball file.") + os.remove(os.path.join(self.pkg_path, self.pkg_name)) + # search for file fedn.yaml in extracted package + for root, dirs, files in os.walk(self.pkg_path): + if "fedn.yaml" in files: + # Get the path to where fedn.yaml is located + logger.info("Found fedn.yaml file in {}".format(root)) + return True, root + + logger.error("No fedn.yaml file found in extracted package!") + return False, "" except Exception: logger.error("Error extracting files.") - return False + # delete the tarball + os.remove(os.path.join(self.pkg_path, self.pkg_name)) + return False, "" def dispatcher(self, run_path): - """ Dispatch the compute package + """Dispatch the compute package :param run_path: path to dispatch the compute package :type run_path: str :return: Dispatcher object :rtype: :class:`fedn.utils.dispatcher.Dispatcher` """ - from_path = os.path.join(os.getcwd(), 'client') - - # preserve_times=False ensures compatibility with Gramine LibOS - copy_tree(from_path, run_path, preserve_times=False) - self.dispatch_config = _read_yaml_file(os.path.join(run_path, 'fedn.yaml')) + self.dispatch_config = _read_yaml_file(os.path.join(run_path, "fedn.yaml")) dispatcher = Dispatcher(self.dispatch_config, run_path) return dispatcher diff --git a/fedn/fedn/network/clients/state.py b/fedn/network/clients/state.py similarity index 84% rename from fedn/fedn/network/clients/state.py rename to fedn/network/clients/state.py index dd329f96e..8d64512a5 100644 --- a/fedn/fedn/network/clients/state.py +++ b/fedn/network/clients/state.py @@ -4,7 +4,8 @@ class ClientState(Enum): - """ Enum for representing the state of a client.""" + """Enum for representing the state of a client.""" + idle = 1 training = 2 validating = 3 @@ -12,7 +13,7 @@ class ClientState(Enum): @tracer.start_as_current_span(name="ClientStateToString") def ClientStateToString(state): - """ Convert a ClientState to a string representation. + """Convert a ClientState to a string representation. :param state: the state to convert :type state: :class:`fedn.network.clients.state.ClientState` diff --git a/fedn/fedn/network/clients/test_client.py b/fedn/network/clients/test_client.py similarity index 100% rename from fedn/fedn/network/clients/test_client.py rename to fedn/network/clients/test_client.py diff --git a/fedn/fedn/network/combiner/__init__.py b/fedn/network/combiner/__init__.py similarity index 100% rename from fedn/fedn/network/combiner/__init__.py rename to fedn/network/combiner/__init__.py diff --git a/fedn/fedn/network/combiner/aggregators/__init__.py b/fedn/network/combiner/aggregators/__init__.py similarity index 100% rename from fedn/fedn/network/combiner/aggregators/__init__.py rename to fedn/network/combiner/aggregators/__init__.py diff --git a/fedn/fedn/network/combiner/aggregators/aggregatorbase.py b/fedn/network/combiner/aggregators/aggregatorbase.py similarity index 96% rename from fedn/fedn/network/combiner/aggregators/aggregatorbase.py rename to fedn/network/combiner/aggregators/aggregatorbase.py index cb23d192e..6d4e997e1 100644 --- a/fedn/fedn/network/combiner/aggregators/aggregatorbase.py +++ b/fedn/network/combiner/aggregators/aggregatorbase.py @@ -36,7 +36,7 @@ def __init__(self, storage, server, modelservice, round_handler): self.model_updates = queue.Queue() @abstractmethod - def combine_models(self, nr_expected_models=None, nr_required_models=1, helper=None, timeout=180, delete_models=True): + def combine_models(self, nr_expected_models=None, nr_required_models=1, helper=None, timeout=180, delete_models=True, parameters=None): """Routine for combining model updates. Implemented in subclass. :param nr_expected_models: Number of expected models. If None, wait for all models. @@ -49,7 +49,10 @@ def combine_models(self, nr_expected_models=None, nr_required_models=1, helper=N :type timeout: int :param delete_models: Delete client models after combining. :type delete_models: bool - :return: A combined model. + :param parameters: Additional key-word arguments. + :type parameters: dict + :return: The global model and metadata + :rtype: tuple """ pass diff --git a/fedn/fedn/network/combiner/aggregators/fedavg.py b/fedn/network/combiner/aggregators/fedavg.py similarity index 98% rename from fedn/fedn/network/combiner/aggregators/fedavg.py rename to fedn/network/combiner/aggregators/fedavg.py index 305447c6b..af5007667 100644 --- a/fedn/fedn/network/combiner/aggregators/fedavg.py +++ b/fedn/network/combiner/aggregators/fedavg.py @@ -28,7 +28,7 @@ def __init__(self, storage, server, modelservice, round_handler): self.name = "fedavg" - def combine_models(self, helper=None, delete_models=True): + def combine_models(self, helper=None, delete_models=True, parameters=None): """Aggregate all model updates in the queue by computing an incremental weighted average of model parameters. diff --git a/fedn/network/combiner/aggregators/fedopt.py b/fedn/network/combiner/aggregators/fedopt.py new file mode 100644 index 000000000..305340f10 --- /dev/null +++ b/fedn/network/combiner/aggregators/fedopt.py @@ -0,0 +1,268 @@ +import math + +from fedn.common.exceptions import InvalidParameterError +from fedn.common.log_config import logger +from fedn.network.combiner.aggregators.aggregatorbase import AggregatorBase + + +class Aggregator(AggregatorBase): + """ Federated Optimization (FedOpt) aggregator. + + Implmentation following: https://arxiv.org/pdf/2003.00295.pdf + + This aggregator computes pseudo gradients by subtracting the model + update from the global model weights from the previous round. + A server-side scheme is then applied, currenty supported schemes + are "adam", "yogi", "adagrad". + + + + :param id: A reference to id of :class: `fedn.network.combiner.Combiner` + :type id: str + :param storage: Model repository for :class: `fedn.network.combiner.Combiner` + :type storage: class: `fedn.common.storage.s3.s3repo.S3ModelRepository` + :param server: A handle to the Combiner class :class: `fedn.network.combiner.Combiner` + :type server: class: `fedn.network.combiner.Combiner` + :param modelservice: A handle to the model service :class: `fedn.network.combiner.modelservice.ModelService` + :type modelservice: class: `fedn.network.combiner.modelservice.ModelService` + :param control: A handle to the :class: `fedn.network.combiner.roundhandler.RoundHandler` + :type control: class: `fedn.network.combiner.roundhandler.RoundHandler` + + """ + + def __init__(self, storage, server, modelservice, round_handler): + + super().__init__(storage, server, modelservice, round_handler) + + self.name = "fedopt" + # To store momentum + self.v = None + self.m = None + + def combine_models(self, helper=None, delete_models=True, parameters=None): + """Compute pseudo gradients using model updates in the queue. + + :param helper: An instance of :class: `fedn.utils.helpers.helpers.HelperBase`, ML framework specific helper, defaults to None + :type helper: class: `fedn.utils.helpers.helpers.HelperBase`, optional + :param time_window: The time window for model aggregation, defaults to 180 + :type time_window: int, optional + :param max_nr_models: The maximum number of updates aggregated, defaults to 100 + :type max_nr_models: int, optional + :param delete_models: Delete models from storage after aggregation, defaults to True + :type delete_models: bool, optional + :param parameters: Aggregator hyperparameters. + :type parameters: `fedn.utils.parmeters.Parameters`, optional + :return: The global model and metadata + :rtype: tuple + """ + + data = {} + data['time_model_load'] = 0.0 + data['time_model_aggregation'] = 0.0 + + # Define parameter schema + parameter_schema = { + 'serveropt': str, + 'learning_rate': float, + 'beta1': float, + 'beta2': float, + 'tau': float, + } + + try: + parameters.validate(parameter_schema) + except InvalidParameterError as e: + logger.error("Aggregator {} recieved invalid parameters. Reason {}".format(self.name, e)) + return None, data + + # Default hyperparameters. Note that these may need fine tuning. + default_parameters = { + 'serveropt': 'adam', + 'learning_rate': 1e-3, + 'beta1': 0.9, + 'beta2': 0.99, + 'tau': 1e-4, + } + + # Validate parameters + if parameters: + try: + parameters.validate(parameter_schema) + except InvalidParameterError as e: + logger.error("Aggregator {} recieved invalid parameters. Reason {}".format(self.name, e)) + return None, data + else: + logger.info("Aggregator {} using default parameteres.", format(self.name)) + parameters = self.default_parameters + + # Override missing paramters with defaults + for key, value in default_parameters.items(): + if key not in parameters: + parameters[key] = value + + model = None + nr_aggregated_models = 0 + total_examples = 0 + + logger.info( + "AGGREGATOR({}): Aggregating model updates... ".format(self.name)) + + while not self.model_updates.empty(): + try: + # Get next model from queue + model_update = self.next_model_update() + + # Load model paratmeters and metadata + model_next, metadata = self.load_model_update(model_update, helper) + + logger.info( + "AGGREGATOR({}): Processing model update {}".format(self.name, model_update.model_update_id)) + + # Increment total number of examples + total_examples += metadata['num_examples'] + + if nr_aggregated_models == 0: + model_old = self.round_handler.load_model_update(helper, model_update.model_id) + pseudo_gradient = helper.subtract(model_next, model_old) + else: + pseudo_gradient_next = helper.subtract(model_next, model_old) + pseudo_gradient = helper.increment_average( + pseudo_gradient, pseudo_gradient_next, metadata['num_examples'], total_examples) + + nr_aggregated_models += 1 + # Delete model from storage + if delete_models: + self.modelservice.temp_model_storage.delete(model_update.model_update_id) + logger.info( + "AGGREGATOR({}): Deleted model update {} from storage.".format(self.name, model_update.model_update_id)) + self.model_updates.task_done() + except Exception as e: + logger.error( + "AGGREGATOR({}): Error encoutered while processing model update {}, skipping this update.".format(self.name, e)) + self.model_updates.task_done() + + if parameters['serveropt'] == 'adam': + model = self.serveropt_adam(helper, pseudo_gradient, model_old, parameters) + elif parameters['serveropt'] == 'yogi': + model = self.serveropt_yogi(helper, pseudo_gradient, model_old, parameters) + elif parameters['serveropt'] == 'adagrad': + model = self.serveropt_adagrad(helper, pseudo_gradient, model_old, parameters) + else: + logger.error("Unsupported server optimizer passed to FedOpt.") + return None, data + + data['nr_aggregated_models'] = nr_aggregated_models + + logger.info("AGGREGATOR({}): Aggregation completed, aggregated {} models.".format(self.name, nr_aggregated_models)) + return model, data + + def serveropt_adam(self, helper, pseudo_gradient, model_old, parameters): + """ Server side optimization, FedAdam. + + :param helper: instance of helper class. + :type helper: Helper + :param pseudo_gradient: The pseudo gradient. + :type pseudo_gradient: As defined by helper. + :param model_old: The current global model. + :type model_old: As defined in helper. + :param parameters: Hyperparamters for the aggregator. + :type parameters: dict + :return: new model weights. + :rtype: as defined by helper. + """ + beta1 = parameters['beta1'] + beta2 = parameters['beta2'] + learning_rate = parameters['learning_rate'] + tau = parameters['tau'] + + if not self.v: + self.v = helper.ones(pseudo_gradient, math.pow(tau, 2)) + + if not self.m: + self.m = helper.multiply(pseudo_gradient, [(1.0-beta1)]*len(pseudo_gradient)) + else: + self.m = helper.add(self.m, pseudo_gradient, beta1, (1.0-beta1)) + + p = helper.power(pseudo_gradient, 2) + self.v = helper.add(self.v, p, beta2, (1.0-beta2)) + + sv = helper.add(helper.sqrt(self.v), helper.ones(self.v, tau)) + t = helper.divide(self.m, sv) + model = helper.add(model_old, t, 1.0, learning_rate) + + return model + + def serveropt_yogi(self, helper, pseudo_gradient, model_old, parameters): + """ Server side optimization, FedYogi. + + :param helper: instance of helper class. + :type helper: Helper + :param pseudo_gradient: The pseudo gradient. + :type pseudo_gradient: As defined by helper. + :param model_old: The current global model. + :type model_old: As defined in helper. + :param parameters: Hyperparamters for the aggregator. + :type parameters: dict + :return: new model weights. + :rtype: as defined by helper. + """ + + beta1 = parameters['beta1'] + beta2 = parameters['beta2'] + learning_rate = parameters['learning_rate'] + tau = parameters['tau'] + + if not self.v: + self.v = helper.ones(pseudo_gradient, math.pow(tau, 2)) + + if not self.m: + self.m = helper.multiply(pseudo_gradient, [(1.0-beta1)]*len(pseudo_gradient)) + else: + self.m = helper.add(self.m, pseudo_gradient, beta1, (1.0-beta1)) + + p = helper.power(pseudo_gradient, 2) + s = helper.sign(helper.add(self.v, p, 1.0, -1.0)) + s = helper.multiply(s, p) + self.v = helper.add(self.v, s, 1.0, -(1.0-beta2)) + + sv = helper.add(helper.sqrt(self.v), helper.ones(self.v, tau)) + t = helper.divide(self.m, sv) + model = helper.add(model_old, t, 1.0, learning_rate) + + return model + + def serveropt_adagrad(self, helper, pseudo_gradient, model_old, parameters): + """ Server side optimization, FedAdam. + + :param helper: instance of helper class. + :type helper: Helper + :param pseudo_gradient: The pseudo gradient. + :type pseudo_gradient: As defined by helper. + :param model_old: The current global model. + :type model_old: As defined in helper. + :param parameters: Hyperparamters for the aggregator. + :type parameters: dict + :return: new model weights. + :rtype: as defined by helper. + """ + + beta1 = parameters['beta1'] + learning_rate = parameters['learning_rate'] + tau = parameters['tau'] + + if not self.v: + self.v = helper.ones(pseudo_gradient, math.pow(tau, 2)) + + if not self.m: + self.m = helper.multiply(pseudo_gradient, [(1.0-beta1)]*len(pseudo_gradient)) + else: + self.m = helper.add(self.m, pseudo_gradient, beta1, (1.0-beta1)) + + p = helper.power(pseudo_gradient, 2) + self.v = helper.add(self.v, p, 1.0, 1.0) + + sv = helper.add(helper.sqrt(self.v), helper.ones(self.v, tau)) + t = helper.divide(self.m, sv) + model = helper.add(model_old, t, 1.0, learning_rate) + + return model diff --git a/fedn/fedn/network/combiner/aggregators/tests/test_fedavg.py b/fedn/network/combiner/aggregators/tests/test_fedavg.py similarity index 100% rename from fedn/fedn/network/combiner/aggregators/tests/test_fedavg.py rename to fedn/network/combiner/aggregators/tests/test_fedavg.py diff --git a/fedn/fedn/network/combiner/combiner.py b/fedn/network/combiner/combiner.py similarity index 82% rename from fedn/fedn/network/combiner/combiner.py rename to fedn/network/combiner/combiner.py index 046242253..af32f3d90 100644 --- a/fedn/fedn/network/combiner/combiner.py +++ b/fedn/network/combiner/combiner.py @@ -12,9 +12,8 @@ import fedn.network.grpc.fedn_pb2 as fedn import fedn.network.grpc.fedn_pb2_grpc as rpc -from fedn.common.log_config import (logger, set_log_level_from_string, - set_log_stream) from fedn.common.telemetry import trace_all_methods +from fedn.common.log_config import logger, set_log_level_from_string, set_log_stream from fedn.network.combiner.connect import ConnectorCombiner, Status from fedn.network.combiner.modelservice import ModelService from fedn.network.combiner.roundhandler import RoundHandler @@ -22,11 +21,12 @@ from fedn.network.storage.s3.repository import Repository from fedn.network.storage.statestore.mongostatestore import MongoStateStore -VALID_NAME_REGEX = '^[a-zA-Z0-9_-]*$' +VALID_NAME_REGEX = "^[a-zA-Z0-9_-]*$" class Role(Enum): - """ Enum for combiner roles. """ + """Enum for combiner roles.""" + WORKER = 1 COMBINER = 2 REDUCER = 3 @@ -34,7 +34,7 @@ class Role(Enum): def role_to_proto_role(role): - """ Convert a Role to a proto Role. + """Convert a Role to a proto Role. :param role: the role to convert :type role: :class:`fedn.network.combiner.server.Role` @@ -53,17 +53,17 @@ def role_to_proto_role(role): @trace_all_methods class Combiner(rpc.CombinerServicer, rpc.ReducerServicer, rpc.ConnectorServicer, rpc.ControlServicer): - """ Combiner gRPC server. + """Combiner gRPC server. :param config: configuration for the combiner :type config: dict """ def __init__(self, config): - """ Initialize Combiner server.""" + """Initialize Combiner server.""" - set_log_level_from_string(config.get('verbosity', "INFO")) - set_log_stream(config.get('logfile', None)) + set_log_level_from_string(config.get("verbosity", "INFO")) + set_log_stream(config.get("logfile", None)) # Client queues self.clients = {} @@ -71,24 +71,26 @@ def __init__(self, config): self.modelservice = ModelService() # Validate combiner name - match = re.search(VALID_NAME_REGEX, config['name']) + match = re.search(VALID_NAME_REGEX, config["name"]) if not match: - raise ValueError('Unallowed character in combiner name. Allowed characters: a-z, A-Z, 0-9, _, -.') + raise ValueError("Unallowed character in combiner name. Allowed characters: a-z, A-Z, 0-9, _, -.") - self.id = config['name'] + self.id = config["name"] self.role = Role.COMBINER - self.max_clients = config['max_clients'] + self.max_clients = config["max_clients"] # Connector to announce combiner to discover service (reducer) - announce_client = ConnectorCombiner(host=config['discover_host'], - port=config['discover_port'], - myhost=config['host'], - fqdn=config['fqdn'], - myport=config['port'], - token=config['token'], - name=config['name'], - secure=config['secure'], - verify=config['verify']) + announce_client = ConnectorCombiner( + host=config["discover_host"], + port=config["discover_port"], + myhost=config["host"], + fqdn=config["fqdn"], + myport=config["port"], + token=config["token"], + name=config["name"], + secure=config["secure"], + verify=config["verify"], + ) while True: # Announce combiner to discover service @@ -112,27 +114,20 @@ def __init__(self, config): logger.info("Status.Unassigned") time.sleep(5) - cert = announce_config['certificate'] - key = announce_config['key'] + cert = announce_config["certificate"] + key = announce_config["key"] - if announce_config['certificate']: - cert = base64.b64decode(announce_config['certificate']) # .decode('utf-8') - key = base64.b64decode(announce_config['key']) # .decode('utf-8') + if announce_config["certificate"]: + cert = base64.b64decode(announce_config["certificate"]) # .decode('utf-8') + key = base64.b64decode(announce_config["key"]) # .decode('utf-8') # Set up gRPC server configuration - grpc_config = {'port': config['port'], - 'secure': config['secure'], - 'certificate': cert, - 'key': key} + grpc_config = {"port": config["port"], "secure": config["secure"], "certificate": cert, "key": key} # Set up model repository - self.repository = Repository( - announce_config['storage']['storage_config']) + self.repository = Repository(announce_config["storage"]["storage_config"]) - self.statestore = MongoStateStore( - announce_config['statestore']['network_id'], - announce_config['statestore']['mongo_config'] - ) + self.statestore = MongoStateStore(announce_config["statestore"]["network_id"], announce_config["statestore"]["mongo_config"]) # Create gRPC server self.server = Server(self, self.modelservice, grpc_config) @@ -149,7 +144,7 @@ def __init__(self, config): self.server.start() def __whoami(self, client, instance): - """ Set the client id and role in a proto message. + """Set the client id and role in a proto message. :param client: the client to set the id and role for :type client: :class:`fedn.network.grpc.fedn_pb2.Client` @@ -163,7 +158,7 @@ def __whoami(self, client, instance): return client def request_model_update(self, config, clients=[]): - """ Ask clients to update the current global model. + """Ask clients to update the current global model. :param config: the model configuration to send to clients :type config: dict @@ -173,12 +168,12 @@ def request_model_update(self, config, clients=[]): """ # The request to be added to the client queue request = fedn.TaskRequest() - request.model_id = config['model_id'] + request.model_id = config["model_id"] request.correlation_id = str(uuid.uuid4()) request.timestamp = str(datetime.now()) request.data = json.dumps(config) request.type = fedn.StatusType.MODEL_UPDATE - request.session_id = config['session_id'] + request.session_id = config["session_id"] request.sender.name = self.id request.sender.role = fedn.COMBINER @@ -192,14 +187,12 @@ def request_model_update(self, config, clients=[]): self._put_request_to_client_queue(request, fedn.Queue.TASK_QUEUE) if len(clients) < 20: - logger.info("Sent model update request for model {} to clients {}".format( - request.model_id, clients)) + logger.info("Sent model update request for model {} to clients {}".format(request.model_id, clients)) else: - logger.info("Sent model update request for model {} to {} clients".format( - request.model_id, len(clients))) + logger.info("Sent model update request for model {} to {} clients".format(request.model_id, len(clients))) def request_model_validation(self, model_id, config, clients=[]): - """ Ask clients to validate the current global model. + """Ask clients to validate the current global model. :param model_id: the model id to validate :type model_id: str @@ -230,14 +223,12 @@ def request_model_validation(self, model_id, config, clients=[]): self._put_request_to_client_queue(request, fedn.Queue.TASK_QUEUE) if len(clients) < 20: - logger.info("Sent model validation request for model {} to clients {}".format( - request.model_id, clients)) + logger.info("Sent model validation request for model {} to clients {}".format(request.model_id, clients)) else: - logger.info("Sent model validation request for model {} to {} clients".format( - request.model_id, len(clients))) + logger.info("Sent model validation request for model {} to {} clients".format(request.model_id, len(clients))) def get_active_trainers(self): - """ Get a list of active trainers. + """Get a list of active trainers. :return: the list of active trainers :rtype: list @@ -246,7 +237,7 @@ def get_active_trainers(self): return trainers def get_active_validators(self): - """ Get a list of active validators. + """Get a list of active validators. :return: the list of active validators :rtype: list @@ -255,7 +246,7 @@ def get_active_validators(self): return validators def nr_active_trainers(self): - """ Get the number of active trainers. + """Get the number of active trainers. :return: the number of active trainers :rtype: int @@ -265,7 +256,7 @@ def nr_active_trainers(self): #################################################################################################################### def __join_client(self, client): - """ Add a client to the list of active clients. + """Add a client to the list of active clients. :param client: the client to add :type client: :class:`fedn.network.grpc.fedn_pb2.Client` @@ -275,7 +266,7 @@ def __join_client(self, client): self.clients[client.name] = {"lastseen": datetime.now(), "status": "offline"} def _subscribe_client_to_queue(self, client, queue_name): - """ Subscribe a client to the queue. + """Subscribe a client to the queue. :param client: the client to subscribe :type client: :class:`fedn.network.grpc.fedn_pb2.Client` @@ -287,7 +278,7 @@ def _subscribe_client_to_queue(self, client, queue_name): self.clients[client.name][queue_name] = queue.Queue() def __get_queue(self, client, queue_name): - """ Get the queue for a client. + """Get the queue for a client. :param client: the client to get the queue for :type client: :class:`fedn.network.grpc.fedn_pb2.Client` @@ -304,7 +295,7 @@ def __get_queue(self, client, queue_name): raise def _list_subscribed_clients(self, queue_name): - """ List all clients subscribed to a queue. + """List all clients subscribed to a queue. :param queue_name: the name of the queue :type queue_name: str @@ -318,7 +309,7 @@ def _list_subscribed_clients(self, queue_name): return subscribed_clients def _list_active_clients(self, channel): - """ List all clients that have sent a status message in the last 10 seconds. + """List all clients that have sent a status message in the last 10 seconds. :param channel: the name of the channel :type channel: str @@ -338,7 +329,7 @@ def _list_active_clients(self, channel): if (now - then) < timedelta(seconds=10): clients["active_clients"].append(client) # If client has changed status, update statestore - if status == "offline": + if status != "online": self.clients[client]["status"] = "online" clients["update_active_clients"].append(client) else: @@ -355,14 +346,14 @@ def _list_active_clients(self, channel): return clients["active_clients"] def _deamon_thread_client_status(self, timeout=10): - """ Deamon thread that checks for inactive clients and updates statestore. """ + """Deamon thread that checks for inactive clients and updates statestore.""" while True: time.sleep(timeout) # TODO: Also update validation clients self._list_active_clients(fedn.Queue.TASK_QUEUE) def _put_request_to_client_queue(self, request, queue_name): - """ Get a client specific queue and add a request to it. + """Get a client specific queue and add a request to it. The client is identified by the request.receiver. :param request: the request to send @@ -374,14 +365,11 @@ def _put_request_to_client_queue(self, request, queue_name): q = self.__get_queue(request.receiver, queue_name) q.put(request) except Exception as e: - logger.error("Failed to put request to client queue {} for client {}: {}".format( - queue_name, - request.receiver.name, - str(e))) + logger.error("Failed to put request to client queue {} for client {}: {}".format(queue_name, request.receiver.name, str(e))) raise def _send_status(self, status): - """ Report a status to backend db. + """Report a status to backend db. :param status: the status to report :type status: :class:`fedn.network.grpc.fedn_pb2.Status` @@ -411,7 +399,7 @@ def _flush_model_update_queue(self): # Controller Service def Start(self, control: fedn.ControlRequest, context): - """ Start a round of federated learning" + """Start a round of federated learning" :param control: the control request :type control: :class:`fedn.network.grpc.fedn_pb2.ControlRequest` @@ -439,7 +427,7 @@ def Start(self, control: fedn.ControlRequest, context): return response def SetAggregator(self, control: fedn.ControlRequest, context): - """ Set the active aggregator. + """Set the active aggregator. :param control: the control request :type control: :class:`fedn.network.grpc.fedn_pb2.ControlRequest` @@ -456,14 +444,14 @@ def SetAggregator(self, control: fedn.ControlRequest, context): response = fedn.ControlResponse() if status: - response.message = 'Success' + response.message = "Success" else: - response.message = 'Failed' + response.message = "Failed" return response def FlushAggregationQueue(self, control: fedn.ControlRequest, context): - """ Flush the queue. + """Flush the queue. :param control: the control request :type control: :class:`fedn.network.grpc.fedn_pb2.ControlRequest` @@ -477,16 +465,16 @@ def FlushAggregationQueue(self, control: fedn.ControlRequest, context): response = fedn.ControlResponse() if status: - response.message = 'Success' + response.message = "Success" else: - response.message = 'Failed' + response.message = "Failed" return response ############################################################################## def Stop(self, control: fedn.ControlRequest, context): - """ TODO: Not yet implemented. + """TODO: Not yet implemented. :param control: the control request :type control: :class:`fedn.network.grpc.fedn_pb2.ControlRequest` @@ -502,7 +490,7 @@ def Stop(self, control: fedn.ControlRequest, context): ##################################################################################################################### def SendStatus(self, status: fedn.Status, context): - """ A client RPC endpoint that accepts status messages. + """A client RPC endpoint that accepts status messages. :param status: the status message :type status: :class:`fedn.network.grpc.fedn_pb2.Status` @@ -519,7 +507,7 @@ def SendStatus(self, status: fedn.Status, context): return response def ListActiveClients(self, request: fedn.ListClientsRequest, context): - """ RPC endpoint that returns a ClientList containing the names of all active clients. + """RPC endpoint that returns a ClientList containing the names of all active clients. An active client has sent a status message / responded to a heartbeat request in the last 10 seconds. @@ -543,7 +531,7 @@ def ListActiveClients(self, request: fedn.ListClientsRequest, context): return clients def AcceptingClients(self, request: fedn.ConnectionRequest, context): - """ RPC endpoint that returns a ConnectionResponse indicating whether the server + """RPC endpoint that returns a ConnectionResponse indicating whether the server is accepting clients or not. :param request: the request (unused) @@ -554,8 +542,7 @@ def AcceptingClients(self, request: fedn.ConnectionRequest, context): :rtype: :class:`fedn.network.grpc.fedn_pb2.ConnectionResponse` """ response = fedn.ConnectionResponse() - active_clients = self._list_active_clients( - fedn.Queue.TASK_QUEUE) + active_clients = self._list_active_clients(fedn.Queue.TASK_QUEUE) try: requested = int(self.max_clients) @@ -574,7 +561,7 @@ def AcceptingClients(self, request: fedn.ConnectionRequest, context): return response def SendHeartbeat(self, heartbeat: fedn.Heartbeat, context): - """ RPC that lets clients send a hearbeat, notifying the server that + """RPC that lets clients send a hearbeat, notifying the server that the client is available. :param heartbeat: the heartbeat @@ -599,7 +586,7 @@ def SendHeartbeat(self, heartbeat: fedn.Heartbeat, context): # Combiner Service def TaskStream(self, response, context): - """ A server stream RPC endpoint (Update model). Messages from client stream. + """A server stream RPC endpoint (Update model). Messages from client stream. :param response: the response :type response: :class:`fedn.network.grpc.fedn_pb2.ModelUpdateRequest` @@ -611,17 +598,15 @@ def TaskStream(self, response, context): metadata = context.invocation_metadata() if metadata: metadata = dict(metadata) - logger.info("grpc.Combiner.TaskStream: Client connected: {}\n".format(metadata['client'])) + logger.info("grpc.Combiner.TaskStream: Client connected: {}\n".format(metadata["client"])) - status = fedn.Status( - status="Client {} connecting to TaskStream.".format(client.name)) + status = fedn.Status(status="Client {} connecting to TaskStream.".format(client.name)) status.log_level = fedn.Status.INFO status.timestamp.GetCurrentTime() self.__whoami(status.sender, self) - self._subscribe_client_to_queue( - client, fedn.Queue.TASK_QUEUE) + self._subscribe_client_to_queue(client, fedn.Queue.TASK_QUEUE) q = self.__get_queue(client, fedn.Queue.TASK_QUEUE) self._send_status(status) @@ -642,7 +627,7 @@ def TaskStream(self, response, context): logger.error("Error in ModelUpdateRequestStream: {}".format(e)) def SendModelUpdate(self, request, context): - """ Send a model update response. + """Send a model update response. :param request: the request :type request: :class:`fedn.network.grpc.fedn_pb2.ModelUpdate` @@ -654,8 +639,7 @@ def SendModelUpdate(self, request, context): self.round_handler.aggregator.on_model_update(request) response = fedn.Response() - response.response = "RECEIVED ModelUpdate {} from client {}".format( - response, response.sender.name) + response.response = "RECEIVED ModelUpdate {} from client {}".format(response, response.sender.name) return response # TODO Fill later def register_model_validation(self, validation): @@ -668,7 +652,7 @@ def register_model_validation(self, validation): self.statestore.report_validation(validation) def SendModelValidation(self, request, context): - """ Send a model validation response. + """Send a model validation response. :param request: the request :type request: :class:`fedn.network.grpc.fedn_pb2.ModelValidation` @@ -682,17 +666,15 @@ def SendModelValidation(self, request, context): self.register_model_validation(request) response = fedn.Response() - response.response = "RECEIVED ModelValidation {} from client {}".format( - response, response.sender.name) + response.response = "RECEIVED ModelValidation {} from client {}".format(response, response.sender.name) return response #################################################################################################################### def run(self): - """ Start the server.""" + """Start the server.""" - logger.info("COMBINER: {} started, ready for gRPC requests.".format( - self.id)) + logger.info("COMBINER: {} started, ready for gRPC requests.".format(self.id)) try: while True: signal.pause() diff --git a/fedn/fedn/network/combiner/combiner_tests.py b/fedn/network/combiner/combiner_tests.py similarity index 100% rename from fedn/fedn/network/combiner/combiner_tests.py rename to fedn/network/combiner/combiner_tests.py diff --git a/fedn/fedn/network/combiner/connect.py b/fedn/network/combiner/connect.py similarity index 75% rename from fedn/fedn/network/combiner/connect.py rename to fedn/network/combiner/connect.py index 1a8afc776..eb3fda9aa 100644 --- a/fedn/fedn/network/combiner/connect.py +++ b/fedn/network/combiner/connect.py @@ -14,7 +14,8 @@ class Status(enum.Enum): - """ Enum for representing the status of a combiner announcement.""" + """Enum for representing the status of a combiner announcement.""" + Unassigned = 0 Assigned = 1 TryAgain = 2 @@ -24,7 +25,7 @@ class Status(enum.Enum): @trace_all_methods class ConnectorCombiner: - """ Connector for annnouncing combiner to the FEDn network. + """Connector for annnouncing combiner to the FEDn network. :param host: host of discovery service :type host: str @@ -47,7 +48,7 @@ class ConnectorCombiner: """ def __init__(self, host, port, myhost, fqdn, myport, token, name, secure=False, verify=False): - """ Initialize the ConnectorCombiner. + """Initialize the ConnectorCombiner. :param host: The host of the discovery service. :type host: str @@ -75,22 +76,20 @@ def __init__(self, host, port, myhost, fqdn, myport, token, name, secure=False, self.myhost = myhost self.myport = myport self.token = token - self.token_scheme = os.environ.get('FEDN_AUTH_SCHEME', 'Bearer') + self.token_scheme = os.environ.get("FEDN_AUTH_SCHEME", "Bearer") self.name = name self.secure = secure self.verify = verify if not self.token: - self.token = os.environ.get('FEDN_AUTH_TOKEN', None) + self.token = os.environ.get("FEDN_AUTH_TOKEN", None) # for https we assume a an ingress handles permanent redirect (308) self.prefix = "http://" if port: - self.connect_string = "{}{}:{}".format( - self.prefix, self.host, self.port) + self.connect_string = "{}{}:{}".format(self.prefix, self.host, self.port) else: - self.connect_string = "{}{}".format( - self.prefix, self.host) + self.connect_string = "{}{}".format(self.prefix, self.host) logger.info("Setting connection string to {}".format(self.connect_string)) @@ -101,24 +100,21 @@ def announce(self): :return: Tuple with announcement Status, FEDn network configuration if sucessful, else None. :rtype: :class:`fedn.network.combiner.connect.Status`, str """ - payload = { - "combiner_id": self.name, - "address": self.myhost, - "fqdn": self.fqdn, - "port": self.myport, - "secure_grpc": self.secure - } - url_prefix = os.environ.get('FEDN_CUSTOM_URL_PREFIX', '') + payload = {"combiner_id": self.name, "address": self.myhost, "fqdn": self.fqdn, "port": self.myport, "secure_grpc": self.secure} + url_prefix = os.environ.get("FEDN_CUSTOM_URL_PREFIX", "") try: - retval = requests.post(self.connect_string + url_prefix + '/add_combiner', json=payload, - verify=self.verify, - headers={'Authorization': f'{self.token_scheme} {self.token}'}) + retval = requests.post( + self.connect_string + url_prefix + "/add_combiner", + json=payload, + verify=self.verify, + headers={"Authorization": f"{self.token_scheme} {self.token}"}, + ) except Exception: return Status.Unassigned, {} if retval.status_code == 400: # Get error messange from response - reason = retval.json()['message'] + reason = retval.json()["message"] return Status.UnMatchedConfig, reason if retval.status_code == 401: @@ -126,8 +122,8 @@ def announce(self): return Status.UnAuthorized, reason if retval.status_code >= 200 and retval.status_code < 204: - if retval.json()['status'] == 'retry': - reason = retval.json()['message'] + if retval.json()["status"] == "retry": + reason = retval.json()["message"] return Status.TryAgain, reason return Status.Assigned, retval.json() diff --git a/fedn/fedn/network/combiner/interfaces.py b/fedn/network/combiner/interfaces.py similarity index 77% rename from fedn/fedn/network/combiner/interfaces.py rename to fedn/network/combiner/interfaces.py index 21e1ac846..f35d1fb8c 100644 --- a/fedn/fedn/network/combiner/interfaces.py +++ b/fedn/network/combiner/interfaces.py @@ -17,7 +17,7 @@ class CombinerUnavailableError(Exception): @trace_all_methods class Channel: - """ Wrapper for a gRPC channel. + """Wrapper for a gRPC channel. :param address: The address for the gRPC server. :type address: str @@ -28,7 +28,7 @@ class Channel: """ def __init__(self, address, port, certificate=None): - """ Create a channel. + """Create a channel. If a valid certificate is given, a secure channel is created, else insecure. @@ -44,16 +44,13 @@ def __init__(self, address, port, certificate=None): self.certificate = certificate if self.certificate: - credentials = grpc.ssl_channel_credentials( - root_certificates=copy.deepcopy(certificate)) - self.channel = grpc.secure_channel('{}:{}'.format( - self.address, str(self.port)), credentials) + credentials = grpc.ssl_channel_credentials(root_certificates=copy.deepcopy(certificate)) + self.channel = grpc.secure_channel("{}:{}".format(self.address, str(self.port)), credentials) else: - self.channel = grpc.insecure_channel( - '{}:{}'.format(self.address, str(self.port))) + self.channel = grpc.insecure_channel("{}:{}".format(self.address, str(self.port))) def get_channel(self): - """ Get a channel. + """Get a channel. :return: An instance of a gRPC channel :rtype: :class:`grpc.Channel` @@ -63,7 +60,7 @@ def get_channel(self): @trace_all_methods class CombinerInterface: - """ Interface for the Combiner (aggregation server). + """Interface for the Combiner (aggregation server). Abstraction on top of the gRPC server servicer. :param parent: The parent combiner (controller) @@ -87,7 +84,7 @@ class CombinerInterface: """ def __init__(self, parent, name, address, fqdn, port, certificate=None, key=None, ip=None, config=None): - """ Initialize the combiner interface.""" + """Initialize the combiner interface.""" self.parent = parent self.name = name self.address = address @@ -98,15 +95,13 @@ def __init__(self, parent, name, address, fqdn, port, certificate=None, key=None self.ip = ip if not config: - self.config = { - 'max_clients': 8 - } + self.config = {"max_clients": 8} else: self.config = config @classmethod def from_json(combiner_config): - """ Initialize the combiner config from a json document. + """Initialize the combiner config from a json document. :parameter combiner_config: The combiner configuration. :type combiner_config: dict @@ -116,34 +111,34 @@ def from_json(combiner_config): return CombinerInterface(**combiner_config) def to_dict(self): - """ Export combiner configuration to a dictionary. + """Export combiner configuration to a dictionary. :return: A dictionary with the combiner configuration. :rtype: dict """ data = { - 'parent': self.parent, - 'name': self.name, - 'address': self.address, - 'fqdn': self.fqdn, - 'port': self.port, - 'ip': self.ip, - 'certificate': None, - 'key': None, - 'config': self.config + "parent": self.parent, + "name": self.name, + "address": self.address, + "fqdn": self.fqdn, + "port": self.port, + "ip": self.ip, + "certificate": None, + "key": None, + "config": self.config, } if self.certificate: cert_b64 = base64.b64encode(self.certificate) key_b64 = base64.b64encode(self.key) - data['certificate'] = str(cert_b64).split('\'')[1] - data['key'] = str(key_b64).split('\'')[1] + data["certificate"] = str(cert_b64).split("'")[1] + data["key"] = str(key_b64).split("'")[1] return data def to_json(self): - """ Export combiner configuration to json. + """Export combiner configuration to json. :return: A json document with the combiner configuration. :rtype: str @@ -151,34 +146,33 @@ def to_json(self): return json.dumps(self.to_dict()) def get_certificate(self): - """ Get combiner certificate. + """Get combiner certificate. :return: The combiner certificate. :rtype: str, None if no certificate is set. """ if self.certificate: cert_b64 = base64.b64encode(self.certificate) - return str(cert_b64).split('\'')[1] + return str(cert_b64).split("'")[1] else: return None def get_key(self): - """ Get combiner key. + """Get combiner key. :return: The combiner key. :rtype: str, None if no key is set. """ if self.key: key_b64 = base64.b64encode(self.key) - return str(key_b64).split('\'')[1] + return str(key_b64).split("'")[1] else: return None def flush_model_update_queue(self): - """ Reset the model update queue on the combiner. """ + """Reset the model update queue on the combiner.""" - channel = Channel(self.address, self.port, - self.certificate).get_channel() + channel = Channel(self.address, self.port, self.certificate).get_channel() control = rpc.ControlStub(channel) request = fedn.ControlRequest() @@ -192,14 +186,13 @@ def flush_model_update_queue(self): raise def set_aggregator(self, aggregator): - """ Set the active aggregator module. + """Set the active aggregator module. :param aggregator: The name of the aggregator module. :type config: str """ - channel = Channel(self.address, self.port, - self.certificate).get_channel() + channel = Channel(self.address, self.port, self.certificate).get_channel() control = rpc.ControlStub(channel) request = fedn.ControlRequest() @@ -216,15 +209,14 @@ def set_aggregator(self, aggregator): raise def submit(self, config): - """ Submit a compute plan to the combiner. + """Submit a compute plan to the combiner. :param config: The job configuration. :type config: dict :return: Server ControlResponse object. :rtype: :class:`fedn.network.grpc.fedn_pb2.ControlResponse` """ - channel = Channel(self.address, self.port, - self.certificate).get_channel() + channel = Channel(self.address, self.port, self.certificate).get_channel() control = rpc.ControlStub(channel) request = fedn.ControlRequest() request.command = fedn.Command.START @@ -244,7 +236,7 @@ def submit(self, config): return response def get_model(self, id, timeout=10): - """ Download a model from the combiner server. + """Download a model from the combiner server. :param id: The model id. :type id: str @@ -252,8 +244,7 @@ def get_model(self, id, timeout=10): :rtype: :class:`io.BytesIO`, None if the model is not available. """ - channel = Channel(self.address, self.port, - self.certificate).get_channel() + channel = Channel(self.address, self.port, self.certificate).get_channel() modelservice = rpc.ModelServiceStub(channel) data = BytesIO() @@ -279,13 +270,12 @@ def get_model(self, id, timeout=10): continue def allowing_clients(self): - """ Check if the combiner is allowing additional client connections. + """Check if the combiner is allowing additional client connections. :return: True if accepting, else False. :rtype: bool """ - channel = Channel(self.address, self.port, - self.certificate).get_channel() + channel = Channel(self.address, self.port, self.certificate).get_channel() connector = rpc.ConnectorStub(channel) request = fedn.ConnectionRequest() @@ -306,7 +296,7 @@ def allowing_clients(self): return False def list_active_clients(self, queue=1): - """ List active clients. + """List active clients. :param queue: The channel (queue) to use (optional). Default is 1 = MODEL_UPDATE_REQUESTS channel. see :class:`fedn.network.grpc.fedn_pb2.Channel` @@ -314,8 +304,7 @@ def list_active_clients(self, queue=1): :return: A list of active clients. :rtype: json """ - channel = Channel(self.address, self.port, - self.certificate).get_channel() + channel = Channel(self.address, self.port, self.certificate).get_channel() control = rpc.ConnectorStub(channel) request = fedn.ListClientsRequest() request.channel = queue diff --git a/fedn/fedn/network/combiner/modelservice.py b/fedn/network/combiner/modelservice.py similarity index 84% rename from fedn/fedn/network/combiner/modelservice.py rename to fedn/network/combiner/modelservice.py index 150a7f6a8..78b27b6d0 100644 --- a/fedn/fedn/network/combiner/modelservice.py +++ b/fedn/network/combiner/modelservice.py @@ -23,11 +23,9 @@ def upload_request_generator(mdl, id): while True: b = mdl.read(CHUNK_SIZE) if b: - result = fedn.ModelRequest( - data=b, id=id, status=fedn.ModelStatus.IN_PROGRESS) + result = fedn.ModelRequest(data=b, id=id, status=fedn.ModelStatus.IN_PROGRESS) else: - result = fedn.ModelRequest( - id=id, data=None, status=fedn.ModelStatus.OK) + result = fedn.ModelRequest(id=id, data=None, status=fedn.ModelStatus.OK) yield result if not b: break @@ -51,7 +49,7 @@ def model_as_bytesIO(model): @tracer.start_as_current_span(name="get_tmp_path") def get_tmp_path(): - """ Return a temporary output path compatible with save_model, load_model. """ + """Return a temporary output path compatible with save_model, load_model.""" fd, path = tempfile.mkstemp() os.close(fd) return path @@ -59,7 +57,7 @@ def get_tmp_path(): @tracer.start_as_current_span(name="load_model_from_BytesIO") def load_model_from_BytesIO(model_bytesio, helper): - """ Load a model from a BytesIO object. + """Load a model from a BytesIO object. :param model_bytesio: A BytesIO object containing the model. :type model_bytesio: :class:`io.BytesIO` :param helper: The helper object for the model. @@ -68,7 +66,7 @@ def load_model_from_BytesIO(model_bytesio, helper): :rtype: return type of helper.load """ path = get_tmp_path() - with open(path, 'wb') as fh: + with open(path, "wb") as fh: fh.write(model_bytesio) fh.flush() model = helper.load(path) @@ -78,7 +76,7 @@ def load_model_from_BytesIO(model_bytesio, helper): @tracer.start_as_current_span(name="serialize_model_to_BytesIO") def serialize_model_to_BytesIO(model, helper): - """ Serialize a model to a BytesIO object. + """Serialize a model to a BytesIO object. :param model: The model object. :type model: return type of helper.load @@ -91,7 +89,7 @@ def serialize_model_to_BytesIO(model, helper): a = BytesIO() a.seek(0, 0) - with open(outfile_name, 'rb') as f: + with open(outfile_name, "rb") as f: a.write(f.read()) a.seek(0) os.unlink(outfile_name) @@ -100,15 +98,13 @@ def serialize_model_to_BytesIO(model, helper): @trace_all_methods class ModelService(rpc.ModelServiceServicer): - """ Service for handling download and upload of models to the server. - - """ + """Service for handling download and upload of models to the server.""" def __init__(self): self.temp_model_storage = TempModelStorage() def exist(self, model_id): - """ Check if a model exists on the server. + """Check if a model exists on the server. :param model_id: The model id. :return: True if the model exists, else False. @@ -116,7 +112,7 @@ def exist(self, model_id): return self.temp_model_storage.exist(model_id) def get_model(self, id): - """ Download model with id 'id' from server. + """Download model with id 'id' from server. :param id: The model id. :type id: str @@ -138,7 +134,7 @@ def get_model(self, id): return None def set_model(self, model, id): - """ Upload model to server. + """Upload model to server. :param model: A model object (BytesIO) :type model: :class:`io.BytesIO` @@ -151,7 +147,7 @@ def set_model(self, model, id): # Model Service def Upload(self, request_iterator, context): - """ RPC endpoints for uploading a model. + """RPC endpoints for uploading a model. :param request_iterator: The model request iterator. :type request_iterator: :class:`fedn.network.grpc.fedn_pb2.ModelRequest` @@ -168,8 +164,7 @@ def Upload(self, request_iterator, context): self.temp_model_storage.set_model_metadata(request.id, fedn.ModelStatus.IN_PROGRESS) if request.status == fedn.ModelStatus.OK and not request.data: - result = fedn.ModelResponse(id=request.id, status=fedn.ModelStatus.OK, - message="Got model successfully.") + result = fedn.ModelResponse(id=request.id, status=fedn.ModelStatus.OK, message="Got model successfully.") # self.temp_model_storage_metadata.update({request.id: fedn.ModelStatus.OK}) self.temp_model_storage.set_model_metadata(request.id, fedn.ModelStatus.OK) self.temp_model_storage.get_ptr(request.id).flush() @@ -177,7 +172,7 @@ def Upload(self, request_iterator, context): return result def Download(self, request, context): - """ RPC endpoints for downloading a model. + """RPC endpoints for downloading a model. :param request: The model request object. :type request: :class:`fedn.network.grpc.fedn_pb2.ModelRequest` @@ -186,11 +181,11 @@ def Download(self, request, context): :return: A model response iterator. :rtype: :class:`fedn.network.grpc.fedn_pb2.ModelResponse` """ - logger.info(f'grpc.ModelService.Download: {request.sender.role}:{request.sender.name} requested model {request.id}') + logger.info(f"grpc.ModelService.Download: {request.sender.role}:{request.sender.name} requested model {request.id}") try: status = self.temp_model_storage.get_model_metadata(request.id) if status != fedn.ModelStatus.OK: - logger.error(f'model file is not ready: {request.id}, status: {status}') + logger.error(f"model file is not ready: {request.id}, status: {status}") yield fedn.ModelResponse(id=request.id, data=None, status=status) except Exception: logger.error("Error file does not exist: {}".format(request.id)) @@ -199,7 +194,7 @@ def Download(self, request, context): try: obj = self.temp_model_storage.get(request.id) if obj is None: - raise Exception(f'File not found: {request.id}') + raise Exception(f"File not found: {request.id}") with obj as f: while True: piece = f.read(CHUNK_SIZE) diff --git a/fedn/fedn/network/combiner/roundhandler.py b/fedn/network/combiner/roundhandler.py similarity index 78% rename from fedn/fedn/network/combiner/roundhandler.py rename to fedn/network/combiner/roundhandler.py index ff224e897..72fc5bc10 100644 --- a/fedn/fedn/network/combiner/roundhandler.py +++ b/fedn/network/combiner/roundhandler.py @@ -1,3 +1,4 @@ +import ast import queue import random import sys @@ -7,9 +8,9 @@ from fedn.common.log_config import logger from fedn.common.telemetry import trace_all_methods from fedn.network.combiner.aggregators.aggregatorbase import get_aggregator -from fedn.network.combiner.modelservice import (load_model_from_BytesIO, - serialize_model_to_BytesIO) +from fedn.network.combiner.modelservice import load_model_from_BytesIO, serialize_model_to_BytesIO from fedn.utils.helpers.helpers import get_helper +from fedn.utils.parameters import Parameters class ModelUpdateError(Exception): @@ -18,7 +19,7 @@ class ModelUpdateError(Exception): @trace_all_methods class RoundHandler: - """ Round handler. + """Round handler. The round handler processes requests from the global controller to produce model updates and perform model validations. @@ -34,7 +35,7 @@ class RoundHandler: """ def __init__(self, storage, server, modelservice): - """ Initialize the RoundHandler.""" + """Initialize the RoundHandler.""" self.round_configs = queue.Queue() self.storage = storage @@ -53,12 +54,12 @@ def push_round_config(self, round_config): :rtype: str """ try: - round_config['_job_id'] = str(uuid.uuid4()) + round_config["_job_id"] = str(uuid.uuid4()) self.round_configs.put(round_config) except Exception: logger.error("Failed to push round config.") raise - return round_config['_job_id'] + return round_config["_job_id"] def load_model_update(self, helper, model_id): """Load model update with id model_id into its memory representation. @@ -74,8 +75,7 @@ def load_model_update(self, helper, model_id): try: model = load_model_from_BytesIO(model_str.getbuffer(), helper) except IOError: - logger.warning( - "AGGREGATOR({}): Failed to load model!".format(self.name)) + logger.warning("AGGREGATOR({}): Failed to load model!".format(self.name)) else: raise ModelUpdateError("Failed to load model.") @@ -108,7 +108,7 @@ def load_model_update_str(self, model_id, retry=3): return model_str def waitforit(self, config, buffer_size=100, polling_interval=0.1): - """ Defines the policy for how long the server should wait before starting to aggregate models. + """Defines the policy for how long the server should wait before starting to aggregate models. The policy is as follows: 1. Wait a maximum of time_window time until the round times out. @@ -122,7 +122,7 @@ def waitforit(self, config, buffer_size=100, polling_interval=0.1): :type polling_interval: float, optional """ - time_window = float(config['round_timeout']) + time_window = float(config["round_timeout"]) tt = 0.0 while tt < time_window: @@ -143,22 +143,21 @@ def _training_round(self, config, clients): :rtype: model, dict """ - logger.info( - "ROUNDHANDLER: Initiating training round, participating clients: {}".format(clients)) + logger.info("ROUNDHANDLER: Initiating training round, participating clients: {}".format(clients)) meta = {} - meta['nr_expected_updates'] = len(clients) - meta['nr_required_updates'] = int(config['clients_required']) - meta['timeout'] = float(config['round_timeout']) + meta["nr_expected_updates"] = len(clients) + meta["nr_required_updates"] = int(config["clients_required"]) + meta["timeout"] = float(config["round_timeout"]) # Request model updates from all active clients. self.server.request_model_update(config, clients=clients) # If buffer_size is -1 (default), the round terminates when/if all clients have completed. - if int(config['buffer_size']) == -1: + if int(config["buffer_size"]) == -1: buffer_size = len(clients) else: - buffer_size = int(config['buffer_size']) + buffer_size = int(config["buffer_size"]) # Wait / block until the round termination policy has been met. self.waitforit(config, buffer_size=buffer_size) @@ -168,19 +167,25 @@ def _training_round(self, config, clients): data = None try: - helper = get_helper(config['helper_type']) - logger.info("Config delete_models_storage: {}".format(config['delete_models_storage'])) - if config['delete_models_storage'] == 'True': + helper = get_helper(config["helper_type"]) + logger.info("Config delete_models_storage: {}".format(config["delete_models_storage"])) + if config["delete_models_storage"] == "True": delete_models = True else: delete_models = False - model, data = self.aggregator.combine_models(helper=helper, - delete_models=delete_models) + + if "aggregator_kwargs" in config.keys(): + dict_parameters = ast.literal_eval(config["aggregator_kwargs"]) + parameters = Parameters(dict_parameters) + else: + parameters = None + + model, data = self.aggregator.combine_models(helper=helper, delete_models=delete_models, parameters=parameters) except Exception as e: logger.warning("AGGREGATION FAILED AT COMBINER! {}".format(e)) - meta['time_combination'] = time.time() - tic - meta['aggregation_time'] = data + meta["time_combination"] = time.time() - tic + meta["aggregation_time"] = data return model, meta def _validation_round(self, config, clients, model_id): @@ -229,7 +234,7 @@ def stage_model(self, model_id, timeout_retry=3, retry=2): self.modelservice.set_model(model, model_id) def _assign_round_clients(self, n, type="trainers"): - """ Obtain a list of clients(trainers or validators) to ask for updates in this round. + """Obtain a list of clients(trainers or validators) to ask for updates in this round. :param n: Size of a random set taken from active trainers(clients), if n > "active trainers" all is used :type n: int @@ -268,30 +273,27 @@ def _check_nr_round_clients(self, config): """ active = self.server.nr_active_trainers() - if active >= int(config['clients_required']): - logger.info("Number of clients required ({0}) to start round met {1}.".format( - config['clients_required'], active)) + if active >= int(config["clients_required"]): + logger.info("Number of clients required ({0}) to start round met {1}.".format(config["clients_required"], active)) return True else: logger.info("Too few clients to start round.") return False def execute_validation_round(self, round_config): - """ Coordinate validation rounds as specified in config. + """Coordinate validation rounds as specified in config. :param round_config: The round config object. :type round_config: dict """ - model_id = round_config['model_id'] - logger.info( - "COMBINER orchestrating validation of model {}".format(model_id)) + model_id = round_config["model_id"] + logger.info("COMBINER orchestrating validation of model {}".format(model_id)) self.stage_model(model_id) - validators = self._assign_round_clients( - self.server.max_clients, type="validators") + validators = self._assign_round_clients(self.server.max_clients, type="validators") self._validation_round(round_config, validators, model_id) def execute_training_round(self, config): - """ Coordinates clients to execute training tasks. + """Coordinates clients to execute training tasks. :param config: The round config object. :type config: dict @@ -299,39 +301,37 @@ def execute_training_round(self, config): :rtype: dict """ - logger.info("Processing training round, job_id {}".format(config['_job_id'])) + logger.info("Processing training round, job_id {}".format(config["_job_id"])) data = {} - data['config'] = config - data['round_id'] = config['round_id'] + data["config"] = config + data["round_id"] = config["round_id"] # Download model to update and set in temp storage. - self.stage_model(config['model_id']) + self.stage_model(config["model_id"]) clients = self._assign_round_clients(self.server.max_clients) model, meta = self._training_round(config, clients) - data['data'] = meta + data["data"] = meta if model is None: - logger.warning( - "\t Failed to update global model in round {0}!".format(config['round_id'])) + logger.warning("\t Failed to update global model in round {0}!".format(config["round_id"])) if model is not None: - helper = get_helper(config['helper_type']) + helper = get_helper(config["helper_type"]) a = serialize_model_to_BytesIO(model, helper) model_id = self.storage.set_model(a.read(), is_file=False) a.close() - data['model_id'] = model_id + data["model_id"] = model_id - logger.info( - "TRAINING ROUND COMPLETED. Aggregated model id: {}, Job id: {}".format(model_id, config['_job_id'])) + logger.info("TRAINING ROUND COMPLETED. Aggregated model id: {}, Job id: {}".format(model_id, config["_job_id"])) # Delete temp model - self.modelservice.temp_model_storage.delete(config['model_id']) + self.modelservice.temp_model_storage.delete(config["model_id"]) return data def run(self, polling_interval=1.0): - """ Main control loop. Execute rounds based on round config on the queue. + """Main control loop. Execute rounds based on round config on the queue. :param polling_interval: The polling interval in seconds for checking if a new job/config is available. :type polling_interval: float @@ -346,23 +346,22 @@ def run(self, polling_interval=1.0): round_meta = {} if ready: - if round_config['task'] == 'training': + if round_config["task"] == "training": tic = time.time() round_meta = self.execute_training_round(round_config) - round_meta['time_exec_training'] = time.time() - \ - tic - round_meta['status'] = "Success" - round_meta['name'] = self.server.id + round_meta["time_exec_training"] = time.time() - tic + round_meta["status"] = "Success" + round_meta["name"] = self.server.id self.server.statestore.set_round_combiner_data(round_meta) - elif round_config['task'] == 'validation' or round_config['task'] == 'inference': + elif round_config["task"] == "validation" or round_config["task"] == "inference": self.execute_validation_round(round_config) else: logger.warning("config contains unkown task type.") else: round_meta = {} - round_meta['status'] = "Failed" - round_meta['reason'] = "Failed to meet client allocation requirements for this round config." - logger.warning("{0}".format(round_meta['reason'])) + round_meta["status"] = "Failed" + round_meta["reason"] = "Failed to meet client allocation requirements for this round config." + logger.warning("{0}".format(round_meta["reason"])) self.round_configs.task_done() except queue.Empty: diff --git a/fedn/fedn/network/config.py b/fedn/network/config.py similarity index 56% rename from fedn/fedn/network/config.py rename to fedn/network/config.py index a9e8773f4..0c32949b8 100644 --- a/fedn/fedn/network/config.py +++ b/fedn/network/config.py @@ -6,16 +6,14 @@ class Config(ABC): class ReducerConfig(Config): - """ Configuration for the Reducer component. """ + """Configuration for the Reducer component.""" + compute_bundle_dir = None models_dir = None initial_model = None - storage_backend = { - 'type': 's3', 'settings': - {'bucket': 'models'} - } + storage_backend = {"type": "s3", "settings": {"bucket": "models"}} def __init__(self): pass diff --git a/fedn/fedn/network/controller/__init__.py b/fedn/network/controller/__init__.py similarity index 100% rename from fedn/fedn/network/controller/__init__.py rename to fedn/network/controller/__init__.py diff --git a/fedn/fedn/network/controller/control.py b/fedn/network/controller/control.py similarity index 86% rename from fedn/fedn/network/controller/control.py rename to fedn/network/controller/control.py index fc00df4c5..897d4175a 100644 --- a/fedn/fedn/network/controller/control.py +++ b/fedn/network/controller/control.py @@ -3,8 +3,7 @@ import time import uuid -from tenacity import (retry, retry_if_exception_type, stop_after_delay, - wait_random) +from tenacity import retry, retry_if_exception_type, stop_after_delay, wait_random from fedn.common.log_config import logger from fedn.common.telemetry import trace_all_methods @@ -55,10 +54,10 @@ def __init__(self, message): class CombinersNotDoneException(Exception): - """ Exception class for when model is None """ + """Exception class for when model is None""" def __init__(self, message): - """ Constructor method. + """Constructor method. :param message: The exception message. :type message: str @@ -110,9 +109,9 @@ def session(self, config): last_round = int(self.get_latest_round_id()) for combiner in self.network.get_combiners(): - combiner.set_aggregator(config['aggregator']) + combiner.set_aggregator(config["aggregator"]) - self.set_session_status(config['session_id'], 'Started') + self.set_session_status(config["session_id"], "Started") # Execute the rounds in this session for round in range(1, int(config["rounds"] + 1)): # Increment the round number @@ -131,11 +130,11 @@ def session(self, config): config["model_id"] = self.statestore.get_latest_model() # TODO: Report completion of session - self.set_session_status(config['session_id'], 'Finished') + self.set_session_status(config["session_id"], "Finished") self._state = ReducerState.idle def round(self, session_config, round_id): - """ Execute one global round. + """Execute one global round. : param session_config: The session config. : type session_config: dict @@ -144,11 +143,11 @@ def round(self, session_config, round_id): """ - self.create_round({'round_id': round_id, 'status': "Pending"}) + self.create_round({"round_id": round_id, "status": "Pending"}) if len(self.network.get_combiners()) < 1: logger.warning("Round cannot start, no combiners connected!") - self.set_round_status(round_id, 'Failed') + self.set_round_status(round_id, "Failed") return None, self.statestore.get_round(round_id) # Assemble round config for this global round @@ -167,11 +166,10 @@ def round(self, session_config, round_id): round_start = self.evaluate_round_start_policy(participating_combiners) if round_start: - logger.info("round start policy met, {} participating combiners.".format( - len(participating_combiners))) + logger.info("round start policy met, {} participating combiners.".format(len(participating_combiners))) else: logger.warning("Round start policy not met, skipping round!") - self.set_round_status(round_id, 'Failed') + self.set_round_status(round_id, "Failed") return None, self.statestore.get_round(round_id) # Ask participating combiners to coordinate model updates @@ -183,18 +181,19 @@ def round(self, session_config, round_id): def do_if_round_times_out(result): logger.warning("Round timed out!") - @ retry(wait=wait_random(min=1.0, max=2.0), - stop=stop_after_delay(session_config['round_timeout']), - retry_error_callback=do_if_round_times_out, - retry=retry_if_exception_type(CombinersNotDoneException)) + @retry( + wait=wait_random(min=1.0, max=2.0), + stop=stop_after_delay(session_config["round_timeout"]), + retry_error_callback=do_if_round_times_out, + retry=retry_if_exception_type(CombinersNotDoneException), + ) def combiners_done(): - round = self.statestore.get_round(round_id) - if 'combiners' not in round: + if "combiners" not in round: logger.info("Waiting for combiners to update model...") raise CombinersNotDoneException("Combiners have not yet reported.") - if len(round['combiners']) < len(participating_combiners): + if len(round["combiners"]) < len(participating_combiners): logger.info("Waiting for combiners to update model...") raise CombinersNotDoneException("All combiners have not yet reported.") @@ -205,11 +204,10 @@ def combiners_done(): # Due to the distributed nature of the computation, there might be a # delay before combiners have reported the round data to the db, # so we need some robustness here. - @ retry(wait=wait_random(min=0.1, max=1.0), - retry=retry_if_exception_type(KeyError)) + @retry(wait=wait_random(min=0.1, max=1.0), retry=retry_if_exception_type(KeyError)) def check_combiners_done_reporting(): round = self.statestore.get_round(round_id) - combiners = round['combiners'] + combiners = round["combiners"] return combiners _ = check_combiners_done_reporting() @@ -218,7 +216,7 @@ def check_combiners_done_reporting(): round_valid = self.evaluate_round_validity_policy(round) if not round_valid: logger.error("Round failed. Invalid - evaluate_round_validity_policy: False") - self.set_round_status(round_id, 'Failed') + self.set_round_status(round_id, "Failed") return None, self.statestore.get_round(round_id) logger.info("Reducing combiner level models...") @@ -226,12 +224,12 @@ def check_combiners_done_reporting(): round_data = {} try: round = self.statestore.get_round(round_id) - model, data = self.reduce(round['combiners']) - round_data['reduce'] = data + model, data = self.reduce(round["combiners"]) + round_data["reduce"] = data logger.info("Done reducing models from combiners!") except Exception as e: logger.error("Failed to reduce models from combiners, reason: {}".format(e)) - self.set_round_status(round_id, 'Failed') + self.set_round_status(round_id, "Failed") return None, self.statestore.get_round(round_id) # Commit the new global model to the model trail @@ -239,20 +237,16 @@ def check_combiners_done_reporting(): logger.info("Committing global model to model trail...") tic = time.time() model_id = uuid.uuid4() - session_id = ( - session_config["session_id"] - if "session_id" in session_config - else None - ) + session_id = session_config["session_id"] if "session_id" in session_config else None self.commit(model_id, model, session_id) round_data["time_commit"] = time.time() - tic logger.info("Done committing global model to model trail.") else: logger.error("Failed to commit model to global model trail.") - self.set_round_status(round_id, 'Failed') + self.set_round_status(round_id, "Failed") return None, self.statestore.get_round(round_id) - self.set_round_status(round_id, 'Success') + self.set_round_status(round_id, "Success") # 4. Trigger participating combiner nodes to execute a validation round for the current model validate = session_config["validate"] @@ -263,8 +257,7 @@ def check_combiners_done_reporting(): combiner_config["task"] = "validation" combiner_config["helper_type"] = self.statestore.get_helper() - validating_combiners = self.get_participating_combiners( - combiner_config) + validating_combiners = self.get_participating_combiners(combiner_config) for combiner, combiner_config in validating_combiners: try: @@ -275,7 +268,7 @@ def check_combiners_done_reporting(): pass self.set_round_data(round_id, round_data) - self.set_round_status(round_id, 'Finished') + self.set_round_status(round_id, "Finished") return model_id, self.statestore.get_round(round_id) def reduce(self, combiners): @@ -294,15 +287,15 @@ def reduce(self, combiners): model = None for combiner in combiners: - name = combiner['name'] - model_id = combiner['model_id'] + name = combiner["name"] + model_id = combiner["model_id"] logger.info("Fetching model ({}) from model repository".format(model_id)) try: tic = time.time() data = self.model_repository.get_model(model_id) - meta['time_fetch_model'] += (time.time() - tic) + meta["time_fetch_model"] += time.time() - tic except Exception as e: logger.error("Failed to fetch model from model repository {}: {}".format(name, e)) data = None @@ -375,8 +368,7 @@ def inference_round(self, config): combiner_config["helper_type"] = self.statestore.get_framework() # Select combiners - validating_combiners = self.get_participating_combiners( - combiner_config) + validating_combiners = self.get_participating_combiners(combiner_config) # Test round start policy round_start = self.check_round_start_policy(validating_combiners) diff --git a/fedn/fedn/network/controller/controlbase.py b/fedn/network/controller/controlbase.py similarity index 90% rename from fedn/fedn/network/controller/controlbase.py rename to fedn/network/controller/controlbase.py index 1da0a5f21..2897fb00c 100644 --- a/fedn/fedn/network/controller/controlbase.py +++ b/fedn/network/controller/controlbase.py @@ -63,9 +63,7 @@ def __init__(self, statestore): raise MisconfiguredStorageBackend() if storage_config["storage_type"] == "S3": - self.model_repository = Repository( - storage_config["storage_config"] - ) + self.model_repository = Repository(storage_config["storage_config"]) else: logger.error("Unsupported storage backend, exiting.") raise UnsupportedStorageBackend() @@ -94,11 +92,7 @@ def get_helper(self): helper_type = self.statestore.get_helper() helper = fedn.utils.helpers.helpers.get_helper(helper_type) if not helper: - raise MisconfiguredHelper( - "Unsupported helper type {}, please configure compute_package.helper !".format( - helper_type - ) - ) + raise MisconfiguredHelper("Unsupported helper type {}, please configure compute_package.helper !".format(helper_type)) return helper def get_state(self): @@ -179,8 +173,8 @@ def get_compute_package(self, compute_package=""): else: return None - def create_session(self, config, status='Initialized'): - """ Initialize a new session in backend db. """ + def create_session(self, config, status="Initialized"): + """Initialize a new session in backend db.""" if "session_id" not in config.keys(): session_id = uuid.uuid4() @@ -193,7 +187,7 @@ def create_session(self, config, status='Initialized'): self.statestore.set_session_status(session_id, status) def set_session_status(self, session_id, status): - """ Set the round round stats. + """Set the round round stats. :param round_id: The round unique identifier :type round_id: str @@ -203,12 +197,12 @@ def set_session_status(self, session_id, status): self.statestore.set_session_status(session_id, status) def create_round(self, round_data): - """Initialize a new round in backend db. """ + """Initialize a new round in backend db.""" self.statestore.create_round(round_data) def set_round_data(self, round_id, round_data): - """ Set round data. + """Set round data. :param round_id: The round unique identifier :type round_id: str @@ -218,7 +212,7 @@ def set_round_data(self, round_id, round_data): self.statestore.set_round_data(round_id, round_data) def set_round_status(self, round_id, status): - """ Set the round round stats. + """Set the round round stats. :param round_id: The round unique identifier :type round_id: str @@ -228,7 +222,7 @@ def set_round_status(self, round_id, status): self.statestore.set_round_status(round_id, status) def set_round_config(self, round_id, round_config): - """ Upate round in backend db. + """Upate round in backend db. :param round_id: The round unique identifier :type round_id: str @@ -265,9 +259,7 @@ def commit(self, model_id, model=None, session_id=None): logger.info("Saving model file temporarily to disk...") outfile_name = helper.save(model) logger.info("CONTROL: Uploading model to Minio...") - model_id = self.model_repository.set_model( - outfile_name, is_file=True - ) + model_id = self.model_repository.set_model(outfile_name, is_file=True) logger.info("CONTROL: Deleting temporary model file...") os.unlink(outfile_name) @@ -294,16 +286,12 @@ def get_participating_combiners(self, combiner_round_config): self._handle_unavailable_combiner(combiner) continue - is_participating = self.evaluate_round_participation_policy( - combiner_round_config, nr_active_clients - ) + is_participating = self.evaluate_round_participation_policy(combiner_round_config, nr_active_clients) if is_participating: combiners.append((combiner, combiner_round_config)) return combiners - def evaluate_round_participation_policy( - self, compute_plan, nr_active_clients - ): + def evaluate_round_participation_policy(self, compute_plan, nr_active_clients): """Evaluate policy for combiner round-participation. A combiner participates if it is responsive and reports enough active clients to participate in the round. @@ -327,7 +315,7 @@ def evaluate_round_start_policy(self, combiners): return False def evaluate_round_validity_policy(self, round): - """ Check if the round is valid. + """Check if the round is valid. At the end of the round, before committing a model to the global model trail, we check if the round validity policy has been met. This can involve @@ -340,9 +328,9 @@ def evaluate_round_validity_policy(self, round): :rtype: bool """ model_ids = [] - for combiner in round['combiners']: + for combiner in round["combiners"]: try: - model_ids.append(combiner['model_id']) + model_ids.append(combiner["model_id"]) except KeyError: pass @@ -352,7 +340,7 @@ def evaluate_round_validity_policy(self, round): return True def state(self): - """ Get the current state of the controller. + """Get the current state of the controller. :return: The state :rype: str diff --git a/fedn/fedn/network/grpc/__init__.py b/fedn/network/grpc/__init__.py similarity index 100% rename from fedn/fedn/network/grpc/__init__.py rename to fedn/network/grpc/__init__.py diff --git a/fedn/fedn/network/grpc/auth.py b/fedn/network/grpc/auth.py similarity index 64% rename from fedn/fedn/network/grpc/auth.py rename to fedn/network/grpc/auth.py index d879cd812..e57926cd7 100644 --- a/fedn/fedn/network/grpc/auth.py +++ b/fedn/network/grpc/auth.py @@ -6,35 +6,33 @@ from fedn.network.api.auth import check_custom_claims ENDPOINT_ROLES_MAPPING = { - '/fedn.Combiner/TaskStream': ['client'], - '/fedn.Combiner/SendModelUpdate': ['client'], - '/fedn.Combiner/SendModelValidation': ['client'], - '/fedn.Connector/SendHeartbeat': ['client'], - '/fedn.Connector/SendStatus': ['client'], - '/fedn.ModelService/Download': ['client'], - '/fedn.ModelService/Upload': ['client'], - '/fedn.Control/Start': ['controller'], - '/fedn.Control/Stop': ['controller'], - '/fedn.Control/FlushAggregationQueue': ['controller'], - '/fedn.Control/SetAggregator': ['controller'], + "/fedn.Combiner/TaskStream": ["client"], + "/fedn.Combiner/SendModelUpdate": ["client"], + "/fedn.Combiner/SendModelValidation": ["client"], + "/fedn.Connector/SendHeartbeat": ["client"], + "/fedn.Connector/SendStatus": ["client"], + "/fedn.ModelService/Download": ["client"], + "/fedn.ModelService/Upload": ["client"], + "/fedn.Control/Start": ["controller"], + "/fedn.Control/Stop": ["controller"], + "/fedn.Control/FlushAggregationQueue": ["controller"], + "/fedn.Control/SetAggregator": ["controller"], } ENDPOINT_WHITELIST = [ - '/fedn.Connector/AcceptingClients', - '/fedn.Connector/ListActiveClients', - '/fedn.Control/Start', - '/fedn.Control/Stop', - '/fedn.Control/FlushAggregationQueue', - '/fedn.Control/SetAggregator', + "/fedn.Connector/AcceptingClients", + "/fedn.Connector/ListActiveClients", + "/fedn.Control/Start", + "/fedn.Control/Stop", + "/fedn.Control/FlushAggregationQueue", + "/fedn.Control/SetAggregator", ] -USER_AGENT_WHITELIST = [ - 'grpc_health_probe' -] +USER_AGENT_WHITELIST = ["grpc_health_probe"] def check_role_claims(payload, endpoint): - user_role = payload.get('role', '') + user_role = payload.get("role", "") # Perform endpoint-specific RBAC check allowed_roles = ENDPOINT_ROLES_MAPPING.get(endpoint) @@ -63,33 +61,33 @@ def intercept_service(self, continuation, handler_call_details): if handler_call_details.method in ENDPOINT_WHITELIST: return continuation(handler_call_details) # Pass if the request comes from whitelisted user agents - user_agent = metadata.get('user-agent').split(' ')[0] + user_agent = metadata.get("user-agent").split(" ")[0] if user_agent in USER_AGENT_WHITELIST: return continuation(handler_call_details) - token = metadata.get('authorization') + token = metadata.get("authorization") if token is None: - return _unary_unary_rpc_terminator(grpc.StatusCode.UNAUTHENTICATED, 'Token is missing') + return _unary_unary_rpc_terminator(grpc.StatusCode.UNAUTHENTICATED, "Token is missing") if not token.startswith(FEDN_AUTH_SCHEME): - return _unary_unary_rpc_terminator(grpc.StatusCode.UNAUTHENTICATED, f'Invalid token scheme, expected {FEDN_AUTH_SCHEME}') + return _unary_unary_rpc_terminator(grpc.StatusCode.UNAUTHENTICATED, f"Invalid token scheme, expected {FEDN_AUTH_SCHEME}") - token = token.split(' ')[1] + token = token.split(" ")[1] try: payload = jwt.decode(token, SECRET_KEY, algorithms=[FEDN_JWT_ALGORITHM]) if not check_role_claims(payload, handler_call_details.method): - return _unary_unary_rpc_terminator(grpc.StatusCode.PERMISSION_DENIED, 'Insufficient permissions') + return _unary_unary_rpc_terminator(grpc.StatusCode.PERMISSION_DENIED, "Insufficient permissions") if not check_custom_claims(payload): - return _unary_unary_rpc_terminator(grpc.StatusCode.PERMISSION_DENIED, 'Insufficient permissions') + return _unary_unary_rpc_terminator(grpc.StatusCode.PERMISSION_DENIED, "Insufficient permissions") return continuation(handler_call_details) except jwt.InvalidTokenError: - return _unary_unary_rpc_terminator(grpc.StatusCode.UNAUTHENTICATED, 'Invalid token') + return _unary_unary_rpc_terminator(grpc.StatusCode.UNAUTHENTICATED, "Invalid token") except jwt.ExpiredSignatureError: - return _unary_unary_rpc_terminator(grpc.StatusCode.UNAUTHENTICATED, 'Token expired') + return _unary_unary_rpc_terminator(grpc.StatusCode.UNAUTHENTICATED, "Token expired") except Exception as e: logger.error(str(e)) return _unary_unary_rpc_terminator(grpc.StatusCode.UNKNOWN, str(e)) diff --git a/fedn/fedn/network/grpc/fedn.proto b/fedn/network/grpc/fedn.proto similarity index 100% rename from fedn/fedn/network/grpc/fedn.proto rename to fedn/network/grpc/fedn.proto diff --git a/fedn/fedn/network/grpc/fedn_pb2.py b/fedn/network/grpc/fedn_pb2.py similarity index 100% rename from fedn/fedn/network/grpc/fedn_pb2.py rename to fedn/network/grpc/fedn_pb2.py diff --git a/fedn/fedn/network/grpc/fedn_pb2_grpc.py b/fedn/network/grpc/fedn_pb2_grpc.py similarity index 100% rename from fedn/fedn/network/grpc/fedn_pb2_grpc.py rename to fedn/network/grpc/fedn_pb2_grpc.py diff --git a/fedn/fedn/network/grpc/server.py b/fedn/network/grpc/server.py similarity index 100% rename from fedn/fedn/network/grpc/server.py rename to fedn/network/grpc/server.py diff --git a/fedn/fedn/network/loadbalancer/__init__.py b/fedn/network/loadbalancer/__init__.py similarity index 100% rename from fedn/fedn/network/loadbalancer/__init__.py rename to fedn/network/loadbalancer/__init__.py diff --git a/fedn/fedn/network/loadbalancer/firstavailable.py b/fedn/network/loadbalancer/firstavailable.py similarity index 82% rename from fedn/fedn/network/loadbalancer/firstavailable.py rename to fedn/network/loadbalancer/firstavailable.py index 6ab288db8..5bae613f5 100644 --- a/fedn/fedn/network/loadbalancer/firstavailable.py +++ b/fedn/network/loadbalancer/firstavailable.py @@ -4,7 +4,7 @@ @trace_all_methods class LeastPacked(LoadBalancerBase): - """ Load balancer that selects the first available combiner. + """Load balancer that selects the first available combiner. :param network: A handle to the network. :type network: class: `fedn.network.api.network.Network` @@ -14,7 +14,7 @@ def __init__(self, network): super().__init__(network) def find_combiner(self): - """ Find the first available combiner. """ + """Find the first available combiner.""" for combiner in self.network.get_combiners(): if combiner.allowing_clients(): diff --git a/fedn/fedn/network/loadbalancer/leastpacked.py b/fedn/network/loadbalancer/leastpacked.py similarity index 87% rename from fedn/fedn/network/loadbalancer/leastpacked.py rename to fedn/network/loadbalancer/leastpacked.py index c628d95c7..abb458342 100644 --- a/fedn/fedn/network/loadbalancer/leastpacked.py +++ b/fedn/network/loadbalancer/leastpacked.py @@ -5,7 +5,7 @@ @trace_all_methods class LeastPacked(LoadBalancerBase): - """ Load balancer that selects the combiner with the least number of attached training clients. + """Load balancer that selects the combiner with the least number of attached training clients. :param network: A handle to the network. :type network: class: `fedn.network.api.network.Network` @@ -16,7 +16,7 @@ def __init__(self, network): def find_combiner(self): """ - Find the combiner with the least number of attached clients. + Find the combiner with the least number of attached clients. """ min_clients = None diff --git a/fedn/fedn/network/loadbalancer/loadbalancerbase.py b/fedn/network/loadbalancer/loadbalancerbase.py similarity index 80% rename from fedn/fedn/network/loadbalancer/loadbalancerbase.py rename to fedn/network/loadbalancer/loadbalancerbase.py index 6dc168dd4..aad2eb638 100644 --- a/fedn/fedn/network/loadbalancer/loadbalancerbase.py +++ b/fedn/network/loadbalancer/loadbalancerbase.py @@ -5,7 +5,7 @@ @trace_all_methods class LoadBalancerBase(ABC): - """ Abstract base class for load balancers. + """Abstract base class for load balancers. :param network: A handle to the network. :type network: class: `fedn.network.api.network.Network` @@ -17,5 +17,5 @@ def __init__(self, network): @abstractmethod def find_combiner(self): - """ Find a combiner to connect to. """ + """Find a combiner to connect to.""" pass diff --git a/fedn/fedn/network/state.py b/fedn/network/state.py similarity index 89% rename from fedn/fedn/network/state.py rename to fedn/network/state.py index 5a8932316..27f6ec501 100644 --- a/fedn/fedn/network/state.py +++ b/fedn/network/state.py @@ -4,7 +4,8 @@ class ReducerState(Enum): - """ Enum for representing the state of a reducer.""" + """Enum for representing the state of a reducer.""" + setup = 1 idle = 2 instructing = 3 @@ -13,7 +14,7 @@ class ReducerState(Enum): @tracer.start_as_current_span(name="ReducerStateToString") def ReducerStateToString(state): - """ Convert ReducerState to string. + """Convert ReducerState to string. :param state: The state. :type state: :class:`fedn.network.state.ReducerState` @@ -34,7 +35,7 @@ def ReducerStateToString(state): @tracer.start_as_current_span(name="StringToReducerState") def StringToReducerState(state): - """ Convert string to ReducerState. + """Convert string to ReducerState. :param state: The state as string. :type state: str diff --git a/fedn/fedn/network/storage/__init__.py b/fedn/network/storage/__init__.py similarity index 100% rename from fedn/fedn/network/storage/__init__.py rename to fedn/network/storage/__init__.py diff --git a/fedn/fedn/network/storage/models/__init__.py b/fedn/network/storage/models/__init__.py similarity index 100% rename from fedn/fedn/network/storage/models/__init__.py rename to fedn/network/storage/models/__init__.py diff --git a/fedn/fedn/network/storage/models/memorymodelstorage.py b/fedn/network/storage/models/memorymodelstorage.py similarity index 95% rename from fedn/fedn/network/storage/models/memorymodelstorage.py rename to fedn/network/storage/models/memorymodelstorage.py index 196e27fd4..d7fb4e77e 100644 --- a/fedn/fedn/network/storage/models/memorymodelstorage.py +++ b/fedn/network/storage/models/memorymodelstorage.py @@ -10,25 +10,22 @@ @trace_all_methods class MemoryModelStorage(ModelStorage): - """ Class for in-memory storage of model artifacts. + """Class for in-memory storage of model artifacts. Models are stored as BytesIO objects in a dictionary. """ def __init__(self): - self.models = defaultdict(io.BytesIO) self.models_metadata = {} def exist(self, model_id): - if model_id in self.models.keys(): return True return False def get(self, model_id): - obj = self.models[model_id] obj.seek(0, 0) # Have to copy object to not mix up the file pointers when sending... fix in better way. @@ -44,9 +41,7 @@ def get_ptr(self, model_id): return self.models[model_id] def get_model_metadata(self, model_id): - return self.models_metadata[model_id] def set_model_metadata(self, model_id, model_metadata): - self.models_metadata.update({model_id: model_metadata}) diff --git a/fedn/fedn/network/storage/models/modelstorage.py b/fedn/network/storage/models/modelstorage.py similarity index 85% rename from fedn/fedn/network/storage/models/modelstorage.py rename to fedn/network/storage/models/modelstorage.py index 801c8c489..f77acdebf 100644 --- a/fedn/fedn/network/storage/models/modelstorage.py +++ b/fedn/network/storage/models/modelstorage.py @@ -5,10 +5,9 @@ @trace_all_methods class ModelStorage(ABC): - @abstractmethod def exist(self, model_id): - """ Check if model exists in storage + """Check if model exists in storage :param model_id: The model id :type model_id: str @@ -19,7 +18,7 @@ def exist(self, model_id): @abstractmethod def get(self, model_id): - """ Get model from storage + """Get model from storage :param model_id: The model id :type model_id: str @@ -30,7 +29,7 @@ def get(self, model_id): @abstractmethod def get_model_metadata(self, model_id): - """ Get model metadata from storage + """Get model metadata from storage :param model_id: The model id :type model_id: str @@ -41,7 +40,7 @@ def get_model_metadata(self, model_id): @abstractmethod def set_model_metadata(self, model_id, model_metadata): - """ Set model metadata in storage + """Set model metadata in storage :param model_id: The model id :type model_id: str @@ -54,7 +53,7 @@ def set_model_metadata(self, model_id, model_metadata): @abstractmethod def delete(self, model_id): - """ Delete model from storage + """Delete model from storage :param model_id: The model id :type model_id: str @@ -65,7 +64,7 @@ def delete(self, model_id): @abstractmethod def delete_all(self): - """ Delete all models from storage + """Delete all models from storage :return: True if successful, False otherwise :rtype: bool diff --git a/fedn/fedn/network/storage/models/tempmodelstorage.py b/fedn/network/storage/models/tempmodelstorage.py similarity index 89% rename from fedn/fedn/network/storage/models/tempmodelstorage.py rename to fedn/network/storage/models/tempmodelstorage.py index f1a814ac9..5ae7a1c41 100644 --- a/fedn/fedn/network/storage/models/tempmodelstorage.py +++ b/fedn/network/storage/models/tempmodelstorage.py @@ -11,12 +11,10 @@ @trace_all_methods class TempModelStorage(ModelStorage): - """ Class for managing local temporary models on file on combiners.""" + """Class for managing local temporary models on file on combiners.""" def __init__(self): - - self.default_dir = os.environ.get( - 'FEDN_MODEL_DIR', '/tmp/models') # set default to tmp + self.default_dir = os.environ.get("FEDN_MODEL_DIR", "/tmp/models") # set default to tmp if not os.path.exists(self.default_dir): os.makedirs(self.default_dir) @@ -24,13 +22,11 @@ def __init__(self): self.models_metadata = {} def exist(self, model_id): - if model_id in self.models.keys(): return True return False def get(self, model_id): - try: if self.models_metadata[model_id] != fedn.ModelStatus.OK: logger.warning("File not ready! Try again") @@ -40,7 +36,7 @@ def get(self, model_id): return None obj = BytesIO() - with open(os.path.join(self.default_dir, str(model_id)), 'rb') as f: + with open(os.path.join(self.default_dir, str(model_id)), "rb") as f: obj.write(f.read()) obj.seek(0, 0) @@ -53,15 +49,14 @@ def get_ptr(self, model_id): :return: """ try: - f = self.models[model_id]['file'] + f = self.models[model_id]["file"] except KeyError: f = open(os.path.join(self.default_dir, str(model_id)), "wb") - self.models[model_id] = {'file': f} - return self.models[model_id]['file'] + self.models[model_id] = {"file": f} + return self.models[model_id]["file"] def get_model_metadata(self, model_id): - try: status = self.models_metadata[model_id] except KeyError: @@ -69,12 +64,10 @@ def get_model_metadata(self, model_id): return status def set_model_metadata(self, model_id, model_metadata): - self.models_metadata.update({model_id: model_metadata}) # Delete model from disk def delete(self, model_id): - try: os.remove(os.path.join(self.default_dir, str(model_id))) logger.info("TEMPMODELSTORAGE: Deleted model with id: {}".format(model_id)) @@ -88,7 +81,6 @@ def delete(self, model_id): # Delete all models from disk def delete_all(self): - ids_pop = [] for model_id in self.models.keys(): try: diff --git a/fedn/fedn/network/storage/models/tests/test_tempmodelstorage.py b/fedn/network/storage/models/tests/test_tempmodelstorage.py similarity index 100% rename from fedn/fedn/network/storage/models/tests/test_tempmodelstorage.py rename to fedn/network/storage/models/tests/test_tempmodelstorage.py diff --git a/fedn/fedn/network/storage/s3/__init__.py b/fedn/network/storage/s3/__init__.py similarity index 100% rename from fedn/fedn/network/storage/s3/__init__.py rename to fedn/network/storage/s3/__init__.py diff --git a/fedn/fedn/network/storage/s3/base.py b/fedn/network/storage/s3/base.py similarity index 86% rename from fedn/fedn/network/storage/s3/base.py rename to fedn/network/storage/s3/base.py index 8f22a2485..671b75d7d 100644 --- a/fedn/fedn/network/storage/s3/base.py +++ b/fedn/network/storage/s3/base.py @@ -6,7 +6,7 @@ class RepositoryBase(object): @abc.abstractmethod def set_artifact(self, instance_name, instance, bucket): - """ Set object with name object_name + """Set object with name object_name :param instance_name: The name of the object :tyep insance_name: str @@ -18,7 +18,7 @@ def set_artifact(self, instance_name, instance, bucket): @abc.abstractmethod def get_artifact(self, instance_name, bucket): - """ Retrive object with name instance_name. + """Retrive object with name instance_name. :param instance_name: The name of the object to retrieve :type instance_name: str @@ -29,7 +29,7 @@ def get_artifact(self, instance_name, bucket): @abc.abstractmethod def get_artifact_stream(self, instance_name, bucket): - """ Return a stream handler for object with name instance_name. + """Return a stream handler for object with name instance_name. :param instance_name: The name if the object :type instance_name: str diff --git a/fedn/fedn/network/storage/s3/miniorepository.py b/fedn/network/storage/s3/miniorepository.py similarity index 65% rename from fedn/fedn/network/storage/s3/miniorepository.py rename to fedn/network/storage/s3/miniorepository.py index 3eaca244c..124a81f88 100644 --- a/fedn/fedn/network/storage/s3/miniorepository.py +++ b/fedn/network/storage/s3/miniorepository.py @@ -11,12 +11,12 @@ @trace_all_methods class MINIORepository(RepositoryBase): - """ Class implementing Repository for MinIO. """ + """Class implementing Repository for MinIO.""" client = None def __init__(self, config): - """ Initialize object. + """Initialize object. :param config: Dictionary containing configuration for credentials and bucket names. :type config: dict @@ -25,34 +25,35 @@ def __init__(self, config): super().__init__() self.name = "MINIORepository" - if config['storage_secure_mode']: - manager = PoolManager( - num_pools=100, cert_reqs='CERT_NONE', assert_hostname=False) - self.client = Minio("{0}:{1}".format(config['storage_hostname'], config['storage_port']), - access_key=config['storage_access_key'], - secret_key=config['storage_secret_key'], - secure=config['storage_secure_mode'], http_client=manager) + if config["storage_secure_mode"]: + manager = PoolManager(num_pools=100, cert_reqs="CERT_NONE", assert_hostname=False) + self.client = Minio( + "{0}:{1}".format(config["storage_hostname"], config["storage_port"]), + access_key=config["storage_access_key"], + secret_key=config["storage_secret_key"], + secure=config["storage_secure_mode"], + http_client=manager, + ) else: - self.client = Minio("{0}:{1}".format(config['storage_hostname'], config['storage_port']), - access_key=config['storage_access_key'], - secret_key=config['storage_secret_key'], - secure=config['storage_secure_mode']) + self.client = Minio( + "{0}:{1}".format(config["storage_hostname"], config["storage_port"]), + access_key=config["storage_access_key"], + secret_key=config["storage_secret_key"], + secure=config["storage_secure_mode"], + ) def set_artifact(self, instance_name, instance, bucket, is_file=False): - if is_file: self.client.fput_object(bucket, instance_name, instance) else: try: - self.client.put_object( - bucket, instance_name, io.BytesIO(instance), len(instance)) + self.client.put_object(bucket, instance_name, io.BytesIO(instance), len(instance)) except Exception as e: raise Exception("Could not load data into bytes {}".format(e)) return True def get_artifact(self, instance_name, bucket): - try: data = self.client.get_object(bucket, instance_name) return data.read() @@ -63,7 +64,6 @@ def get_artifact(self, instance_name, bucket): data.release_conn() def get_artifact_stream(self, instance_name, bucket): - try: data = self.client.get_object(bucket, instance_name) return data @@ -71,7 +71,7 @@ def get_artifact_stream(self, instance_name, bucket): raise Exception("Could not fetch data from bucket, {}".format(e)) def list_artifacts(self, bucket): - """ List all objects in bucket. + """List all objects in bucket. :param bucket: Name of the bucket :type bucket: str @@ -83,12 +83,11 @@ def list_artifacts(self, bucket): for obj in objs: objects.append(obj.object_name) except Exception: - raise Exception( - "Could not list models in bucket {}".format(bucket)) + raise Exception("Could not list models in bucket {}".format(bucket)) return objects def delete_artifact(self, instance_name, bucket): - """ Delete object with name instance_name from buckets. + """Delete object with name instance_name from buckets. :param instance_name: The object name :param bucket: Buckets to delete from @@ -98,11 +97,11 @@ def delete_artifact(self, instance_name, bucket): try: self.client.remove_object(bucket, instance_name) except InvalidResponseError as err: - logger.error('Could not delete artifact: {0} err: {1}'.format(instance_name, err)) + logger.error("Could not delete artifact: {0} err: {1}".format(instance_name, err)) pass def create_bucket(self, bucket_name): - """ Create a new bucket. If bucket exists, do nothing. + """Create a new bucket. If bucket exists, do nothing. :param bucket_name: The name of the bucket :type bucket_name: str diff --git a/fedn/fedn/network/storage/s3/repository.py b/fedn/network/storage/s3/repository.py similarity index 79% rename from fedn/fedn/network/storage/s3/repository.py rename to fedn/network/storage/s3/repository.py index afd38aed9..df3d6acb0 100644 --- a/fedn/fedn/network/storage/s3/repository.py +++ b/fedn/network/storage/s3/repository.py @@ -7,12 +7,11 @@ @trace_all_methods class Repository: - """ Interface for storing model objects and compute packages in S3 compatible storage. """ + """Interface for storing model objects and compute packages in S3 compatible storage.""" def __init__(self, config): - - self.model_bucket = config['storage_bucket'] - self.context_bucket = config['context_bucket'] + self.model_bucket = config["storage_bucket"] + self.context_bucket = config["context_bucket"] # TODO: Make a plug-in solution self.client = MINIORepository(config) @@ -21,27 +20,25 @@ def __init__(self, config): self.client.create_bucket(self.model_bucket) def get_model(self, model_id): - """ Retrieve a model with id model_id. + """Retrieve a model with id model_id. :param model_id: Unique identifier for model to retrive. :return: The model object """ - logger.info("Client {} trying to get model with id: {}".format( - self.client.name, model_id)) + logger.info("Client {} trying to get model with id: {}".format(self.client.name, model_id)) return self.client.get_artifact(model_id, self.model_bucket) def get_model_stream(self, model_id): - """ Retrieve a stream handle to model with id model_id. + """Retrieve a stream handle to model with id model_id. :param model_id: :return: Handle to model object """ - logger.info("Client {} trying to get model with id: {}".format( - self.client.name, model_id)) + logger.info("Client {} trying to get model with id: {}".format(self.client.name, model_id)) return self.client.get_artifact_stream(model_id, self.model_bucket) def set_model(self, model, is_file=True): - """ Upload model object. + """Upload model object. :param model: The model object :type model: BytesIO or str file name. @@ -51,15 +48,14 @@ def set_model(self, model, is_file=True): model_id = uuid.uuid4() try: - self.client.set_artifact(str(model_id), model, - bucket=self.model_bucket, is_file=is_file) + self.client.set_artifact(str(model_id), model, bucket=self.model_bucket, is_file=is_file) except Exception: logger.error("Failed to upload model with ID {} to repository.".format(model_id)) raise return str(model_id) def delete_model(self, model_id): - """ Delete model. + """Delete model. :param model_id: The id of the model to delete :type model_id: str @@ -71,7 +67,7 @@ def delete_model(self, model_id): raise def set_compute_package(self, name, compute_package, is_file=True): - """ Upload compute package. + """Upload compute package. :param name: The name of the compute package. :type name: str @@ -81,14 +77,13 @@ def set_compute_package(self, name, compute_package, is_file=True): """ try: - self.client.set_artifact(str(name), compute_package, - bucket=self.context_bucket, is_file=is_file) + self.client.set_artifact(str(name), compute_package, bucket=self.context_bucket, is_file=is_file) except Exception: logger.error("Failed to write compute_package to repository.") raise def get_compute_package(self, compute_package): - """ Retrieve compute package from object store. + """Retrieve compute package from object store. :param compute_package: The name of the compute package. :type compute_pacakge: str @@ -102,7 +97,7 @@ def get_compute_package(self, compute_package): return data def delete_compute_package(self, compute_package): - """ Delete a compute package from storage. + """Delete a compute package from storage. :param compute_package: The name of the compute_package :type compute_package: str diff --git a/fedn/fedn/network/storage/statestore/__init__.py b/fedn/network/storage/statestore/__init__.py similarity index 100% rename from fedn/fedn/network/storage/statestore/__init__.py rename to fedn/network/storage/statestore/__init__.py diff --git a/fedn/fedn/network/storage/statestore/mongostatestore.py b/fedn/network/storage/statestore/mongostatestore.py similarity index 87% rename from fedn/fedn/network/storage/statestore/mongostatestore.py rename to fedn/network/storage/statestore/mongostatestore.py index a885f9ba0..6e3c4acb6 100644 --- a/fedn/fedn/network/storage/statestore/mongostatestore.py +++ b/fedn/network/storage/statestore/mongostatestore.py @@ -63,7 +63,7 @@ def __init__(self, network_id, config): self.init_index() def connect(self): - """ Establish client connection to MongoDB. + """Establish client connection to MongoDB. :param config: Dictionary containing connection strings and security credentials. :type config: dict @@ -127,11 +127,7 @@ def transition(self, state): True, ) else: - logger.info( - "Not updating state, already in {}".format( - ReducerStateToString(state) - ) - ) + logger.info("Not updating state, already in {}".format(ReducerStateToString(state))) def get_sessions(self, limit=None, skip=None, sort_key="_id", sort_order=pymongo.DESCENDING): """Get all sessions. @@ -153,13 +149,9 @@ def get_sessions(self, limit=None, skip=None, sort_key="_id", sort_order=pymongo limit = int(limit) skip = int(skip) - result = self.sessions.find().limit(limit).skip(skip).sort( - sort_key, sort_order - ) + result = self.sessions.find().limit(limit).skip(skip).sort(sort_key, sort_order) else: - result = self.sessions.find().sort( - sort_key, sort_order - ) + result = self.sessions.find().sort(sort_key, sort_order) count = self.sessions.count_documents({}) @@ -206,9 +198,7 @@ def set_latest_model(self, model_id, session_id=None): } ) - self.model.update_one( - {"key": "current_model"}, {"$set": {"model": model_id}}, True - ) + self.model.update_one({"key": "current_model"}, {"$set": {"model": model_id}}, True) self.model.update_one( {"key": "model_trail"}, { @@ -227,9 +217,7 @@ def get_initial_model(self): :rtype: str """ - result = self.model.find_one( - {"key": "model_trail"}, sort=[("committed_at", pymongo.ASCENDING)] - ) + result = self.model.find_one({"key": "model_trail"}, sort=[("committed_at", pymongo.ASCENDING)]) if result is None: return None @@ -268,16 +256,12 @@ def set_current_model(self, model_id: str): """ try: - committed_at = datetime.now() existing_model = self.model.find_one({"key": "models", "model": model_id}) if existing_model is not None: - - self.model.update_one( - {"key": "current_model"}, {"$set": {"model": model_id, "committed_at": committed_at, "session_id": None}}, True - ) + self.model.update_one({"key": "current_model"}, {"$set": {"model": model_id, "committed_at": committed_at, "session_id": None}}, True) return True except Exception as e: @@ -336,7 +320,6 @@ def set_active_compute_package(self, id: str): """ try: - find = {"id": id} projection = {"_id": False, "key": False} @@ -347,9 +330,7 @@ def set_active_compute_package(self, id: str): doc["key"] = "active" - self.control.package.replace_one( - {"key": "active"}, doc - ) + self.control.package.replace_one({"key": "active"}, doc) except Exception as e: logger.error("ERROR: {}".format(e)) @@ -378,9 +359,7 @@ def set_compute_package(self, file_name: str, storage_file_name: str, helper_typ self.control.package.update_one( {"key": "active"}, - { - "$set": obj - }, + {"$set": obj}, True, ) @@ -397,7 +376,6 @@ def get_compute_package(self): :rtype: ObjectID """ try: - find = {"key": "active"} projection = {"key": False, "_id": False} ret = self.control.package.find_one(find, projection) @@ -451,9 +429,7 @@ def set_helper(self, helper): :type helper: str :return: """ - self.control.package.update_one( - {"key": "active"}, {"$set": {"helper": helper}}, True - ) + self.control.package.update_one({"key": "active"}, {"$set": {"helper": helper}}, True) def get_helper(self): """Get the active helper package. @@ -468,9 +444,7 @@ def get_helper(self): # ret = self.control.config.find_one({'key': 'round_config'}) try: retcheck = ret["helper"] - if ( - retcheck == "" or retcheck == " " - ): # ugly check for empty string + if retcheck == "" or retcheck == " ": # ugly check for empty string return None return retcheck except (KeyError, IndexError): @@ -497,11 +471,7 @@ def list_models( """ result = None - find_option = ( - {"key": "models"} - if session_id is None - else {"key": "models", "session_id": session_id} - ) + find_option = {"key": "models"} if session_id is None else {"key": "models", "session_id": session_id} projection = {"_id": False, "key": False} @@ -509,17 +479,10 @@ def list_models( limit = int(limit) skip = int(skip) - result = ( - self.model.find(find_option, projection) - .limit(limit) - .skip(skip) - .sort(sort_key, sort_order) - ) + result = self.model.find(find_option, projection).limit(limit).skip(skip).sort(sort_key, sort_order) else: - result = self.model.find(find_option, projection).sort( - sort_key, sort_order - ) + result = self.model.find(find_option, projection).sort(sort_key, sort_order) count = self.model.count_documents(find_option) @@ -627,9 +590,7 @@ def get_events(self, **kwargs): projection = {"_id": False} if not kwargs: - result = self.control.status.find({}, projection).sort( - "timestamp", pymongo.DESCENDING - ) + result = self.control.status.find({}, projection).sort("timestamp", pymongo.DESCENDING) count = self.control.status.count_documents({}) else: limit = kwargs.pop("limit", None) @@ -638,16 +599,9 @@ def get_events(self, **kwargs): if limit is not None and skip is not None: limit = int(limit) skip = int(skip) - result = ( - self.control.status.find(kwargs, projection) - .sort("timestamp", pymongo.DESCENDING) - .limit(limit) - .skip(skip) - ) + result = self.control.status.find(kwargs, projection).sort("timestamp", pymongo.DESCENDING).limit(limit).skip(skip) else: - result = self.control.status.find(kwargs, projection).sort( - "timestamp", pymongo.DESCENDING - ) + result = self.control.status.find(kwargs, projection).sort("timestamp", pymongo.DESCENDING) count = self.control.status.count_documents(kwargs) @@ -663,9 +617,7 @@ def get_storage_backend(self): :rtype: ObjectID """ try: - ret = self.storage.find( - {"status": "enabled"}, projection={"_id": False} - ) + ret = self.storage.find({"status": "enabled"}, projection={"_id": False}) return ret[0] except (KeyError, IndexError): return None @@ -680,9 +632,7 @@ def set_storage_backend(self, config): config = copy.deepcopy(config) config["updated_at"] = str(datetime.now()) config["status"] = "enabled" - self.storage.update_one( - {"storage_type": config["storage_type"]}, {"$set": config}, True - ) + self.storage.update_one({"storage_type": config["storage_type"]}, {"$set": config}, True) def set_reducer(self, reducer_data): """Set the reducer in the statestore. @@ -692,9 +642,7 @@ def set_reducer(self, reducer_data): :return: """ reducer_data["updated_at"] = str(datetime.now()) - self.reducer.update_one( - {"name": reducer_data["name"]}, {"$set": reducer_data}, True - ) + self.reducer.update_one({"name": reducer_data["name"]}, {"$set": reducer_data}, True) def get_reducer(self): """Get reducer.config. @@ -769,9 +717,7 @@ def set_combiner(self, combiner_data): """ combiner_data["updated_at"] = str(datetime.now()) - self.combiners.update_one( - {"name": combiner_data["name"]}, {"$set": combiner_data}, True - ) + self.combiners.update_one({"name": combiner_data["name"]}, {"$set": combiner_data}, True) def delete_combiner(self, combiner): """Delete a combiner from statestore. @@ -795,9 +741,7 @@ def set_client(self, client_data): :return: """ client_data["updated_at"] = str(datetime.now()) - self.clients.update_one( - {"name": client_data["name"]}, {"$set": client_data}, True - ) + self.clients.update_one({"name": client_data["name"]}, {"$set": client_data}, True) def get_client(self, name): """Get client by name. @@ -868,15 +812,15 @@ def list_combiners_data(self, combiners, sort_key="count", sort_order=pymongo.DE result = None try: - - pipeline = [ - {"$match": {"combiner": {"$in": combiners}, "status": "online"}}, - {"$group": {"_id": "$combiner", "count": {"$sum": 1}}}, - {"$sort": {sort_key: sort_order, "_id": pymongo.ASCENDING}} - ] if combiners is not None else [ - {"$group": {"_id": "$combiner", "count": {"$sum": 1}}}, - {"$sort": {sort_key: sort_order, "_id": pymongo.ASCENDING}} - ] + pipeline = ( + [ + {"$match": {"combiner": {"$in": combiners}, "status": "online"}}, + {"$group": {"_id": "$combiner", "count": {"$sum": 1}}}, + {"$sort": {sort_key: sort_order, "_id": pymongo.ASCENDING}}, + ] + if combiners is not None + else [{"$group": {"_id": "$combiner", "count": {"$sum": 1}}}, {"$sort": {sort_key: sort_order, "_id": pymongo.ASCENDING}}] + ) result = self.clients.aggregate(pipeline) @@ -913,7 +857,7 @@ def drop_status(self): self.status.drop() def create_session(self, id=None): - """ Create a new session object. + """Create a new session object. :param id: The ID of the created session. :type id: uuid, str @@ -921,11 +865,11 @@ def create_session(self, id=None): """ if not id: id = uuid.uuid4() - data = {'session_id': str(id)} + data = {"session_id": str(id)} self.sessions.insert_one(data) def create_round(self, round_data): - """ Create a new round. + """Create a new round. :param round_data: Dictionary with round data. :type round_data: dict @@ -941,8 +885,7 @@ def set_session_config(self, id, config): :param config: Session configuration :type config: dict """ - self.sessions.update_one({'session_id': str(id)}, { - '$push': {'session_config': config}}, True) + self.sessions.update_one({"session_id": str(id)}, {"$push": {"session_config": config}}, True) def set_session_status(self, id, status): """Set session status. @@ -951,8 +894,7 @@ def set_session_status(self, id, status): :type round_id: str :param round_status: The status of the session. """ - self.sessions.update_one({'session_id': str(id)}, { - '$set': {'status': status}}, True) + self.sessions.update_one({"session_id": str(id)}, {"$set": {"status": status}}, True) def set_round_combiner_data(self, data): """Set combiner round controller data. @@ -960,8 +902,7 @@ def set_round_combiner_data(self, data): :param data: The combiner data :type data: dict """ - self.rounds.update_one({'round_id': str(data['round_id'])}, { - '$push': {'combiners': data}}, True) + self.rounds.update_one({"round_id": str(data["round_id"])}, {"$push": {"combiners": data}}, True) def set_round_config(self, round_id, round_config): """Set round configuration. @@ -971,8 +912,7 @@ def set_round_config(self, round_id, round_config): :param round_config: The round configuration :type round_config: dict """ - self.rounds.update_one({'round_id': round_id}, { - '$set': {'round_config': round_config}}, True) + self.rounds.update_one({"round_id": round_id}, {"$set": {"round_config": round_config}}, True) def set_round_status(self, round_id, round_status): """Set round status. @@ -981,8 +921,7 @@ def set_round_status(self, round_id, round_status): :type round_id: str :param round_status: The status of the round. """ - self.rounds.update_one({'round_id': round_id}, { - '$set': {'status': round_status}}, True) + self.rounds.update_one({"round_id": round_id}, {"$set": {"status": round_status}}, True) def set_round_data(self, round_id, round_data): """Update round metadata @@ -992,11 +931,10 @@ def set_round_data(self, round_id, round_data): :param round_data: The round metadata :type round_data: dict """ - self.rounds.update_one({'round_id': round_id}, { - '$set': {'round_data': round_data}}, True) + self.rounds.update_one({"round_id": round_id}, {"$set": {"round_data": round_data}}, True) def update_client_status(self, clients, status): - """ Update client status in statestore. + """Update client status in statestore. :param client_name: The client name :type client_name: str :param status: The client status diff --git a/fedn/fedn/network/storage/statestore/statestorebase.py b/fedn/network/storage/statestore/statestorebase.py similarity index 70% rename from fedn/fedn/network/storage/statestore/statestorebase.py rename to fedn/network/storage/statestore/statestorebase.py index f41e3c025..7c6681682 100644 --- a/fedn/fedn/network/storage/statestore/statestorebase.py +++ b/fedn/network/storage/statestore/statestorebase.py @@ -2,22 +2,19 @@ class StateStoreBase(ABC): - """ - - """ + """ """ def __init__(self): pass @abstractmethod def state(self): - """ Return the current state of the statestore. - """ + """Return the current state of the statestore.""" pass @abstractmethod def transition(self, state): - """ Transition the statestore to a new state. + """Transition the statestore to a new state. :param state: The new state. :type state: str @@ -26,7 +23,7 @@ def transition(self, state): @abstractmethod def set_latest_model(self, model_id): - """ Set the latest model id in the statestore. + """Set the latest model id in the statestore. :param model_id: The model id. :type model_id: str @@ -35,7 +32,7 @@ def set_latest_model(self, model_id): @abstractmethod def get_latest_model(self): - """ Get the latest model id from the statestore. + """Get the latest model id from the statestore. :return: The model object. :rtype: ObjectId @@ -44,7 +41,7 @@ def get_latest_model(self): @abstractmethod def is_inited(self): - """ Check if the statestore is initialized. + """Check if the statestore is initialized. :return: True if initialized, else False. :rtype: bool diff --git a/fedn/fedn/network/storage/statestore/stores/__init__.py b/fedn/network/storage/statestore/stores/__init__.py similarity index 100% rename from fedn/fedn/network/storage/statestore/stores/__init__.py rename to fedn/network/storage/statestore/stores/__init__.py diff --git a/fedn/fedn/network/storage/statestore/stores/client_store.py b/fedn/network/storage/statestore/stores/client_store.py similarity index 100% rename from fedn/fedn/network/storage/statestore/stores/client_store.py rename to fedn/network/storage/statestore/stores/client_store.py diff --git a/fedn/fedn/network/storage/statestore/stores/combiner_store.py b/fedn/network/storage/statestore/stores/combiner_store.py similarity index 100% rename from fedn/fedn/network/storage/statestore/stores/combiner_store.py rename to fedn/network/storage/statestore/stores/combiner_store.py diff --git a/fedn/fedn/network/storage/statestore/stores/model_store.py b/fedn/network/storage/statestore/stores/model_store.py similarity index 100% rename from fedn/fedn/network/storage/statestore/stores/model_store.py rename to fedn/network/storage/statestore/stores/model_store.py diff --git a/fedn/fedn/network/storage/statestore/stores/package_store.py b/fedn/network/storage/statestore/stores/package_store.py similarity index 100% rename from fedn/fedn/network/storage/statestore/stores/package_store.py rename to fedn/network/storage/statestore/stores/package_store.py diff --git a/fedn/fedn/network/storage/statestore/stores/round_store.py b/fedn/network/storage/statestore/stores/round_store.py similarity index 100% rename from fedn/fedn/network/storage/statestore/stores/round_store.py rename to fedn/network/storage/statestore/stores/round_store.py diff --git a/fedn/fedn/network/storage/statestore/stores/session_store.py b/fedn/network/storage/statestore/stores/session_store.py similarity index 100% rename from fedn/fedn/network/storage/statestore/stores/session_store.py rename to fedn/network/storage/statestore/stores/session_store.py diff --git a/fedn/fedn/network/storage/statestore/stores/shared.py b/fedn/network/storage/statestore/stores/shared.py similarity index 100% rename from fedn/fedn/network/storage/statestore/stores/shared.py rename to fedn/network/storage/statestore/stores/shared.py diff --git a/fedn/fedn/network/storage/statestore/stores/status_store.py b/fedn/network/storage/statestore/stores/status_store.py similarity index 100% rename from fedn/fedn/network/storage/statestore/stores/status_store.py rename to fedn/network/storage/statestore/stores/status_store.py diff --git a/fedn/fedn/network/storage/statestore/stores/store.py b/fedn/network/storage/statestore/stores/store.py similarity index 100% rename from fedn/fedn/network/storage/statestore/stores/store.py rename to fedn/network/storage/statestore/stores/store.py diff --git a/fedn/fedn/network/storage/statestore/stores/validation_store.py b/fedn/network/storage/statestore/stores/validation_store.py similarity index 100% rename from fedn/fedn/network/storage/statestore/stores/validation_store.py rename to fedn/network/storage/statestore/stores/validation_store.py diff --git a/fedn/fedn/utils/__init__.py b/fedn/utils/__init__.py similarity index 100% rename from fedn/fedn/utils/__init__.py rename to fedn/utils/__init__.py diff --git a/fedn/fedn/utils/checksum.py b/fedn/utils/checksum.py similarity index 81% rename from fedn/fedn/utils/checksum.py rename to fedn/utils/checksum.py index 01f28e02c..0cf000f37 100644 --- a/fedn/fedn/utils/checksum.py +++ b/fedn/utils/checksum.py @@ -5,7 +5,7 @@ @tracer.start_as_current_span(name="sha") def sha(fname): - """ Calculate the sha256 checksum of a file. Used for computing checksums of compute packages. + """Calculate the sha256 checksum of a file. Used for computing checksums of compute packages. :param fname: The file path. :type fname: str diff --git a/fedn/fedn/utils/dispatcher.py b/fedn/utils/dispatcher.py similarity index 85% rename from fedn/fedn/utils/dispatcher.py rename to fedn/utils/dispatcher.py index 197fed2f0..d5345cc08 100644 --- a/fedn/fedn/utils/dispatcher.py +++ b/fedn/utils/dispatcher.py @@ -77,10 +77,7 @@ def _validate_virtualenv_is_available(): on how to install virtualenv. """ if not _is_virtualenv_available(): - raise Exception( - "Could not find the virtualenv binary. Run `pip install virtualenv` to install " - "virtualenv." - ) + raise Exception("Could not find the virtualenv binary. Run `pip install virtualenv` to install " "virtualenv.") @tracer.start_as_current_span("_get_virtualenv_extra_env_vars") @@ -118,8 +115,7 @@ def _create_virtualenv( with remove_on_error( env_dir, onerror=lambda e: logger.warning( - "Encountered an unexpected error: %s while creating a virtualenv environment in %s, " - "removing the environment directory...", + "Encountered an unexpected error: %s while creating a virtualenv environment in %s, " "removing the environment directory...", repr(e), env_dir, ), @@ -135,9 +131,7 @@ def _create_virtualenv( with tempfile.TemporaryDirectory() as tmpdir: tmp_req_file = f"requirements.{uuid.uuid4().hex}.txt" Path(tmpdir).joinpath(tmp_req_file).write_text("\n".join(deps)) - cmd = _join_commands( - activate_cmd, f"python -m pip install --quiet -r {tmp_req_file}" - ) + cmd = _join_commands(activate_cmd, f"python -m pip install -r {tmp_req_file}") _exec_cmd(cmd, capture_output=capture_output, cwd=tmpdir, extra_env=extra_env) return activate_cmd @@ -147,21 +141,18 @@ def _create_virtualenv( def _read_yaml_file(file_path): try: cfg = None - with open(file_path, 'rb') as config_file: - + with open(file_path, "rb") as config_file: cfg = yaml.safe_load(config_file.read()) except Exception as e: - logger.error( - f"Error trying to read yaml file: {file_path}" - ) + logger.error(f"Error trying to read yaml file: {file_path}") raise e return cfg @trace_all_methods class Dispatcher: - """ Dispatcher class for compute packages. + """Dispatcher class for compute packages. :param config: The configuration. :type config: dict @@ -170,7 +161,7 @@ class Dispatcher: """ def __init__(self, config, project_dir): - """ Initialize the dispatcher.""" + """Initialize the dispatcher.""" self.config = config self.project_dir = project_dir self.activate_cmd = "" @@ -184,10 +175,7 @@ def _get_or_create_python_env(self, capture_output=False, pip_requirements_overr else: python_env_path = os.path.join(self.project_dir, python_env) if not os.path.exists(python_env_path): - raise Exception( - "Compute package specified python_env file %s, but no such " - "file was found." % python_env_path - ) + raise Exception("Compute package specified python_env file %s, but no such " "file was found." % python_env_path) python_env = _get_python_env(python_env_path) extra_env = _get_virtualenv_extra_env_vars() @@ -211,10 +199,7 @@ def _get_or_create_python_env(self, capture_output=False, pip_requirements_overr ) # Install additional dependencies specified by `requirements_override` if pip_requirements_override: - logger.info( - "Installing additional dependencies specified by " - f"pip_requirements_override: {pip_requirements_override}" - ) + logger.info("Installing additional dependencies specified by " f"pip_requirements_override: {pip_requirements_override}") cmd = _join_commands( activate_cmd, f"python -m pip install --quiet -U {' '.join(pip_requirements_override)}", @@ -232,14 +217,8 @@ def _get_or_create_python_env(self, capture_output=False, pip_requirements_overr raise - def run_cmd(self, - cmd_type, - capture_output=False, - extra_env=None, - synchronous=True, - stream_output=False - ): - """ Run a command. + def run_cmd(self, cmd_type, capture_output=False, extra_env=None, synchronous=True, stream_output=False): + """Run a command. :param cmd_type: The command type. :type cmd_type: str @@ -252,7 +231,7 @@ def run_cmd(self, args = cmdsandargs[1:] # Join entry point and arguments into a single command as a string - entry_point_args = ' '.join(args) + entry_point_args = " ".join(args) entry_point = f"{entry_point} {entry_point_args}" if self.activate_cmd: @@ -260,7 +239,7 @@ def run_cmd(self, else: cmd = _join_commands(entry_point) - logger.info('Running command: {}'.format(cmd)) + logger.info("Running command: {}".format(cmd)) _exec_cmd( cmd, cwd=self.project_dir, @@ -271,7 +250,7 @@ def run_cmd(self, stream_output=stream_output, ) - logger.info('Done executing {}'.format(cmd_type)) + logger.info("Done executing {}".format(cmd_type)) except IndexError: message = "No such argument or configuration to run." logger.error(message) diff --git a/fedn/fedn/utils/environment.py b/fedn/utils/environment.py similarity index 95% rename from fedn/fedn/utils/environment.py rename to fedn/utils/environment.py index 9139b74ff..465691342 100644 --- a/fedn/fedn/utils/environment.py +++ b/fedn/utils/environment.py @@ -15,6 +15,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import yaml from fedn.common.telemetry import trace_all_methods @@ -33,6 +34,8 @@ def __init__(self, name=None, python=None, build_dependencies=None, dependencies Represents environment information for FEDn compute packages. Args: + ---- + name: Name of environment. If unspecified, defaults to fedn_env python: Python version for the environment. If unspecified, defaults to the current Python version. build_dependencies: List of build dependencies for the environment that must @@ -40,15 +43,14 @@ def __init__(self, name=None, python=None, build_dependencies=None, dependencies defaults to an empty list. dependencies: List of dependencies for the environment. If unspecified, defaults to an empty list. + """ if name is not None and not isinstance(name, str): raise TypeError(f"`name` must be a string but got {type(name)}") if python is not None and not isinstance(python, str): raise TypeError(f"`python` must be a string but got {type(python)}") if build_dependencies is not None and not isinstance(build_dependencies, list): - raise TypeError( - f"`build_dependencies` must be a list but got {type(build_dependencies)}" - ) + raise TypeError(f"`build_dependencies` must be a list but got {type(build_dependencies)}") if dependencies is not None and not isinstance(dependencies, list): raise TypeError(f"`dependencies` must be a list but got {type(dependencies)}") self.name = name or "fedn_env" diff --git a/fedn/fedn/utils/flowercompat/__init__.py b/fedn/utils/flowercompat/__init__.py similarity index 100% rename from fedn/fedn/utils/flowercompat/__init__.py rename to fedn/utils/flowercompat/__init__.py diff --git a/fedn/fedn/utils/flowercompat/client_app_adapter.py b/fedn/utils/flowercompat/client_app_adapter.py similarity index 76% rename from fedn/fedn/utils/flowercompat/client_app_adapter.py rename to fedn/utils/flowercompat/client_app_adapter.py index 2675dc6b9..b7d70101c 100644 --- a/fedn/fedn/utils/flowercompat/client_app_adapter.py +++ b/fedn/utils/flowercompat/client_app_adapter.py @@ -1,16 +1,27 @@ from typing import Tuple from flwr.client import ClientApp -from flwr.common import (Context, EvaluateIns, FitIns, GetParametersIns, - Message, MessageType, MessageTypeLegacy, Metadata, - NDArrays, ndarrays_to_parameters, - parameters_to_ndarrays) -from flwr.common.recordset_compat import (evaluateins_to_recordset, - fitins_to_recordset, - getparametersins_to_recordset, - recordset_to_evaluateres, - recordset_to_fitres, - recordset_to_getparametersres) +from flwr.common import ( + Context, + EvaluateIns, + FitIns, + GetParametersIns, + Message, + MessageType, + MessageTypeLegacy, + Metadata, + NDArrays, + ndarrays_to_parameters, + parameters_to_ndarrays, +) +from flwr.common.recordset_compat import ( + evaluateins_to_recordset, + fitins_to_recordset, + getparametersins_to_recordset, + recordset_to_evaluateres, + recordset_to_fitres, + recordset_to_getparametersres, +) from fedn.common.telemetry import trace_all_methods @@ -24,23 +35,21 @@ def __init__(self, app: ClientApp) -> None: def init_parameters(self, partition_id: int, config: dict = {}): # Construct a get_parameters message for the ClientApp - message, context = self._construct_message( - MessageTypeLegacy.GET_PARAMETERS, [], partition_id, config - ) + message, context = self._construct_message(MessageTypeLegacy.GET_PARAMETERS, [], partition_id, config) # Call client app with train message client_return_message = self.app(message, context) # return NDArrays of clients parameters parameters = self._parse_get_parameters_message(client_return_message) if len(parameters) == 0: - raise ValueError("The 'parameters' list is empty. Ensure your flower \ - client has implemented a get_parameters() function.") + raise ValueError( + "The 'parameters' list is empty. Ensure your flower \ + client has implemented a get_parameters() function." + ) return parameters def train(self, parameters: NDArrays, partition_id: int, config: dict = {}): # Construct a train message for the ClientApp with given parameters - message, context = self._construct_message( - MessageType.TRAIN, parameters, partition_id, config - ) + message, context = self._construct_message(MessageType.TRAIN, parameters, partition_id, config) # Call client app with train message client_return_message = self.app(message, context) # Parse return message @@ -49,9 +58,7 @@ def train(self, parameters: NDArrays, partition_id: int, config: dict = {}): def evaluate(self, parameters: NDArrays, partition_id: int, config: dict = {}): # Construct an evaluate message for the ClientApp with given parameters - message, context = self._construct_message( - MessageType.EVALUATE, parameters, partition_id, config - ) + message, context = self._construct_message(MessageType.EVALUATE, parameters, partition_id, config) # Call client app with evaluate message client_return_message = self.app(message, context) # Parse return message diff --git a/fedn/fedn/utils/helpers/__init__.py b/fedn/utils/helpers/__init__.py similarity index 100% rename from fedn/fedn/utils/helpers/__init__.py rename to fedn/utils/helpers/__init__.py diff --git a/fedn/fedn/utils/helpers/helperbase.py b/fedn/utils/helpers/helperbase.py similarity index 69% rename from fedn/fedn/utils/helpers/helperbase.py rename to fedn/utils/helpers/helperbase.py index a59ee49d8..3377d0336 100644 --- a/fedn/fedn/utils/helpers/helperbase.py +++ b/fedn/utils/helpers/helperbase.py @@ -4,16 +4,16 @@ class HelperBase(ABC): - """ Abstract class defining helpers. """ + """Abstract class defining helpers.""" def __init__(self): - """ Initialize helper. """ + """Initialize helper.""" self.name = self.__class__.__name__ @abstractmethod def increment_average(self, m1, m2, a, W): - """ Compute one increment of incremental weighted averaging. + """Compute one increment of incremental weighted averaging. :param m1: Current model weights in array-like format. :param m2: New model weights in array-like format. @@ -25,7 +25,7 @@ def increment_average(self, m1, m2, a, W): @abstractmethod def save(self, model, path): - """ Serialize weights to file. The serialized model must be a single binary object. + """Serialize weights to file. The serialized model must be a single binary object. :param model: Weights in array-like format. :param path: Path to file. @@ -35,7 +35,7 @@ def save(self, model, path): @abstractmethod def load(self, fh): - """ Load weights from file or filelike. + """Load weights from file or filelike. :param fh: file path, filehandle, filelike. :return: Weights in array-like format. @@ -43,10 +43,10 @@ def load(self, fh): pass def get_tmp_path(self): - """ Return a temporary output path compatible with save_model, load_model. + """Return a temporary output path compatible with save_model, load_model. :return: Path to file. """ - fd, path = tempfile.mkstemp(suffix='.npz') + fd, path = tempfile.mkstemp(suffix=".npz") os.close(fd) return path diff --git a/fedn/fedn/utils/helpers/helpers.py b/fedn/utils/helpers/helpers.py similarity index 84% rename from fedn/fedn/utils/helpers/helpers.py rename to fedn/utils/helpers/helpers.py index 6fcc4631e..8f25101b5 100644 --- a/fedn/fedn/utils/helpers/helpers.py +++ b/fedn/utils/helpers/helpers.py @@ -8,7 +8,7 @@ @tracer.start_as_current_span(name="get_helper") def get_helper(helper_module_name): - """ Return an instance of the helper class. + """Return an instance of the helper class. :param helper_module_name: The name of the helper plugin module. :type helper_module_name: str @@ -22,25 +22,25 @@ def get_helper(helper_module_name): @tracer.start_as_current_span("save_metadata") def save_metadata(metadata, filename): - """ Save metadata to file. + """Save metadata to file. :param metadata: The metadata to save. :type metadata: dict :param filename: The name of the file to save to. :type filename: str """ - with open(filename+'-metadata', 'w') as outfile: + with open(filename + "-metadata", "w") as outfile: json.dump(metadata, outfile) @tracer.start_as_current_span("save_metrics") def save_metrics(metrics, filename): - """ Save metrics to file. + """Save metrics to file. :param metrics: The metrics to save. :type metrics: dict :param filename: The name of the file to save to. :type filename: str """ - with open(filename, 'w') as outfile: + with open(filename, "w") as outfile: json.dump(metrics, outfile) diff --git a/fedn/fedn/utils/helpers/plugins/__init__.py b/fedn/utils/helpers/plugins/__init__.py similarity index 100% rename from fedn/fedn/utils/helpers/plugins/__init__.py rename to fedn/utils/helpers/plugins/__init__.py diff --git a/fedn/fedn/utils/helpers/plugins/androidhelper.py b/fedn/utils/helpers/plugins/androidhelper.py similarity index 86% rename from fedn/fedn/utils/helpers/plugins/androidhelper.py rename to fedn/utils/helpers/plugins/androidhelper.py index 87fc8c464..6340a1c8b 100644 --- a/fedn/fedn/utils/helpers/plugins/androidhelper.py +++ b/fedn/utils/helpers/plugins/androidhelper.py @@ -20,9 +20,7 @@ def __init__(self): # function to calculate an incremental weighted average of the weights - def increment_average( - self, model, model_next, num_examples, total_examples - ): + def increment_average(self, model, model_next, num_examples, total_examples): """Incremental weighted average of model weights. :param model: Current model weights. @@ -42,9 +40,7 @@ def increment_average( return (1 - w) * model + w * model_next # function to calculate an incremental weighted average of the weights using numpy.add - def increment_average_add( - self, model, model_next, num_examples, total_examples - ): + def increment_average_add(self, model, model_next, num_examples, total_examples): """Incremental weighted average of model weights. :param model: Current model weights. @@ -61,9 +57,7 @@ def increment_average_add( # Incremental weighted average w = np.add( model, - num_examples - * (np.array(model_next) - np.array(model)) - / total_examples, + num_examples * (np.array(model_next) - np.array(model)) / total_examples, ) return w @@ -77,7 +71,7 @@ def save(self, weights, path=None): if not path: path = self.get_tmp_path() - byte_array = struct.pack("f"*len(weights), *weights) + byte_array = struct.pack("f" * len(weights), *weights) with open(path, "wb") as file: file.write(byte_array) @@ -96,7 +90,7 @@ def load(self, fh): else: byte_data = fh.read() - weights = np.array(struct.unpack(f'{len(byte_data) // 4}f', byte_data)) + weights = np.array(struct.unpack(f"{len(byte_data) // 4}f", byte_data)) return weights diff --git a/fedn/fedn/utils/helpers/plugins/numpyhelper.py b/fedn/utils/helpers/plugins/numpyhelper.py similarity index 78% rename from fedn/fedn/utils/helpers/plugins/numpyhelper.py rename to fedn/utils/helpers/plugins/numpyhelper.py index fc902913d..857c57871 100644 --- a/fedn/fedn/utils/helpers/plugins/numpyhelper.py +++ b/fedn/utils/helpers/plugins/numpyhelper.py @@ -1,4 +1,3 @@ - import numpy as np from fedn.common.telemetry import trace_all_methods @@ -7,15 +6,15 @@ @trace_all_methods class Helper(HelperBase): - """ FEDn helper class for models weights/parameters that can be transformed to numpy ndarrays. """ + """FEDn helper class for models weights/parameters that can be transformed to numpy ndarrays.""" def __init__(self): - """ Initialize helper. """ + """Initialize helper.""" super().__init__() self.name = "numpyhelper" def increment_average(self, m1, m2, n, N): - """ Update a weighted incremental average of model weights. + """Update a weighted incremental average of model weights. :param m1: Current parameters. :type model: list of numpy ndarray @@ -29,10 +28,10 @@ def increment_average(self, m1, m2, n, N): :rtype: list of numpy ndarray """ - return [np.add(x, n*(y-x)/N) for x, y in zip(m1, m2)] + return [np.add(x, n * (y - x) / N) for x, y in zip(m1, m2)] def add(self, m1, m2, a=1.0, b=1.0): - """ m1*a + m2*b + """m1*a + m2*b :param model: Current model weights. :type model: list of ndarrays @@ -42,10 +41,10 @@ def add(self, m1, m2, a=1.0, b=1.0): :rtype: list of ndarrays """ - return [x*a+y*b for x, y in zip(m1, m2)] + return [x * a + y * b for x, y in zip(m1, m2)] def subtract(self, m1, m2, a=1.0, b=1.0): - """ m1*a - m2*b. + """m1*a - m2*b. :param m1: Current model weights. :type m1: list of ndarrays @@ -57,7 +56,7 @@ def subtract(self, m1, m2, a=1.0, b=1.0): return self.add(m1, m2, a, -b) def divide(self, m1, m2): - """ Subtract weights. + """Subtract weights. :param m1: Current model weights. :type m1: list of ndarrays @@ -70,7 +69,7 @@ def divide(self, m1, m2): return [np.divide(x, y) for x, y in zip(m1, m2)] def multiply(self, m1, m2): - """ Multiply m1 by m2. + """Multiply m1 by m2. :param m1: Current model weights. :type m1: list of ndarrays @@ -83,7 +82,7 @@ def multiply(self, m1, m2): return [np.multiply(x, y) for (x, y) in zip(m1, m2)] def sqrt(self, m1): - """ Sqrt of m1, element-wise. + """Sqrt of m1, element-wise. :param m1: Current model weights. :type model: list of ndarrays @@ -96,7 +95,7 @@ def sqrt(self, m1): return [np.sqrt(x) for x in m1] def power(self, m1, a): - """ m1 raised to the power of m2. + """m1 raised to the power of m2. :param m1: Current model weights. :type m1: list of ndarrays @@ -109,7 +108,7 @@ def power(self, m1, a): return [np.power(x, a) for x in m1] def norm(self, m): - """ Return the norm (L1) of model weights. + """Return the norm (L1) of model weights. :param m: Current model weights. :type m: list of ndarrays @@ -121,8 +120,19 @@ def norm(self, m): n += np.linalg.norm(x, 1) return n + def sign(self, m): + """Sign of m. + + :param m: Model parameters. + :type m: list of ndarrays + :return: sign(m) + :rtype: list of ndarrays + """ + + return [np.sign(x) for x in m] + def ones(self, m1, a): - """ Return a list of numpy arrays of the same shape as m1, filled with ones. + """Return a list of numpy arrays of the same shape as m1, filled with ones. :param m1: Current model weights. :type m1: list of ndarrays @@ -134,11 +144,11 @@ def ones(self, m1, a): res = [] for x in m1: - res.append(np.ones(np.shape(x))*a) + res.append(np.ones(np.shape(x)) * a) return res def save(self, weights, path=None): - """ Serialize weights to file. The serialized model must be a single binary object. + """Serialize weights to file. The serialized model must be a single binary object. :param weights: List of weights in numpy format. :param path: Path to file. @@ -156,7 +166,7 @@ def save(self, weights, path=None): return path def load(self, fh): - """ Load weights from file or filelike. + """Load weights from file or filelike. :param fh: file path, filehandle, filelike. :return: List of weights in numpy format. diff --git a/fedn/fedn/utils/helpers/tests/test_numpyhelper.py b/fedn/utils/helpers/tests/test_numpyhelper.py similarity index 100% rename from fedn/fedn/utils/helpers/tests/test_numpyhelper.py rename to fedn/utils/helpers/tests/test_numpyhelper.py diff --git a/fedn/utils/parameters.py b/fedn/utils/parameters.py new file mode 100644 index 000000000..53d3080b6 --- /dev/null +++ b/fedn/utils/parameters.py @@ -0,0 +1,51 @@ +from fedn.common.exceptions import InvalidParameterError + + +class Parameters(dict): + """Represents a collection of parameters. + + Extends dict and adds functionality to validate + paramteres types against a user-provided schema. + + Example of use: + p = Parameters({'n_iter': 10, 'beta': 1e-2}) + p.validate({'n_iter': int, 'beta': float}) + + """ + + def __init__(self, parameters=None): + """ """ + if parameters: + for key, value in parameters.items(): + self.__setitem__(key, value) + + def validate(self, parameter_schema): + """Validate parameters against a schema. + + :param parameter_schema: mapping of parameter name and data type. + :type parameter_schema: dict + :return: True if the parameters validate + :rtype: bool + """ + for key, value in self.items(): + if key not in parameter_schema.keys(): + raise InvalidParameterError("Parameter {} not in paramter schema".format(key)) + else: + type = parameter_schema[key] + self._validate_parameter_type(key, value, type) + + return True + + def _validate_parameter_type(self, key, value, type): + """Validate that parameters values matches the data type. + + :param key: mapping of parameter name and data type. + :type parameter_schema: dict + :return: True if the parameters validate + :rtype: bool + + """ + if not isinstance(value, type): + raise InvalidParameterError("Parameter {} has invalid type, expecting {}.".format(key, type)) + + return True diff --git a/fedn/utils/plots.py b/fedn/utils/plots.py new file mode 100644 index 000000000..7901e2374 --- /dev/null +++ b/fedn/utils/plots.py @@ -0,0 +1,432 @@ +import json +from datetime import datetime + +import numpy +import plotly +import plotly.graph_objs as go +from plotly.subplots import make_subplots + +from fedn.common.log_config import logger +from fedn.network.storage.statestore.mongostatestore import MongoStateStore + + +class Plot: + """ """ + + def __init__(self, statestore): + try: + statestore_config = statestore.get_config() + statestore = MongoStateStore(statestore_config["network_id"], statestore_config["mongo_config"]) + self.mdb = statestore.connect() + self.status = self.mdb["control.status"] + self.round_time = self.mdb["control.round_time"] + self.combiner_round_time = self.mdb["control.combiner_round_time"] + self.psutil_usage = self.mdb["control.psutil_monitoring"] + self.network_clients = self.mdb["network.clients"] + + except Exception as e: + logger.error("FAILED TO CONNECT TO MONGO, {}".format(e)) + self.collection = None + raise + + # plot metrics from DB + def _scalar_metrics(self, metrics): + """Extract all scalar valued metrics from a MODEL_VALIDATON.""" + + data = json.loads(metrics["data"]) + data = json.loads(data["data"]) + + valid_metrics = [] + for metric, val in data.items(): + # If it can be converted to a float it is a valid, scalar metric + try: + val = float(val) + valid_metrics.append(metric) + except Exception: + pass + + return valid_metrics + + def create_table_plot(self): + """ + + :return: + """ + metrics = self.status.find_one({"type": "MODEL_VALIDATION"}) + if metrics is None: + fig = go.Figure(data=[]) + fig.update_layout(title_text="No data currently available for table mean metrics") + table = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) + return False + + valid_metrics = self._scalar_metrics(metrics) + if valid_metrics == []: + fig = go.Figure(data=[]) + fig.update_layout(title_text="No scalar metrics found") + table = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) + return False + + all_vals = [] + models = [] + for metric in valid_metrics: + validations = {} + for post in self.status.find({"type": "MODEL_VALIDATION"}): + e = json.loads(post["data"]) + try: + validations[e["modelId"]].append(float(json.loads(e["data"])[metric])) + except KeyError: + validations[e["modelId"]] = [float(json.loads(e["data"])[metric])] + + vals = [] + models = [] + for model, data in validations.items(): + vals.append(numpy.mean(data)) + models.append(model) + all_vals.append(vals) + + header_vals = valid_metrics + models.reverse() + values = [models] + + for vals in all_vals: + vals.reverse() + values.append(vals) + + fig = go.Figure( + data=[ + go.Table( + header=dict(values=["Model ID"] + header_vals, line_color="darkslategray", fill_color="lightskyblue", align="left"), + cells=dict( + values=values, # 2nd column + line_color="darkslategray", + fill_color="lightcyan", + align="left", + ), + ) + ] + ) + + fig.update_layout(title_text="Summary: mean metrics") + table = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) + return table + + def create_timeline_plot(self): + """ + + :return: + """ + trace_data = [] + x = [] + y = [] + base = [] + for p in self.status.find({"type": "MODEL_UPDATE_REQUEST"}): + e = json.loads(p["data"]) + cid = e["correlationId"] + for cc in self.status.find({"sender": p["sender"], "type": "MODEL_UPDATE"}): + da = json.loads(cc["data"]) + if da["correlationId"] == cid: + cp = cc + + cd = json.loads(cp["data"]) + tr = datetime.strptime(e["timestamp"], "%Y-%m-%d %H:%M:%S.%f") + tu = datetime.strptime(cd["timestamp"], "%Y-%m-%d %H:%M:%S.%f") + ts = tu - tr + base.append(tr.timestamp()) + x.append(ts.total_seconds() / 60.0) + y.append(p["sender"]["name"]) + + trace_data.append( + go.Bar( + x=y, + y=x, + marker=dict(color="royalblue"), + name="Training", + ) + ) + + x = [] + y = [] + base = [] + for p in self.status.find({"type": "MODEL_VALIDATION_REQUEST"}): + e = json.loads(p["data"]) + cid = e["correlationId"] + for cc in self.status.find({"sender": p["sender"], "type": "MODEL_VALIDATION"}): + da = json.loads(cc["data"]) + if da["correlationId"] == cid: + cp = cc + cd = json.loads(cp["data"]) + tr = datetime.strptime(e["timestamp"], "%Y-%m-%d %H:%M:%S.%f") + tu = datetime.strptime(cd["timestamp"], "%Y-%m-%d %H:%M:%S.%f") + ts = tu - tr + base.append(tr.timestamp()) + x.append(ts.total_seconds() / 60.0) + y.append(p["sender"]["name"]) + + trace_data.append( + go.Bar( + x=y, + y=x, + marker=dict(color="lightskyblue"), + name="Validation", + ) + ) + + layout = go.Layout( + barmode="stack", + showlegend=True, + ) + + fig = go.Figure(data=trace_data, layout=layout) + fig.update_xaxes(title_text="Alliance/client") + fig.update_yaxes(title_text="Time (Min)") + fig.update_layout(title_text="Alliance timeline") + timeline = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) + return timeline + + def create_client_training_distribution(self): + """ + + :return: + """ + training = [] + for p in self.status.find({"type": "MODEL_UPDATE"}): + e = json.loads(p["data"]) + meta = json.loads(e["meta"]) + training.append(meta["exec_training"]) + + if not training: + return False + fig = go.Figure(data=go.Histogram(x=training)) + fig.update_layout(title_text="Client model training time, mean: {}".format(numpy.mean(training))) + histogram = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) + return histogram + + def create_client_histogram_plot(self): + """ + + :return: + """ + training = [] + for p in self.status.find({"type": "MODEL_UPDATE"}): + e = json.loads(p["data"]) + meta = json.loads(e["meta"]) + training.append(meta["exec_training"]) + + fig = go.Figure() + + fig.update_layout( + template="simple_white", + xaxis=dict(title_text="Time (s)"), + yaxis=dict(title_text="Number of updates"), + title="Mean client training time: {}".format(numpy.mean(training)), + # showlegend=True + ) + if not training: + return False + + fig.add_trace(go.Histogram(x=training)) + + histogram_plot = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) + return histogram_plot + + def create_client_plot(self): + """ + + :return: + """ + processing = [] + upload = [] + download = [] + training = [] + for p in self.status.find({"type": "MODEL_UPDATE"}): + e = json.loads(p["data"]) + meta = json.loads(e["meta"]) + upload.append(meta["upload_model"]) + download.append(meta["fetch_model"]) + training.append(meta["exec_training"]) + processing.append(meta["processing_time"]) + + fig = go.Figure() + fig.update_layout(template="simple_white", title="Mean client processing time: {}".format(numpy.mean(processing)), showlegend=True) + if not processing: + return False + data = [numpy.mean(training), numpy.mean(upload), numpy.mean(download)] + labels = ["Training execution", "Model upload (to combiner)", "Model download (from combiner)"] + fig.add_trace(go.Pie(labels=labels, values=data)) + + client_plot = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) + return client_plot + + def create_combiner_plot(self): + """ + + :return: + """ + waiting = [] + aggregation = [] + model_load = [] + combination = [] + for round in self.mdb["control.round"].find(): + try: + for combiner in round["combiners"]: + data = combiner + stats = data["local_round"]["1"] + ml = stats["aggregation_time"]["time_model_load"] + ag = stats["aggregation_time"]["time_model_aggregation"] + combination.append(stats["time_combination"]) + waiting.append(stats["time_combination"] - ml - ag) + model_load.append(ml) + aggregation.append(ag) + except Exception: + pass + + labels = ["Waiting for client updates", "Aggregation", "Loading model updates from disk"] + val = [numpy.mean(waiting), numpy.mean(aggregation), numpy.mean(model_load)] + fig = go.Figure() + + fig.update_layout(template="simple_white", title="Mean combiner round time: {}".format(numpy.mean(combination)), showlegend=True) + if not combination: + return False + fig.add_trace(go.Pie(labels=labels, values=val)) + combiner_plot = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) + return combiner_plot + + def fetch_valid_metrics(self): + """ + + :return: + """ + metrics = self.status.find_one({"type": "MODEL_VALIDATION"}) + valid_metrics = self._scalar_metrics(metrics) + return valid_metrics + + def create_box_plot(self, metric): + """ + + :param metric: + :return: + """ + metrics = self.status.find_one({"type": "MODEL_VALIDATION"}) + if metrics is None: + fig = go.Figure(data=[]) + fig.update_layout(title_text="No data currently available for metric distribution over " "participants") + box = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) + return box + + valid_metrics = self._scalar_metrics(metrics) + if valid_metrics == []: + fig = go.Figure(data=[]) + fig.update_layout(title_text="No scalar metrics found") + box = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) + return box + + validations = {} + for post in self.status.find({"type": "MODEL_VALIDATION"}): + e = json.loads(post["data"]) + try: + validations[e["modelId"]].append(float(json.loads(e["data"])[metric])) + except KeyError: + validations[e["modelId"]] = [float(json.loads(e["data"])[metric])] + + # Make sure validations are plotted in chronological order + model_trail = self.mdb.control.model.find_one({"key": "model_trail"}) + model_trail_ids = model_trail["model"] + validations_sorted = [] + for model_id in model_trail_ids: + try: + validations_sorted.append(validations[model_id]) + except Exception: + pass + + validations = validations_sorted + + box = go.Figure() + + y = [] + for j, acc in enumerate(validations): + # x.append(j) + y.append(numpy.mean([float(i) for i in acc])) + if len(acc) >= 2: + box.add_trace(go.Box(y=acc, name=str(j), marker_color="royalblue", showlegend=False, boxpoints=False)) + else: + box.add_trace(go.Scatter(x=[str(j)], y=[y[j]], showlegend=False)) + + rounds = list(range(len(y))) + box.add_trace(go.Scatter(x=rounds, y=y, name="Mean")) + + box.update_xaxes(title_text="Rounds") + box.update_yaxes(tickvals=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]) + box.update_layout(title_text="Metric distribution over clients: {}".format(metric), margin=dict(l=20, r=20, t=45, b=20)) + box = json.dumps(box, cls=plotly.utils.PlotlyJSONEncoder) + return box + + def create_round_plot(self): + """ + + :return: + """ + trace_data = [] + metrics = self.round_time.find_one({"key": "round_time"}) + if metrics is None: + fig = go.Figure(data=[]) + fig.update_layout(title_text="No data currently available for round time") + return False + + for post in self.round_time.find({"key": "round_time"}): + rounds = post["round"] + traces_data = post["round_time"] + + trace_data.append(go.Scatter(x=rounds, y=traces_data, mode="lines+markers", name="Reducer")) + + for rec in self.combiner_round_time.find({"key": "combiner_round_time"}): + c_traces_data = rec["round_time"] + + trace_data.append(go.Scatter(x=rounds, y=c_traces_data, mode="lines+markers", name="Combiner")) + + fig = go.Figure(data=trace_data) + fig.update_xaxes(title_text="Round") + fig.update_yaxes(title_text="Time (s)") + fig.update_layout(title_text="Round time") + round_t = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) + return round_t + + def create_cpu_plot(self): + """ + + :return: + """ + metrics = self.psutil_usage.find_one({"key": "cpu_mem_usage"}) + if metrics is None: + fig = go.Figure(data=[]) + fig.update_layout(title_text="No data currently available for MEM and CPU usage") + cpu = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) + return False + + for post in self.psutil_usage.find({"key": "cpu_mem_usage"}): + cpu = post["cpu"] + mem = post["mem"] + ps_time = post["time"] + round = post["round"] + + # Create figure with secondary y-axis + fig = make_subplots(specs=[[{"secondary_y": True}]]) + fig.add_trace(go.Scatter(x=ps_time, y=cpu, mode="lines+markers", name="CPU (%)")) + + fig.add_trace(go.Scatter(x=ps_time, y=mem, mode="lines+markers", name="MEM (%)")) + + fig.add_trace( + go.Scatter( + x=ps_time, + y=round, + mode="lines+markers", + name="Round", + ), + secondary_y=True, + ) + + fig.update_xaxes(title_text="Date Time") + fig.update_yaxes(title_text="Percentage (%)") + fig.update_yaxes(title_text="Round", secondary_y=True) + fig.update_layout(title_text="CPU loads and memory usage") + cpu = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) + return cpu diff --git a/fedn/fedn/utils/process.py b/fedn/utils/process.py similarity index 92% rename from fedn/fedn/utils/process.py rename to fedn/utils/process.py index e4ac62ccb..99caad2c8 100644 --- a/fedn/fedn/utils/process.py +++ b/fedn/utils/process.py @@ -15,6 +15,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import os import subprocess import sys @@ -69,7 +70,9 @@ def _exec_cmd( """A convenience wrapper of `subprocess.Popen` for running a command from a Python script. Args: + ---- cmd: The command to run, as a string or a list of strings. + cwd: The current working directory. throw_on_error: If True, raises an Exception if the exit code of the program is nonzero. extra_env: Extra environment variables to be defined when running the child process. If this argument is specified, `kwargs` cannot contain `env`. @@ -84,6 +87,7 @@ def _exec_cmd( kwargs: Keyword arguments (except `text`) passed to `subprocess.Popen`. Returns: + ------- If synchronous is True, return a `subprocess.CompletedProcess` instance, otherwise return a Popen instance. @@ -97,9 +101,7 @@ def _exec_cmd( raise ValueError("`extra_env` and `env` cannot be used at the same time") if capture_output and stream_output: - raise ValueError( - "`capture_output=True` and `stream_output=True` cannot be specified at the same time" - ) + raise ValueError("`capture_output=True` and `stream_output=True` cannot be specified at the same time") env = env if extra_env is None else {**os.environ, **extra_env} @@ -111,9 +113,7 @@ def _exec_cmd( if capture_output or stream_output: if kwargs.get("stdout") is not None or kwargs.get("stderr") is not None: - raise ValueError( - "stdout and stderr arguments may not be used with capture_output or stream_output" - ) + raise ValueError("stdout and stderr arguments may not be used with capture_output or stream_output") kwargs["stdout"] = subprocess.PIPE if capture_output: kwargs["stderr"] = subprocess.PIPE @@ -152,7 +152,7 @@ def _exec_cmd( @tracer.start_as_current_span("run_process") def run_process(args, cwd): - """ Run a process and log the output. + """Run a process and log the output. :param args: The arguments to the process. :type args: list @@ -160,11 +160,10 @@ def run_process(args, cwd): :type cwd: str :return: """ - status = subprocess.Popen( - args, cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + status = subprocess.Popen(args, cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) def check_io(): - """ Check stdout/stderr of the child process. + """Check stdout/stderr of the child process. :return: """ diff --git a/fedn/fedn/utils/tests/test_helpers.py b/fedn/utils/tests/test_helpers.py similarity index 100% rename from fedn/fedn/utils/tests/test_helpers.py rename to fedn/utils/tests/test_helpers.py diff --git a/fedn/utils/tests/test_parameters.py b/fedn/utils/tests/test_parameters.py new file mode 100644 index 000000000..a30855f90 --- /dev/null +++ b/fedn/utils/tests/test_parameters.py @@ -0,0 +1,69 @@ +import unittest + +from fedn.common.exceptions import InvalidParameterError +from fedn.utils.parameters import Parameters + + +class TestHelpers(unittest.TestCase): + + def test_parameters_invalidkey(self): + + parameters = { + 'serverop': 'adam', + 'learning_rate': 1e-3, + } + param = Parameters(parameters) + + parameter_schema = { + 'serveropt': str, + 'learning_rate': float, + } + + self.assertRaises(InvalidParameterError, param.validate, parameter_schema) + + def test_parameters_valid(self): + + parameters = { + 'serveropt': 'adam', + 'learning_rate': 1e-3, + 'beta1': 0.9, + 'beta2': 0.99, + 'tau': 1e-4, + } + + param = Parameters(parameters) + + parameter_schema = { + 'serveropt': str, + 'learning_rate': float, + 'beta1': float, + 'beta2': float, + 'tau': float, + } + + self.assertTrue(param.validate(parameter_schema)) + + def test_parameters_invalid(self): + + parameters = { + 'serveropt': 'adam', + 'learning_rate': 1e-3, + 'beta1': 0.9, + 'beta2': 0.99, + 'tau': 1e-4, + } + param = Parameters(parameters) + + parameter_schema = { + 'serveropt': str, + 'learning_rate': float, + 'beta1': float, + 'beta2': str, + 'tau': float, + } + + self.assertRaises(InvalidParameterError, param.validate, parameter_schema) + + +if __name__ == '__main__': + unittest.main() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..cfd0402a1 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,206 @@ +lint = ["ruff>=0.0.220"] # MIT License (MIT) + +[build-system] +requires = ["setuptools>=42", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "fedn" +version = "0.9.2" +description = "Scaleout Federated Learning" +authors = [{ name = "Scaleout Systems AB", email = "contact@scaleoutsystems.com" }] +readme = "README.rst" +license = {file="LICENSE"} +keywords = [ + "Scaleout", + "FEDn", + "Federated learning", + "FL", + "Machine learning", +] +classifiers = [ + "Natural Language :: English", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", +] + +requires-python = '>=3.8,<3.12' +dependencies = [ + "requests", + "urllib3>=1.26.4", + "minio", + "grpcio~=1.60.0", + "grpcio-tools~=1.60.0", + "numpy>=1.21.6", + "protobuf~=4.25.2", + "pymongo", + "Flask==3.0.3", + "pyjwt", + "pyopenssl", + "psutil", + "click==8.1.7", + "grpcio-health-checking~=1.60.0", + "pyyaml", + "plotly", + "virtualenv", + "opentelemetry-api", + "opentelemetry-sdk", + "opentelemetry-exporter-otlp" +] + +[project.urls] +homepage = "https://www.scaleoutsystems.com" +documentation = 'https://fedn.readthedocs.io/en/stable/' +repository = 'https://github.com/scaleoutsystems/fedn' + +[project.scripts] +fedn = "fedn.cli.main:main" + +[tool.setuptools] +include-package-data = true + +[tool.setuptools.packages.find] +where = ["."] +include = ["fedn*"] +exclude = ["tests", "tests.*"] + +[tool.ruff] +line-length = 160 +target-version = "py39" + +lint.select = [ + "ANN", # flake8-annotations + "ARG", # flake8-unused-arguments + "B", # flake8-bugbear + "C4", # flake8-comprehensions + "C90", # mccabe + "D", # pydocstyle + "DTZ", # flake8-datetimez + "E", # pycodestyle + "ERA", # eradicate + "F", # Pyflakes + "I", # isort + "N", # pep8-naming + "PD", # pandas-vet + "PGH", # pygrep-hooks + "PLC", # Pylint + "PLE", # Pylint + "PLR", # Pylint + "PLW", # Pylint + "PT", # flake8-pytest-style + "Q", # flake8-quotes + "RET", # flake8-return + "S", # flake8-bandit + "SIM", # flake8-simplify + "T20", # flake8-print + "TID", # flake8-tidy-imports + "W", # pycodestyle +] + +exclude = [ + ".venv", + ".mnist-keras", + ".mnist-pytorch", + "fedn_pb2.py", + "fedn_pb2_grpc.py", + ".ci", + "test*" +] + +lint.ignore = [ + "ANN002", # Missing type annotation for *args + "ANN003", # Missing type annotation for **kwargs + "ANN101", # Missing type annotation for self in method + "ANN102", # Missing type annotation for cls in method + "D107", # Missing docstring in __init__ + "D100", # Missing docstring in public module + "D200", # One-line docstring should fit on one line with quotes + "D210", # [*] No whitespaces allowed surrounding docstring text (100+) + "D104", # Missing docstring in public package (17) + "ANN201", # Missing return type annotation for public function (100+) + "ANN001", # Missing type annotation for function argument (100+) + "ANN205", # Missing return type annotation for staticmethod (5) + "RET504", # Unnecessary assignment to `settings` before `return` statement (72) + "ANN204", # Missing return type annotation for special method `__init__` (61) + "D205", # 1 blank line required between summary line and description (100+) + "T201", # `print` found (31) + "SIM401", # Use `result.get("id", "")` instead of an `if` block (72) + "D400", # First line should end with a period (80) + "D415", # First line should end with a period, question mark, or exclamation point (80) + "D101", # Missing docstring in public class (30) + "S113", # Probable use of requests call without timeout (41) + "PLR2004", # Magic value used in comparison, consider replacing `200` with a constant variable + "PLR0913", # Too many arguments in function definition (31) + "ANN202", # Missing return type annotation for private function (41) + "D102", # Missing docstring in public method (64) + "SIM108", # Use ternary operator instead of `if`-`else`-block (20) + "RET505", # Unnecessary `else` after `return` statement (20) + "D103", # Missing docstring in public function (17) + "D401", # First line of docstring should be in imperative mood (24) + "N818", # Exception name should be named with an Error suffix (8) + "B904", # Within an `except` clause, raise exceptions with `raise ... from err` or `raise ... from None` to distinguish them from errors in exception handling (11) + "DTZ005", # The use of `datetime.datetime.now()` without `tz` argument is not allowed (18) + "ANN206", # Missing return type annotation for classmethod (1) + "S110", # `try`-`except`-`pass` detected, consider logging the exception (3) + "N803", # Argument name should be lowercase + "N805", # First argument of a method should be named `self` + "SIM118", # Use `key in dict` instead of `key in dict.keys()` + "SIM115", # Use context handler for opening files + "B027", # `StateStoreBase.__init__` is an empty method in an abstract base class, but has no abstract decorator + "ARG002", # Unused method argument: `use_typing` + "B006", # Do not use mutable data structures for argument defaults + "PLR1714", # Consider merging multiple comparisons: `retcheck in ("", " ")`. Use a `set` if the elements are hashable. + "ERA001", # Found commented-out code + "N802", # Function name should be lowercase + "SIM116", # Use a dictionary instead of consecutive `if` statements + "RET503", # Missing explicit `return` at the end of function able to return non-`None` value + "PLR0911", # Too many return statements (11 > 6) + "C901", # function is too complex (11 > 10) + "ARG001", # Unused function argument: + "SIM105", # Use `contextlib.suppress(KeyError)` instead of `try`-`except`-`pass` + "PLR0915", # Too many statements + "B024", # `Config` is an abstract base class, but it has no abstract methods + "RET506", # Unnecessary `else` after `raise` statement + "N804", # First argument of a class method should be named `cls` + "S202", # Uses of `tarfile.extractall()` + "PLR0912", # Too many branches + "SIM211", # Use `not ...` instead of `False if ... else True` + "D404", # First word of the docstring should not be "This" + "PLW0603", # Using the global statement to update ... is discouraged + "D105", # Missing docstring in magic method + "PLR1722", # Use `sys.exit()` instead of `exit` + "C408", # Unnecessary `dict` call (rewrite as a literal) + "DTZ007", # The use of `datetime.datetime.strptime()` without %z must be followed by `.replace(tzinfo=)` or `.astimezone()` + "PLW2901", # `for` loop variable `val` overwritten by assignment target + "D419", # Docstring is empty + "C416", # Unnecessary `list` comprehension (rewrite using `list()`) + "SIM102", # Use a single `if` statement instead of nested `if` statements + "PLW1508", # Invalid type for environment variable default; expected `str` or `None` + "B007", # Loop control variable `v` not used within loop body + "N806", # Variable `X_test` in function should be lowercase + + # solved with --fix + "Q000", # [*] Single quotes found but double quotes preferred + "D212", # [*] Multi-line docstring summary should start at the first line + "D213", # [*] Multi-line docstring summary should start at the second line + "D202", # [*] No blank lines allowed after function docstring (found 1) + "D209", # [*] Multi-line docstring closing quotes should be on a separate line + "D204", # [*] 1 blank line required after class docstring + "SIM114", # [*] Combine `if` branches using logical `or` operator + "D208", # [*] Docstring is over-indented + "I001", # [*] Import block is un-sorted or un-formatted + "SIM103", # Return the condition directly + "PLR5501", # [*] Use `elif` instead of `else` then `if`, to reduce indentation + "RET501", # [*] Do not explicitly `return None` in function if it is the only possible return value + "PLW0120", # [*] `else` clause on loop without a `break` statement; remove the `else` and dedent its contents + + # unsafe? + "S104", # Possible binding to all interfaces + + "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes + "S501", # Probable use of `requests` call with `verify=False` disabling SSL certificate checks + "S108", # Probable insecure usage of temporary file or directory: "/tmp/models" + "S603", # `subprocess` call: check for execution of untrusted input +] diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index b98843069..000000000 --- a/setup.cfg +++ /dev/null @@ -1,5 +0,0 @@ -[flake8] -max-line-length = 160 - -[pep8] -max-line-length = 160 \ No newline at end of file