diff --git a/.vscode/settings.json b/.vscode/settings.json index 07cfc57ae..d4c2ea8ad 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,7 +1,7 @@ { "editor.formatOnSave": true, "editor.codeActionsOnSave": { - "source.organizeImports": true + "source.organizeImports": "explicit" }, "python.linting.enabled": true, "python.linting.flake8Enabled": true, diff --git a/README.rst b/README.rst index 1e3c42cb6..03c851a6a 100644 --- a/README.rst +++ b/README.rst @@ -1,5 +1,3 @@ -.. figure:: https://thumb.tildacdn.com/tild6637-3937-4565-b861-386330386132/-/resize/560x/-/format/webp/FEDn_logo.png - :alt: FEDn logo .. image:: https://github.com/scaleoutsystems/fedn/actions/workflows/integration-tests.yaml/badge.svg :target: https://github.com/scaleoutsystems/fedn/actions/workflows/integration-tests.yaml @@ -10,73 +8,77 @@ .. image:: https://readthedocs.org/projects/fedn/badge/?version=latest&style=flat :target: https://fedn.readthedocs.io -FEDn is a modular and model agnostic framework for -federated machine learning. FEDn is designed to scale from pseudo-distributed -development on your laptop to real-world production setups in geographically distributed environments. +FEDn +-------- + +FEDn empowers developers, researchers, and data scientists to create federated learning applications that seamlessly transition from local proofs-of-concept to real-world distributed deployments. Develop your federated learning use case in a pseudo-local environment, and deploy it to FEDn Studio for real-world Federated Learning without any need for code changes. Core Features ============= -- **Scalable and resilient.** FEDn is scalable and resilient via a tiered - architecture where multiple aggregation servers (combiners) divide up the work to coordinate clients and aggregate models. - Benchmarks show high performance both for thousands of clients in a cross-device - setting and for large model updates in a cross-silo setting. - FEDn has the ability to recover from failure in all critical components. - -- **Security**. FEDn is built using secure industry standard communication protocols (gRPC). A key feature is that - clients do not have to expose any ingress ports. +- **Scalable and resilient.** FEDn facilitates the coordination of clients and model aggregation through multiple aggregation servers sharing the workload. This design makes the framework highly scalable, accommodating large numbers of clients. The system is engineered to seamlessly recover from failures, ensuring robust deployment in production environments. Furthermore, FEDn adeptly manages asynchronous federated learning scenarios, accommodating clients that may connect or drop out during training. -- **Track events and training progress in real-time**. FEDn tracks events for clients and aggregation servers, logging to MongoDB. This - helps developers monitor traning progress in real-time, and to troubleshoot the distributed computation. - Tracking and model validation data can easily be retrieved using the API enabling development of custom dashboards and visualizations. +- **Security**. FL clients do not need to open any ingress ports, facilitating real-world deployments across a wide variety of settings. Additionally, FEDn utilizes secure, industry-standard communication protocols and supports token-based authentication for FL clients, enhancing security and ease of integration in diverse environments. -- **Flexible handling of asynchronous clients**. FEDn supports flexible experimentation - with clients coming in and dropping out during training sessions. Extend aggregators to experiment - with different strategies to handle so called stragglers. +- **Track events and training progress in real-time**. Extensive event logging and distributed tracing enable developers to monitor experiments in real-time, simplifying troubleshooting and auditing processes. Machine learning validation metrics from clients can be accessed via the API, allowing for flexible analysis of federated experiments. -- **ML-framework agnostic**. Model updates are treated as black-box - computations. This means that it is possible to support any - ML model type or framework. Support for Keras and PyTorch is +- **ML-framework agnostic**. FEDn is compatible with all major ML frameworks. Examples for Keras, PyTorch and scikit-learn are available out-of-the-box. +From development to real-world FL: + +- Develop a FEDn project in a local development environment, and then deploy it to FEDn Studio +- The FEDn server-side as a managed, production-grade service on Kubernetes. +- Token-based authentication for FL clients +- Role-based access control (RBAC) +- REST API +- Dashboard for orchestrating runs, visualizing and downloading results +- Admin dashboard for managing and scaling the FEDn network +- Collaborate with other data-scientists in a shared workspace. +- Cloud or on-premise deployment -Getting started + +Getting started with FEDn =============== The best way to get started is to take the quickstart tutorial: -- `Quickstart PyTorch `__ +- `Quickstart `__ Documentation ============= -You will find more details about the architecture, compute package and how to deploy FEDn fully distributed in the documentation: + +More details about the architecture, deployment, and how to develop your own application and framework extensions (such as custom aggregators) are found in the documentation: - `Documentation `__ -- `Paper `__ -FEDn Studio +Deploying a project to FEDn Studio =============== -Scaleout also develops FEDn Studio, a web application that extends the FEDn SDK with a UI, production-grade deployment of the FEDn server side on Kubernetes, user authentication/authorization, client identity/API-token management, and project-based multitenancy for segmenting work and resources into collaboration workspaces. FEDn Studio is available as a fully managed service. -There is also additional tooling and charts for self-managed deployment on Kubernetes including integration with several projects from the cloud native landscape. -See `FEDn Framework `__ for more information. +Studio offers a production-grade deployment of the FEDn server-side infrastructure on Kubernetes. With Studio, you can also manage token-based authentication for clients and collaborate with other users in joint project workspaces. In addition to a REST API, Studio features intuitive dashboards that allows you to orchestrate FL experiments and visualize and manage global models, event logs and metrics. These features enhance your ability to monitor and analyze federated learning projects. Studio is available as-a service hosted by Scaleout and one project is provided for free for testing and research. -Making contributions -==================== +- `Register for a project in Studio `__ +- `Deploy you project to FEDn Studio `__ -All pull requests will be considered and are much appreciated. Reach out -to one of the maintainers if you are interested in making contributions, -and we will help you find a good first issue to get you started. For -more details please refer to our `contribution -guidelines `__. +Options and charts are also available for self-managed deployment of FEDn Studio, reach out to the Scaleout team for more information. -Community support + +Support ================= Community support in available in our `Discord server `__. +Options are also available for `Enterprise support `__. + +Making contributions +==================== + +All pull requests will be considered and are much appreciated. For +more details please refer to our `contribution +guidelines `__. + Citation ======== @@ -91,10 +93,6 @@ If you use FEDn in your research, please cite: year={2021} } -Organizational collaborators, contributors and supporters -========================================================= - -|FEDn logo| |UU logo| |AI Sweden logo| |Zenseact logo| |Scania logo| License ======= @@ -102,13 +100,4 @@ License FEDn is licensed under Apache-2.0 (see `LICENSE `__ file for full information). -.. |FEDn logo| image:: https://github.com/scaleoutsystems/fedn/raw/master/docs/img/logos/Scaleout.png - :width: 15% -.. |UU logo| image:: https://github.com/scaleoutsystems/fedn/raw/master/docs/img/logos/UU.png - :width: 15% -.. |AI Sweden logo| image:: https://github.com/scaleoutsystems/fedn/raw/master/docs/img/logos/ai-sweden-logo.png - :width: 15% -.. |Zenseact logo| image:: https://github.com/scaleoutsystems/fedn/raw/master/docs/img/logos/zenseact-logo.png - :width: 15% -.. |Scania logo| image:: https://github.com/scaleoutsystems/fedn/raw/master/docs/img/logos/Scania.png - :width: 15% +Use of FEDn Studio (SaaS) is subject to the `Terms of Use `__. diff --git a/docs/auth.rst b/docs/auth.rst new file mode 100644 index 000000000..1866bd4f1 --- /dev/null +++ b/docs/auth.rst @@ -0,0 +1,90 @@ +.. _auth-label: + +Authentication and Authorization (RBAC) +============================================= +.. warning:: The FEDn RBAC system is an experimental feature and may change in the future. + +FEDn supports Role-Based Access Control (RBAC) for controlling access to the FEDn API and gRPC endpoints. The RBAC system is based on JSON Web Tokens (JWT) and is implemented using the `jwt` package. The JWT tokens are used to authenticate users and to control access to the FEDn API. +There are two types of JWT tokens used in the FEDn RBAC system: +- Access tokens: Used to authenticate users and to control access to the FEDn API. +- Refresh tokens: Used to obtain new access tokens when the old ones expire. + +.. note:: Please note that the FEDn RBAC system is not enabled by default and does not issue JWT tokens. It is used to integrate with external authentication and authorization systems such as FEDn Studio. + +FEDn RBAC system is by default configured with four types of roles: +- `admin`: Has full access to the FEDn API. This role is used to manage the FEDn network using the API client or the FEDn CLI. +- `combiner`: Has access to the /add_combiner endpoint in the API. +- `client`: Has access to the /add_client endpoint in the API and various gRPC endpoint to participate in federated learning sessions. + +A full list of the "roles to endpoint" mappings for gRPC can be found in the `fedn/network/grpc/auth.py`. For the API, the mappings are defined using custom decorators defined in `fedn/network/api/auth.py`. + +.. note:: The roles are handled by a custom claim in the JWT token called `role`. The claim is used to control access to the FEDn API and gRPC endpoints. + +To enable the FEDn RBAC system, you need to set the following environment variables in the controller and combiner: + +.. envvar:: FEDN_JWT_SECRET_KEY + :type: str + :required: yes + :default: None + :description: The secret key used for JWT token encryption. + +.. envvar:: FEDN_JWT_ALGORITHM + :type: str + :required: no + :default: "HS256" + :description: The algorithm used for JWT token encryption. + +.. envvar:: FEDN_AUTH_SCHEME + :type: str + :required: no + :default: "Token" + :description: The authentication scheme used in the FEDn API and gRPC interceptors. + +For further fexibility, you can also set the following environment variables: + +.. envvar:: FEDN_CUSTOM_URL_PREFIX + :type: str + :required: no + :default: None + :description: Add a custom URL prefix used in the FEDn API, such as /internal or /v1. + +.. envvar:: FEDN_AUTH_WHITELIST_URL + :type: str + :required: no + :default: None + :description: A URL patterns to the API that should be excluded from the FEDn RBAC system. For example /internal (to enable internal API calls). + +.. envvar:: FEDN_JWT_CUSTOM_CLAIM_KEY + :type: str + :required: no + :default: None + :description: The custom claim key used in the JWT token. + +.. envvar:: FEDN_JWT_CUSTOM_CLAIM_VALUE + :type: str + :required: no + :default: None + :description: The custom claim value used in the JWT token. + + +For the client you need to set the following environment variables: + +.. envvar:: FEDN_JWT_ACCESS_TOKEN + :type: str + :required: yes + :default: None + :description: The access token used to authenticate the client to the FEDn API. + +.. envvar:: FEDN_JWT_REFRESH_TOKEN + :type: str + :required: no + :default: None + :description: The refresh token used to obtain new access tokens when the old ones expire. + +.. envvar:: FEDN_AUTH_SCHEME + :type: str + :required: no + :default: "Token" + :description: The authentication scheme used in the FEDn API and gRPC interceptors. + +You can also use `--token` flags in the FEDn CLI to set the access token. \ No newline at end of file diff --git a/docs/distributed.rst b/docs/distributed.rst new file mode 100644 index 000000000..179369eb2 --- /dev/null +++ b/docs/distributed.rst @@ -0,0 +1,70 @@ +Distributed Deployment +=================================== + +This tutorial outlines the steps for deploying the FEDn framework over a **local network**, using a workstation as +the host and different devices as clients. For general steps on how to run FEDn, see one of the quickstart tutorials. + + +.. note:: + For a secure and production-grade deployment solution over **public networks**, explore the FEDn Studio service at + **studio.scaleoutsystems.com**. + + Alternatively follow this tutorial substituting the hosts local IP with your public IP, open the neccesary + ports (see which ports are used in docker-compose.yaml), and ensure you have taken additional neccesary security + precautions. + +Prerequisites +------------- +- `One host workstation and atleast one client device` +- `Python 3.8, 3.9 or 3.10 `__ +- `Docker `__ +- `Docker Compose `__ + +Launch a distributed FEDn Network +------------- + + +Start by noting your host's local IP address, used within your network. Discover it by running ifconfig on UNIX or +ipconfig on Windows, typically listed under inet for Unix and IPv4 for Windows. + +Continue with following the standard procedure to initiate a FEDn network, for example using by docker-compose. +Once the network is active, upload your compute package and seed (for comprehensive details, see the quickstart tutorials). + + +Configuring and Attaching Clients +------------- + +On your client device, continue with initializing your client. To connect to the host machine we need to ensure we are +routing the correct DNS to our hosts local IP address. We can do this using the standard FEDn `client.yaml`: + +.. code-block:: + + network_id: fedn-network + discover_host: api-server + discover_port: 8092 + + +We can then run using docker by adding the hosts in the docker run command: + +.. code-block:: + + docker run \ + -v $PWD/client.yaml: \ + + —add-host=api-server: \ + —add-host=combiner: \ + run client -in client.yaml --name client1 + + +Alternatively updating the `/etc/hosts` file, appending the following lines for running naitively: + +.. code-block:: + + api-server + combiner + + +Start a training session +------------- + +After connecting with your clients, you are ready to start training sessions from the host machine. \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index 3f09dbda5..098b37a2e 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -10,11 +10,13 @@ :caption: Documentation studio + distributed apiclient tutorial architecture aggregators helpers + auth .. toctree:: :maxdepth: 2 diff --git a/examples/async-simulation/.gitignore b/examples/async-clients/.gitignore similarity index 90% rename from examples/async-simulation/.gitignore rename to examples/async-clients/.gitignore index 4ab9fa59f..a3e7562db 100644 --- a/examples/async-simulation/.gitignore +++ b/examples/async-clients/.gitignore @@ -2,5 +2,6 @@ data *.npz *.tgz *.tar.gz +*.log .async-simulation client.yaml \ No newline at end of file diff --git a/examples/async-clients/Experiment.ipynb b/examples/async-clients/Experiment.ipynb new file mode 100644 index 000000000..1035eb3e4 --- /dev/null +++ b/examples/async-clients/Experiment.ipynb @@ -0,0 +1,387 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "622f7047", + "metadata": {}, + "source": [ + "## Cross-device FL with FEDn Part I: A local development environment for intermittent and asyncronous clients \n", + "\n", + "In this example we set up a local development environment for experimenting with cross-device use-cases. We will here use FEDn in pseudo-local mode and simulate a fleet of intermittent and asynchronous clients solving a classification problem using incremental learning.\n", + "\n", + "A key feature of this research sandbox is that while experiments are able to run on a single laptop or workstation - the same code will seamlessly transition to distributed and real-world deployments (this we will study in future posts in this series)." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "1a2686dd", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "from sklearn.datasets import make_classification\n", + "from sklearn.neural_network import MLPClassifier\n", + "from sklearn.metrics import accuracy_score\n", + "from sklearn.model_selection import train_test_split\n", + "from client.entrypoint import compile_model, load_parameters, make_data \n", + "\n", + "\n", + "from fedn import APIClient\n", + "import uuid\n", + "import json\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import copy\n", + "\n", + "import warnings\n", + "warnings.filterwarnings(\"ignore\")" + ] + }, + { + "cell_type": "markdown", + "id": "e4ab4a64", + "metadata": {}, + "source": [ + "### The ML model\n", + "\n", + "As a centralized model baseline we generate synthetic data for a classification problem with 4 features. We train a MLPClassifier using ReLU activation and Adam as optimizer, on 80k training points, then test on 20k points. A maximum of 1000 epochs is used for centralized training. " + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "70c5f5c9", + "metadata": {}, + "outputs": [], + "source": [ + "X, y = make_classification(n_samples=100000, n_features=4, n_informative=4, n_redundant=0, random_state=42)\n", + "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)" + ] + }, + { + "cell_type": "markdown", + "id": "1a121b39", + "metadata": {}, + "source": [ + "We train a centralized baseline model for a maximum of 1000 epochs. " + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "a985c6b3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training accuracy: 0.9208875\n", + "Test accuracy: 0.9193\n" + ] + } + ], + "source": [ + "clf = MLPClassifier(max_iter=1000)\n", + "clf.fit(X_train, y_train)\n", + "central_test_acc = accuracy_score(y_test, clf.predict(X_test))\n", + "\n", + "print(\"Training accuracy: \", accuracy_score(y_train, clf.predict(X_train)))\n", + "print(\"Test accuracy: \", accuracy_score(y_test, clf.predict(X_test)))" + ] + }, + { + "cell_type": "markdown", + "id": "b4976986", + "metadata": {}, + "source": [ + "Next we simulate the FL training procedure each individual FL client will follow. The client will in each iteration (=simulated global round) draw a random number of data points in the interval (n_min, n_max) from (X_train, y_train) and perform 'n_epochs' partial fits on the sampled dataset. Then for each global round we test on the centralized test set (X_test, y_test). In this experiment we simulate 600 global rounds. The client performs 10 local epochs in each round. " + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "34ce6b7d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "clf = compile_model()\n", + "\n", + "n_global_rounds=600\n", + "n_epochs = 10\n", + "central_acc_one_client = []\n", + "for i in range(n_global_rounds):\n", + " x,y,_,_ = make_data(n_min=10,n_max=100)\n", + " for j in range(n_epochs):\n", + " clf.partial_fit(x, y)\n", + " central_acc_one_client.append(accuracy_score(y_test, clf.predict(X_test)))\n", + "\n", + "plt.plot(range(n_global_rounds),[central_test_acc]*n_global_rounds)\n", + "plt.plot(range(n_global_rounds), central_acc_one_client)" + ] + }, + { + "cell_type": "markdown", + "id": "1dfb237f", + "metadata": {}, + "source": [ + "We proceed by simulating the scenario that a number 'n_clients' clients in a fleet of devices send their locally collected/sampled datasets to a central server (we emulate this by scaling n_min and n_max by n_clients). The server then performs incremental learning using the collected data batches (which are thus larger than in the experiment above by a factor n_clients). " + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "79f02df8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "clf = compile_model()\n", + "\n", + "n_global_rounds=600\n", + "n_epochs = 10\n", + "n_clients = 10\n", + "central_acc_all_clients = []\n", + "for i in range(n_global_rounds):\n", + " x,y,_,_ = make_data(n_min=n_clients*10, n_max=n_clients*100)\n", + " for j in range(n_epochs):\n", + " clf.partial_fit(x, y)\n", + " central_acc_all_clients.append(accuracy_score(y_test, clf.predict(X_test)))\n", + "\n", + "plt.plot(range(n_global_rounds),[central_test_acc]*n_global_rounds)\n", + "plt.plot(range(n_global_rounds), central_acc_one_client)\n", + "plt.plot(range(n_global_rounds), central_acc_all_clients)\n", + "plt.legend(['Central baseline, all data','Incremental learning, one client','Inceremental learning, all clients'])" + ] + }, + { + "cell_type": "markdown", + "id": "06c05caf", + "metadata": {}, + "source": [ + "### Federated learning with clients connecting and disconnecting intermittently \n", + "\n", + "The figure below illustrates the federated learning scenario. We will run clients that: \n", + "\n", + "- Connect to the server at a random time t_{on}\n", + "- Stay online for training for a fixed period of time (e.g one minute).\n", + "- Disconnect at time t_{off}. \n", + "\n", + "This completes one cycle in our setup, which we then repeat a configurable number of cycles. " + ] + }, + { + "cell_type": "markdown", + "id": "dfadf3e3", + "metadata": {}, + "source": [ + "![title](img/async-clients.png)" + ] + }, + { + "cell_type": "markdown", + "id": "037cae62", + "metadata": {}, + "source": [ + "### Running the experiment\n", + "\n", + "Now we run federated learning experiments over a FEDn network. For this we first need to start a pseudo-distributed FEDn network (we can use the provided docker-compose template). To run clients that follow the logic in the illustration above will use the script 'run_clients.py'. This script will start clients running in subprocesses on the host machine. Once clients are up and running, you can proceed below and exectute experiments using the script 'run_experiment.py'. Note that these runs can take some time to complete (600 global rounds for 10 clients took about 1.5h on a 2020 MacBook Pro). " + ] + }, + { + "cell_type": "markdown", + "id": "1046a4e5", + "metadata": {}, + "source": [ + "### Analyzing the resluts using the FEDn APIClient\n", + "\n", + "One the experiment has been started we can use the FEDn APIClient to analyze the result. We make a client connection to the FEDn API service. Here we assume that FEDn is deployed locally in pseudo-distributed mode with default ports." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "1061722d", + "metadata": {}, + "outputs": [], + "source": [ + "DISCOVER_HOST = '127.0.0.1'\n", + "DISCOVER_PORT = 8092\n", + "client = APIClient(DISCOVER_HOST, DISCOVER_PORT)" + ] + }, + { + "cell_type": "markdown", + "id": "29552af9", + "metadata": {}, + "source": [ + "Next, we retrive global models for this session and score the models on the central test set. We can use the API client to download model paramters for a given global model_id. We then initialize a model and set the loaded parameters. " + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "5fdabea3", + "metadata": {}, + "outputs": [], + "source": [ + "def load_fedn_model(model_id):\n", + "\n", + " data = client.download_model(model_id, 'temp.npz')\n", + " parameters = load_parameters('temp.npz')\n", + " model = compile_model()\n", + " n = len(parameters)//2\n", + " model.coefs_ = parameters[:n]\n", + " model.intercepts_ = parameters[n:]\n", + " return model" + ] + }, + { + "cell_type": "markdown", + "id": "296e04ee", + "metadata": {}, + "source": [ + "By default, 'get_model_trail' returns a list of all models leading up to the current active global model. We iterate over all these models and score them on the centralized test set. " + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "id": "13b75b1c", + "metadata": {}, + "outputs": [], + "source": [ + "model_trail_fedavg = client.get_model_trail()\n", + "\n", + "acc_fedavg = []\n", + "for model in model_trail_fedavg: \n", + " model = load_fedn_model(model['id'])\n", + " acc_fedavg.append(accuracy_score(y_test, model.predict(X_test)))" + ] + }, + { + "cell_type": "markdown", + "id": "40db4542", + "metadata": {}, + "source": [ + "Plot the result. " + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "id": "f0c3c51c", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 50, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "x = range(1,len(acc_fedavg)+1)\n", + "plt.plot(x,[central_test_acc]*len(x))\n", + "plt.plot(range(n_global_rounds), central_acc_one_client)\n", + "plt.plot(range(n_global_rounds), central_acc_all_clients)\n", + "plt.plot(range(len(acc_fedavg)),acc_fedavg)\n", + "plt.xlabel('Global round')\n", + "plt.ylabel('Accuracy score')\n", + "plt.legend(['Centralized baseline', 'Incremental learning, one client','Incremental learning, all clients', 'FL (10 clients, FedAdam)'])" + ] + }, + { + "cell_type": "markdown", + "id": "11241c81", + "metadata": {}, + "source": [ + "As can be seen, FEDn trains a federated model that reaches the same level of performace as the centralized baseline, with convergence close to the simulated case where 10 clients send data to a central server. Here we used FedAdam with a fixed learning rate 1e-2 as the server-side aggregator. It is possible that hyperparameter tuning, or adapting the learning rate, could improve convergence further. This was not the focus of this experiment though - the objective was to set up a pseudo-local experiment environment where clients connect intermittently and validate the robustness of FEDn in this scenario. In future parts of this series we will build on this in different ways as we explore various aspects of cross-device FL. " + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + }, + "vscode": { + "interpreter": { + "hash": "21345b455230dd04cf84c108e7c182ecfe8d1aa1242b8b64881a6d2c0a5951ac" + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/async-clients/README.md b/examples/async-clients/README.md new file mode 100644 index 000000000..2e2724104 --- /dev/null +++ b/examples/async-clients/README.md @@ -0,0 +1,77 @@ +# ASYNC CLIENTS +This example shows how to experiment with intermittent and asynchronous client workflows. + +## Prerequisites +- [Python 3.8, 3.9 or 3.10](https://www.python.org/downloads) +- [Docker](https://docs.docker.com/get-docker) +- [Docker Compose](https://docs.docker.com/compose/install) + +## Running the example (pseudo-distributed, single host) + +First, make sure that FEDn is installed (we recommend using a virtual environment) + +Clone FEDn +```sh +git clone https://github.com/scaleoutsystems/fedn.git +``` + +Install FEDn and dependencies + +`` +pip install fedn +``` + +Or from source, standing in the folder 'fedn/fedn' + +``` +pip install . +``` + +### Prepare the example environment, the compute package and seed model + +Standing in the folder fedn/examples/async-clients +``` +pip install -r requirements.txt +``` + +Create the compute package and seed model: +``` +tar -czvf package.tgz client +``` + +``` +python client/entrypoint init_seed +``` + +You will now have a file 'seed.npz' in the directory. + +### Running a simulation + +Deploy FEDn on localhost. Standing in the the FEDn root directory: + +``` +docker-compose up +``` + +Initialize FEDn with the compute package and seed model + +``` +python init_fedn.py +``` + +Start simulating clients +``` +python run_clients.py +``` + +Start the experiment / training sessions: + +``` +python run_experiment.py +``` + +Once global models start being produced, you can start analyzing results using API Client, refer to the notebook "Experiment.ipynb" for instructions. + + + + diff --git a/examples/async-clients/client/entrypoint.py b/examples/async-clients/client/entrypoint.py new file mode 100644 index 000000000..4ddddd956 --- /dev/null +++ b/examples/async-clients/client/entrypoint.py @@ -0,0 +1,142 @@ +# /bin/python +import fire +import numpy as np +from sklearn.datasets import make_classification +from sklearn.metrics import accuracy_score +from sklearn.model_selection import train_test_split +from sklearn.neural_network import MLPClassifier + +from fedn.utils.helpers.helpers import get_helper, save_metadata, save_metrics + +HELPER_MODULE = 'numpyhelper' +ARRAY_SIZE = 10000 + + +def compile_model(max_iter=1): + clf = MLPClassifier(max_iter=max_iter) + # This is needed to initialize some state variables needed to make predictions + # We will overwrite weights and biases during FL training + X_train, y_train, _, _ = make_data() + clf.fit(X_train, y_train) + return clf + + +def save_parameters(model, out_path): + """ Save model to disk. + + :param model: The model to save. + :type model: torch.nn.Module + :param out_path: The path to save to. + :type out_path: str + """ + helper = get_helper(HELPER_MODULE) + parameters = model.coefs_ + model.intercepts_ + + helper.save(parameters, out_path) + + +def load_parameters(model_path): + """ Load model from disk. + + param model_path: The path to load from. + :type model_path: str + :return: The loaded model. + :rtype: torch.nn.Module + """ + helper = get_helper(HELPER_MODULE) + parameters = helper.load(model_path) + + return parameters + + +def init_seed(out_path='seed.npz'): + """ Initialize seed model. + + :param out_path: The path to save the seed model to. + :type out_path: str + """ + # Init and save + model = compile_model() + save_parameters(model, out_path) + + +def make_data(n_min=50, n_max=100): + """ Generate / simulate a random number n data points. + + n will fall in the interval (n_min, n_max) + + """ + n_samples = 100000 + X, y = make_classification(n_samples=n_samples, n_features=4, n_informative=4, n_redundant=0, random_state=42) + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) + + n = np.random.randint(n_min, n_max, 1)[0] + ind = np.random.choice(len(X_train), n) + X_train = X_train[ind, :] + y_train = y_train[ind] + return X_train, y_train, X_test, y_test + + +def train(in_model_path, out_model_path): + """ Train model. + + """ + + # Load model + parameters = load_parameters(in_model_path) + model = compile_model() + n = len(parameters)//2 + model.coefs_ = parameters[:n] + model.intercepts_ = parameters[n:] + + # Train + X_train, y_train, _, _ = make_data() + epochs = 10 + for i in range(epochs): + model.partial_fit(X_train, y_train) + + # Metadata needed for aggregation server side + metadata = { + 'num_examples': len(X_train), + } + + # Save JSON metadata file + save_metadata(metadata, out_model_path) + + # Save model update + save_parameters(model, out_model_path) + + +def validate(in_model_path, out_json_path): + """ Validate model. + + :param in_model_path: The path to the input model. + :type in_model_path: str + :param out_json_path: The path to save the output JSON to. + :type out_json_path: str + :param data_path: The path to the data file. + :type data_path: str + """ + parameters = load_parameters(in_model_path) + model = compile_model() + n = len(parameters)//2 + model.coefs_ = parameters[:n] + model.intercepts_ = parameters[n:] + + X_train, y_train, X_test, y_test = make_data() + + # JSON schema + report = { + "accuracy_score": accuracy_score(y_test, model.predict(X_test)), + } + + # Save JSON + save_metrics(report, out_json_path) + + +if __name__ == '__main__': + fire.Fire({ + 'init_seed': init_seed, + 'train': train, + 'validate': validate + }) diff --git a/examples/async-clients/client/fedn.yaml b/examples/async-clients/client/fedn.yaml new file mode 100644 index 000000000..09002ea0c --- /dev/null +++ b/examples/async-clients/client/fedn.yaml @@ -0,0 +1,5 @@ +entry_points: + train: + command: python entrypoint.py train $ENTRYPOINT_OPTS + validate: + command: python entrypoint.py validate $ENTRYPOINT_OPTS \ No newline at end of file diff --git a/examples/async-clients/img/async-clients.png b/examples/async-clients/img/async-clients.png new file mode 100644 index 000000000..71ae8c123 Binary files /dev/null and b/examples/async-clients/img/async-clients.png differ diff --git a/examples/async-simulation/init_fedn.py b/examples/async-clients/init_fedn.py similarity index 57% rename from examples/async-simulation/init_fedn.py rename to examples/async-clients/init_fedn.py index 23078fcd9..2aa298602 100644 --- a/examples/async-simulation/init_fedn.py +++ b/examples/async-clients/init_fedn.py @@ -4,5 +4,5 @@ DISCOVER_PORT = 8092 client = APIClient(DISCOVER_HOST, DISCOVER_PORT) -client.set_package('package.tgz', 'numpyhelper') -client.set_initial_model('seed.npz') +client.set_active_package('package.tgz', 'numpyhelper') +client.set_active_model('seed.npz') diff --git a/examples/async-clients/requirements.txt b/examples/async-clients/requirements.txt new file mode 100644 index 000000000..7529e3699 --- /dev/null +++ b/examples/async-clients/requirements.txt @@ -0,0 +1,3 @@ +fire==0.3.1 +numpy +scikit-learn \ No newline at end of file diff --git a/examples/async-clients/run_clients.py b/examples/async-clients/run_clients.py new file mode 100644 index 000000000..2293c6be2 --- /dev/null +++ b/examples/async-clients/run_clients.py @@ -0,0 +1,77 @@ +"""This scripts starts N_CLIENTS using the SDK. + + + + + +If you are running with a local deploy of FEDn +using docker compose, you need to make sure that clients +are able to resolve the name "combiner" to 127.0.0.1 + +One way to accomplish this is to edit your /etc/host, +adding the line: + +combiner 127.0.0.1 + +(this requires root previliges) +""" + +import copy +import time +from multiprocessing import Process + +import numpy as np + +from fedn.network.clients.client import Client + +settings = { + 'DISCOVER_HOST': '127.0.0.1', + 'DISCOVER_PORT': 8092, + 'N_CLIENTS': 10, + 'N_CYCLES': 100, + 'CLIENTS_MAX_DELAY': 10, + 'CLIENTS_ONLINE_FOR_SECONDS': 120 +} + +client_config = {'discover_host': settings['DISCOVER_HOST'], 'discover_port': settings['DISCOVER_PORT'], 'token': None, 'name': 'testclient', + 'client_id': 1, 'remote_compute_context': True, 'force_ssl': False, 'dry_run': False, 'secure': False, + 'preshared_cert': False, 'verify': False, 'preferred_combiner': False, + 'validator': True, 'trainer': True, 'init': None, 'logfile': 'test.log', 'heartbeat_interval': 2, + 'reconnect_after_missed_heartbeat': 30} + + +def run_client(online_for=120, name='client'): + """ Simulates a client that starts and stops + at random intervals. + + The client will start after a radom time 'mean_delay', + stay online for 'online_for' seconds (deterministic), + then disconnect. + + This is repeated for N_CYCLES. + + """ + + conf = copy.deepcopy(client_config) + conf['name'] = name + + for i in range(settings['N_CYCLES']): + # Sample a delay until the client starts + t_start = np.random.randint(0, settings['CLIENTS_MAX_DELAY']) + time.sleep(t_start) + fl_client = Client(conf) + time.sleep(online_for) + fl_client.disconnect() + + +if __name__ == '__main__': + + # We start N_CLIENTS independent client processes + processes = [] + for i in range(settings['N_CLIENTS']): + p = Process(target=run_client, args=(settings['CLIENTS_ONLINE_FOR_SECONDS'], 'client{}'.format(i),)) + processes.append(p) + p.start() + + for p in processes: + p.join() diff --git a/examples/async-clients/run_experiment.py b/examples/async-clients/run_experiment.py new file mode 100644 index 000000000..d8d12dca2 --- /dev/null +++ b/examples/async-clients/run_experiment.py @@ -0,0 +1,34 @@ +import time +import uuid + +from fedn import APIClient + +DISCOVER_HOST = '127.0.0.1' +DISCOVER_PORT = 8092 +client = APIClient(DISCOVER_HOST, DISCOVER_PORT) + +if __name__ == '__main__': + + # Run six sessions, each with 100 rounds. + num_sessions = 6 + for s in range(num_sessions): + + session_config = { + "helper": "numpyhelper", + "id": str(uuid.uuid4()), + "aggregator": "fedopt", + "round_timeout": 20, + "rounds": 100, + "validate": False, + } + + session = client.start_session(**session_config) + if session['success'] is False: + print(session['message']) + exit(0) + + print("Started session: {}".format(session)) + + # Wait for session to finish + while not client.session_is_finished(session_config['id']): + time.sleep(2) diff --git a/examples/async-simulation/Experiment.ipynb b/examples/async-simulation/Experiment.ipynb deleted file mode 100644 index 51ec4e9e7..000000000 --- a/examples/async-simulation/Experiment.ipynb +++ /dev/null @@ -1,186 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "622f7047", - "metadata": {}, - "source": [ - "## FEDn API Example\n", - "\n", - "This notebook provides an example of how to use the FEDn API to organize experiments and to analyze validation results. We will here run one training session using FedAvg and one session using FedAdam and compare the results.\n", - "\n", - "When you start this tutorial you should have a deployed FEDn Network up and running, and you should have created the compute package and the initial model, see the README for instructions." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "743dfe47", - "metadata": {}, - "outputs": [], - "source": [ - "from fedn import APIClient\n", - "from fedn.network.clients.client import Client\n", - "import uuid\n", - "import json\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import collections\n", - "import copy" - ] - }, - { - "cell_type": "markdown", - "id": "1046a4e5", - "metadata": {}, - "source": [ - "We make a client connection to the FEDn API service. Here we assume that FEDn is deployed locally in pseudo-distributed mode with default ports." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "1061722d", - "metadata": {}, - "outputs": [], - "source": [ - "DISCOVER_HOST = '127.0.0.1'\n", - "DISCOVER_PORT = 8092\n", - "client = APIClient(DISCOVER_HOST, DISCOVER_PORT)" - ] - }, - { - "cell_type": "markdown", - "id": "07f69f5f", - "metadata": {}, - "source": [ - "Initialize FEDn with the compute package and seed model. Note that these files needs to be created separately by follwing instructions in the README." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "5107f6f9", - "metadata": {}, - "outputs": [], - "source": [ - "client.set_package('package.tgz', 'numpyhelper')\n", - "client.set_initial_model('seed.npz')\n", - "seed_model = client.get_initial_model()" - ] - }, - { - "cell_type": "markdown", - "id": "4e26c50b", - "metadata": {}, - "source": [ - "Next we start a training session using FedAvg:" - ] - }, - { - "cell_type": "code", -<<<<<<< HEAD - "execution_count": 9, -======= - "execution_count": 74, ->>>>>>> master - "id": "f0380d35", - "metadata": {}, - "outputs": [], - "source": [ - "session_config_fedavg = {\n", - " \"helper\": \"numpyhelper\",\n", - " \"session_id\": \"experiment_fedavg6\",\n", - " \"aggregator\": \"fedavg\",\n", - " \"model_id\": seed_model['model_id'],\n", - " \"rounds\": 1,\n", - " }\n", - "\n", - "result_fedavg = client.start_session(**session_config_fedavg)" - ] - }, - { - "cell_type": "markdown", - "id": "29552af9", - "metadata": {}, - "source": [ - "Next, we retrive all model validations from all clients, extract the training accuracy metric, and compute its mean value accross all clients" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "11fd17ef", - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "models = client.list_models(session_id = \"experiment_fedavg\")\n", - "\n", - "validations = []\n", - "acc = collections.OrderedDict()\n", - "for model in models[\"result\"]:\n", - " model_id = model[\"model\"]\n", - " validations = client.list_validations(modelId=model_id)\n", - "\n", - " for _ , validation in validations.items(): \n", - " metrics = json.loads(validation['data'])\n", - " try:\n", - " acc[model_id].append(metrics['training_accuracy'])\n", - " except KeyError: \n", - " acc[model_id] = [metrics['training_accuracy']]\n", - " \n", - "mean_acc_fedavg = []\n", - "for model, data in acc.items():\n", - " mean_acc_fedavg.append(np.mean(data))\n", - "mean_acc_fedavg.reverse()" - ] - }, - { - "cell_type": "markdown", - "id": "40db4542", - "metadata": {}, - "source": [ - "Finally, plot the result." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d064aaf9", - "metadata": {}, - "outputs": [], - "source": [ - "x = range(1,len(mean_acc_fedavg)+1)\n", - "plt.plot(x, mean_acc_fedavg)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.6" - }, - "vscode": { - "interpreter": { - "hash": "21345b455230dd04cf84c108e7c182ecfe8d1aa1242b8b64881a6d2c0a5951ac" - } - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/examples/async-simulation/README.md b/examples/async-simulation/README.md deleted file mode 100644 index b5cbfe2ed..000000000 --- a/examples/async-simulation/README.md +++ /dev/null @@ -1,53 +0,0 @@ -# ASYNC SIMULATION -This example is intended as a test for asynchronous clients. - -## Prerequisites -- [Python 3.8, 3.9 or 3.10](https://www.python.org/downloads) -- [Docker](https://docs.docker.com/get-docker) -- [Docker Compose](https://docs.docker.com/compose/install) - -## Running the example (pseudo-distributed, single host) - -Clone FEDn and locate into this directory. -```sh -git clone https://github.com/scaleoutsystems/fedn.git -cd fedn/examples/async-simulation -``` - -### Preparing the environment, the local data, the compute package and seed model - -Install FEDn and dependencies (we recommend using a virtual environment): - -Standing in the folder 'fedn/fedn' - -``` -pip install -e . -``` - -From examples/async-simulation -``` -pip install -r requirements.txt -``` - -Create the compute package and a seed model that you will be asked to upload in the next step. -``` -tar -czvf package.tgz client -``` - -``` -python client/entrypoint init_seed -``` - -### Deploy FEDn and two clients -docker-compose -f ../../docker-compose.yaml -f docker-compose.override.yaml up - -### Initialize the federated model -See 'Experiments.pynb' or 'launch_client.py' to set the package and seed model. - -> **Note**: run with `--scale client=N` to start *N* clients. - -### Run federated training -See 'Experiment.ipynb'. - -## Clean up -You can clean up by running `docker-compose down -v`. diff --git a/examples/async-simulation/client/entrypoint b/examples/async-simulation/client/entrypoint deleted file mode 100644 index dd2216fc0..000000000 --- a/examples/async-simulation/client/entrypoint +++ /dev/null @@ -1,98 +0,0 @@ -# /bin/python -import time - -import fire -import numpy as np - -from fedn.utils.helpers.helpers import get_helper, save_metadata, save_metrics - -HELPER_MODULE = 'numpyhelper' -ARRAY_SIZE = 1000000 - - -def save_model(weights, out_path): - """ Save model to disk. - - :param model: The model to save. - :type model: torch.nn.Module - :param out_path: The path to save to. - :type out_path: str - """ - helper = get_helper(HELPER_MODULE) - helper.save(weights, out_path) - - -def load_model(model_path): - """ Load model from disk. - - param model_path: The path to load from. - :type model_path: str - :return: The loaded model. - :rtype: torch.nn.Module - """ - helper = get_helper(HELPER_MODULE) - weights = helper.load(model_path) - return weights - - -def init_seed(out_path='seed.npz'): - """ Initialize seed model. - - :param out_path: The path to save the seed model to. - :type out_path: str - """ - # Init and save - weights = [np.random.rand(1, ARRAY_SIZE)] - save_model(weights, out_path) - - -def train(in_model_path, out_model_path): - """ Train model. - - """ - - # Load model - weights = load_model(in_model_path) - - # Train - time.sleep(np.random.randint(4, 15)) - - # Metadata needed for aggregation server side - metadata = { - 'num_examples': ARRAY_SIZE, - } - - # Save JSON metadata file - save_metadata(metadata, out_model_path) - - # Save model update - save_model(weights, out_model_path) - - -def validate(in_model_path, out_json_path): - """ Validate model. - - :param in_model_path: The path to the input model. - :type in_model_path: str - :param out_json_path: The path to save the output JSON to. - :type out_json_path: str - :param data_path: The path to the data file. - :type data_path: str - """ - weights = load_model(in_model_path) - - # JSON schema - report = { - "mean": np.mean(weights), - } - - # Save JSON - save_metrics(report, out_json_path) - - -if __name__ == '__main__': - fire.Fire({ - 'init_seed': init_seed, - 'train': train, - 'validate': validate - }) diff --git a/examples/async-simulation/client/fedn.yaml b/examples/async-simulation/client/fedn.yaml deleted file mode 100644 index 68cb70cef..000000000 --- a/examples/async-simulation/client/fedn.yaml +++ /dev/null @@ -1,5 +0,0 @@ -entry_points: - train: - command: /venv/bin/python entrypoint train $ENTRYPOINT_OPTS - validate: - command: /venv/bin/python entrypoint validate $ENTRYPOINT_OPTS \ No newline at end of file diff --git a/examples/async-simulation/docker-compose.override.yaml b/examples/async-simulation/docker-compose.override.yaml deleted file mode 100644 index 61034ce69..000000000 --- a/examples/async-simulation/docker-compose.override.yaml +++ /dev/null @@ -1,15 +0,0 @@ -# Compose schema version -version: '3.3' - -# Overriding requirements -services: - client: - build: - args: - REQUIREMENTS: examples/async-simulation/requirements.txt - deploy: - replicas: 2 - volumes: - - ${HOST_REPO_DIR:-.}/fedn:/app/fedn - - ${HOST_REPO_DIR:-.}/examples/async-simulation/data:/var/data - - /var/run/docker.sock:/var/run/docker.sock diff --git a/examples/async-simulation/launch_clients.py b/examples/async-simulation/launch_clients.py deleted file mode 100644 index 6cffbedd3..000000000 --- a/examples/async-simulation/launch_clients.py +++ /dev/null @@ -1,41 +0,0 @@ -"""This scripts starts N_CLIENTS using the SDK. - -If you are running with a local deploy of FEDn -using docker compose, you need to make sure that clients -are able to resolve the name "combiner" to 127.0.0.1 - -One way to accomplish this is to edit your /etc/host, -adding the line: - -combiner 127.0.0.1 - -""" - - -import copy -import time - -from fedn.network.clients.client import Client - -DISCOVER_HOST = '127.0.0.1' -DISCOVER_PORT = 8092 -N_CLIENTS = 5 -CLIENTS_AVAILABLE_TIME = 120 - -config = {'discover_host': DISCOVER_HOST, 'discover_port': DISCOVER_PORT, 'token': None, 'name': 'testclient', - 'client_id': 1, 'remote_compute_context': True, 'force_ssl': False, 'dry_run': False, 'secure': False, - 'preshared_cert': False, 'verify': False, 'preferred_combiner': False, - 'validator': True, 'trainer': True, 'init': None, 'logfile': 'test.log', 'heartbeat_interval': 2, - 'reconnect_after_missed_heartbeat': 30} - -# Start up N_CLIENTS clients -clients = [] -for i in range(N_CLIENTS): - config_i = copy.deepcopy(config) - config['name'] = 'client{}'.format(i) - clients.append(Client(config)) - -# Disconnect clients after some time -time.sleep(CLIENTS_AVAILABLE_TIME) -for client in clients: - client.detach() diff --git a/examples/async-simulation/requirements.txt b/examples/async-simulation/requirements.txt deleted file mode 100644 index c6bceff1d..000000000 --- a/examples/async-simulation/requirements.txt +++ /dev/null @@ -1 +0,0 @@ -fire==0.3.1 \ No newline at end of file diff --git a/fedn/fedn/common/config.py b/fedn/fedn/common/config.py index f6c827d0d..0a261c4af 100644 --- a/fedn/fedn/common/config.py +++ b/fedn/fedn/common/config.py @@ -5,6 +5,17 @@ global STATESTORE_CONFIG global MODELSTORAGE_CONFIG +SECRET_KEY = os.environ.get('FEDN_JWT_SECRET_KEY', False) +FEDN_JWT_CUSTOM_CLAIM_KEY = os.environ.get('FEDN_JWT_CUSTOM_CLAIM_KEY', False) +FEDN_JWT_CUSTOM_CLAIM_VALUE = os.environ.get('FEDN_JWT_CUSTOM_CLAIM_VALUE', False) + +FEDN_AUTH_WHITELIST_URL_PREFIX = os.environ.get('FEDN_AUTH_WHITELIST_URL_PREFIX', False) +FEDN_JWT_ALGORITHM = os.environ.get('FEDN_JWT_ALGORITHM', 'HS256') +FEDN_AUTH_SCHEME = os.environ.get('FEDN_AUTH_SCHEME', 'Token') +FEDN_AUTH_REFRESH_TOKEN_URI = os.environ.get('FEDN_AUTH_REFRESH_TOKEN_URI', False) +FEDN_AUTH_REFRESH_TOKEN = os.environ.get('FEDN_AUTH_REFRESH_TOKEN', False) +FEDN_CUSTOM_URL_PREFIX = os.environ.get('FEDN_CUSTOM_URL_PREFIX', '') + def get_environment_config(): """ Get the configuration from environment variables. diff --git a/fedn/fedn/network/api/auth.py b/fedn/fedn/network/api/auth.py new file mode 100644 index 000000000..bf43c2f69 --- /dev/null +++ b/fedn/fedn/network/api/auth.py @@ -0,0 +1,70 @@ +from functools import wraps + +import jwt +from flask import jsonify, request + +from fedn.common.config import (FEDN_AUTH_SCHEME, + FEDN_AUTH_WHITELIST_URL_PREFIX, + FEDN_JWT_ALGORITHM, FEDN_JWT_CUSTOM_CLAIM_KEY, + FEDN_JWT_CUSTOM_CLAIM_VALUE, SECRET_KEY) + + +def check_role_claims(payload, role): + if 'role' not in payload: + return False + if payload['role'] != role: + return False + + return True + + +def check_custom_claims(payload): + if FEDN_JWT_CUSTOM_CLAIM_KEY and FEDN_JWT_CUSTOM_CLAIM_VALUE: + if payload[FEDN_JWT_CUSTOM_CLAIM_KEY] != FEDN_JWT_CUSTOM_CLAIM_VALUE: + return False + return True + + +def if_whitelisted_url_prefix(path): + if FEDN_AUTH_WHITELIST_URL_PREFIX and path.startswith(FEDN_AUTH_WHITELIST_URL_PREFIX): + return True + else: + return False + + +def jwt_auth_required(role=None): + def actual_decorator(func): + if not SECRET_KEY: + return func + + @wraps(func) + def decorated(*args, **kwargs): + if if_whitelisted_url_prefix(request.path): + return func(*args, **kwargs) + token = request.headers.get('Authorization') + if not token: + return jsonify({'message': 'Missing token'}), 401 + # Get token from the header Bearer + if token.startswith(FEDN_AUTH_SCHEME): + token = token.split(' ')[1] + else: + return jsonify({'message': + f'Invalid token scheme, expected {FEDN_AUTH_SCHEME}' + }), 401 + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[FEDN_JWT_ALGORITHM]) + if not check_role_claims(payload, role): + return jsonify({'message': 'Invalid token'}), 401 + if not check_custom_claims(payload): + return jsonify({'message': 'Invalid token'}), 401 + + except jwt.ExpiredSignatureError: + return jsonify({'message': 'Token expired'}), 401 + + except jwt.InvalidTokenError: + return jsonify({'message': 'Invalid token'}), 401 + + return func(*args, **kwargs) + + return decorated + return actual_decorator diff --git a/fedn/fedn/network/api/client.py b/fedn/fedn/network/api/client.py index b33c38626..58337137f 100644 --- a/fedn/fedn/network/api/client.py +++ b/fedn/fedn/network/api/client.py @@ -27,7 +27,7 @@ def __init__(self, host, port=None, secure=False, verify=False, token=None, auth # Auth scheme passed as argument overrides environment variable. # "Token" is the default auth scheme. if not auth_scheme: - auth_scheme = os.environ.get("FEDN_AUTH_SCHEME", "Token") + auth_scheme = os.environ.get("FEDN_AUTH_SCHEME", "Bearer") # Override potential env variable if token is passed as argument. if not token: token = os.environ.get("FEDN_AUTH_TOKEN", False) diff --git a/fedn/fedn/network/api/server.py b/fedn/fedn/network/api/server.py index 597feca20..4e6d1a7e2 100644 --- a/fedn/fedn/network/api/server.py +++ b/fedn/fedn/network/api/server.py @@ -1,17 +1,13 @@ +import os + from flasgger import Swagger from flask import Flask, jsonify, request from fedn.common.config import (get_controller_config, get_modelstorage_config, get_network_config, get_statestore_config) +from fedn.network.api.auth import jwt_auth_required from fedn.network.api.interface import API -from fedn.network.api.v1.client_routes import bp as client_bp -from fedn.network.api.v1.combiner_routes import bp as combiner_bp -from fedn.network.api.v1.model_routes import bp as model_bp -from fedn.network.api.v1.package_routes import bp as package_bp -from fedn.network.api.v1.round_routes import bp as round_bp -from fedn.network.api.v1.session_routes import bp as session_bp -from fedn.network.api.v1.status_routes import bp as status_bp -from fedn.network.api.v1.validation_routes import bp as validation_bp +from fedn.network.api.v1 import _routes from fedn.network.controller.control import Control from fedn.network.storage.statestore.mongostatestore import MongoStateStore @@ -21,30 +17,31 @@ statestore = MongoStateStore(network_id, statestore_config["mongo_config"]) statestore.set_storage_backend(modelstorage_config) control = Control(statestore=statestore) + +custom_url_prefix = os.environ.get("FEDN_CUSTOM_URL_PREFIX", False) api = API(statestore, control) app = Flask(__name__) -app.register_blueprint(client_bp) -app.register_blueprint(status_bp) -app.register_blueprint(model_bp) -app.register_blueprint(validation_bp) -app.register_blueprint(package_bp) -app.register_blueprint(session_bp) -app.register_blueprint(combiner_bp) -app.register_blueprint(round_bp) +for bp in _routes: + app.register_blueprint(bp) + if custom_url_prefix: + app.register_blueprint(bp, + name=f"{bp.name}_custom", + url_prefix=f"{custom_url_prefix}{bp.url_prefix}") template = { - "swagger": "2.0", - "info": { - "title": "FEDn API", - "description": "API for the FEDn network.", - "version": "0.0.1" - } + "swagger": "2.0", + "info": { + "title": "FEDn API", + "description": "API for the FEDn network.", + "version": "0.0.1" + } } swagger = Swagger(app, template=template) @app.route("/get_model_trail", methods=["GET"]) +@jwt_auth_required(role="admin") def get_model_trail(): """Get the model trail for a given session. param: session: The session id to get the model trail for. @@ -55,7 +52,12 @@ def get_model_trail(): return api.get_model_trail() +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_model_trail", view_func=get_model_trail, methods=["GET"]) + + @app.route("/get_model_ancestors", methods=["GET"]) +@jwt_auth_required(role="admin") def get_model_ancestors(): """Get the ancestors of a model. param: model: The model id to get the ancestors for. @@ -71,7 +73,12 @@ def get_model_ancestors(): return api.get_model_ancestors(model, limit) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_model_ancestors", view_func=get_model_ancestors, methods=["GET"]) + + @app.route("/get_model_descendants", methods=["GET"]) +@jwt_auth_required(role="admin") def get_model_descendants(): """Get the ancestors of a model. param: model: The model id to get the child for. @@ -87,7 +94,12 @@ def get_model_descendants(): return api.get_model_descendants(model, limit) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_model_descendants", view_func=get_model_descendants, methods=["GET"]) + + @app.route("/list_models", methods=["GET"]) +@jwt_auth_required(role="admin") def list_models(): """Get models from the statestore. param: @@ -108,7 +120,12 @@ def list_models(): return api.get_models(session_id, limit, skip, include_active) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/list_models", view_func=list_models, methods=["GET"]) + + @app.route("/get_model", methods=["GET"]) +@jwt_auth_required(role="admin") def get_model(): """Get a model from the statestore. param: model: The model id to get. @@ -123,7 +140,12 @@ def get_model(): return api.get_model(model) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_model", view_func=get_model, methods=["GET"]) + + @app.route("/delete_model_trail", methods=["GET", "POST"]) +@jwt_auth_required(role="admin") def delete_model_trail(): """Delete the model trail for a given session. param: session: The session id to delete the model trail for. @@ -134,7 +156,12 @@ def delete_model_trail(): return jsonify({"message": "Not implemented"}), 501 +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/delete_model_trail", view_func=delete_model_trail, methods=["GET", "POST"]) + + @app.route("/list_clients", methods=["GET"]) +@jwt_auth_required(role="admin") def list_clients(): """Get all clients from the statestore. return: All clients as a json object. @@ -148,7 +175,12 @@ def list_clients(): return api.get_clients(limit, skip, status) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/list_clients", view_func=list_clients, methods=["GET"]) + + @app.route("/get_active_clients", methods=["GET"]) +@jwt_auth_required(role="admin") def get_active_clients(): """Get all active clients from the statestore. param: combiner_id: The combiner id to get active clients for. @@ -165,7 +197,12 @@ def get_active_clients(): return api.get_active_clients(combiner_id) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_active_clients", view_func=get_active_clients, methods=["GET"]) + + @app.route("/list_combiners", methods=["GET"]) +@jwt_auth_required(role="admin") def list_combiners(): """Get all combiners in the network. return: All combiners as a json object. @@ -178,7 +215,12 @@ def list_combiners(): return api.get_all_combiners(limit, skip) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/list_combiners", view_func=list_combiners, methods=["GET"]) + + @app.route("/get_combiner", methods=["GET"]) +@jwt_auth_required(role="admin") def get_combiner(): """Get a combiner from the statestore. param: combiner_id: The combiner id to get. @@ -195,7 +237,12 @@ def get_combiner(): return api.get_combiner(combiner_id) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_combiner", view_func=get_combiner, methods=["GET"]) + + @app.route("/list_rounds", methods=["GET"]) +@jwt_auth_required(role="admin") def list_rounds(): """Get all rounds from the statestore. return: All rounds as a json object. @@ -204,7 +251,12 @@ def list_rounds(): return api.get_all_rounds() +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/list_rounds", view_func=list_rounds, methods=["GET"]) + + @app.route("/get_round", methods=["GET"]) +@jwt_auth_required(role="admin") def get_round(): """Get a round from the statestore. param: round_id: The round id to get. @@ -218,7 +270,12 @@ def get_round(): return api.get_round(round_id) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_round", view_func=get_round, methods=["GET"]) + + @app.route("/start_session", methods=["GET", "POST"]) +@jwt_auth_required(role="admin") def start_session(): """Start a new session. return: The response from control. @@ -228,7 +285,12 @@ def start_session(): return api.start_session(**json_data) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/start_session", view_func=start_session, methods=["GET", "POST"]) + + @app.route("/list_sessions", methods=["GET"]) +@jwt_auth_required(role="admin") def list_sessions(): """Get all sessions from the statestore. return: All sessions as a json object. @@ -240,7 +302,12 @@ def list_sessions(): return api.get_all_sessions(limit, skip) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/list_sessions", view_func=list_sessions, methods=["GET"]) + + @app.route("/get_session", methods=["GET"]) +@jwt_auth_required(role="admin") def get_session(): """Get a session from the statestore. param: session_id: The session id to get. @@ -257,13 +324,23 @@ def get_session(): return api.get_session(session_id) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_session", view_func=get_session, methods=["GET"]) + + @app.route("/set_active_package", methods=["PUT"]) +@jwt_auth_required(role="admin") def set_active_package(): id = request.args.get("id", None) return api.set_active_compute_package(id) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/set_active_package", view_func=set_active_package, methods=["PUT"]) + + @app.route("/set_package", methods=["POST"]) +@jwt_auth_required(role="admin") def set_package(): """ Set the compute package in the statestore. Usage with curl: @@ -295,7 +372,12 @@ def set_package(): ) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/set_package", view_func=set_package, methods=["POST"]) + + @app.route("/get_package", methods=["GET"]) +@jwt_auth_required(role="admin") def get_package(): """Get the compute package from the statestore. return: The compute package as a json object. @@ -304,7 +386,12 @@ def get_package(): return api.get_compute_package() +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_package", view_func=get_package, methods=["GET"]) + + @app.route("/list_compute_packages", methods=["GET"]) +@jwt_auth_required(role="admin") def list_compute_packages(): """Get the compute package from the statestore. return: The compute package as a json object. @@ -320,7 +407,12 @@ def list_compute_packages(): ) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/list_compute_packages", view_func=list_compute_packages, methods=["GET"]) + + @app.route("/download_package", methods=["GET"]) +@jwt_auth_required(role="client") def download_package(): """Download the compute package. return: The compute package as a json object. @@ -330,13 +422,23 @@ def download_package(): return api.download_compute_package(name) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/download_package", view_func=download_package, methods=["GET"]) + + @app.route("/get_package_checksum", methods=["GET"]) +@jwt_auth_required(role="client") def get_package_checksum(): name = request.args.get("name", None) return api.get_checksum(name) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_package_checksum", view_func=get_package_checksum, methods=["GET"]) + + @app.route("/get_latest_model", methods=["GET"]) +@jwt_auth_required(role="admin") def get_latest_model(): """Get the latest model from the statestore. return: The initial model as a json object. @@ -345,7 +447,12 @@ def get_latest_model(): return api.get_latest_model() +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_latest_model", view_func=get_latest_model, methods=["GET"]) + + @app.route("/set_current_model", methods=["PUT"]) +@jwt_auth_required(role="admin") def set_current_model(): """Set the initial model in the statestore and upload to model repository. Usage with curl: @@ -364,10 +471,14 @@ def set_current_model(): return api.set_current_model(id) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/set_current_model", view_func=set_current_model, methods=["PUT"]) + # Get initial model endpoint @app.route("/get_initial_model", methods=["GET"]) +@jwt_auth_required(role="admin") def get_initial_model(): """Get the initial model from the statestore. return: The initial model as a json object. @@ -376,7 +487,12 @@ def get_initial_model(): return api.get_initial_model() +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_initial_model", view_func=get_initial_model, methods=["GET"]) + + @app.route("/set_initial_model", methods=["POST"]) +@jwt_auth_required(role="admin") def set_initial_model(): """Set the initial model in the statestore and upload to model repository. Usage with curl: @@ -396,7 +512,12 @@ def set_initial_model(): return api.set_initial_model(file) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/set_initial_model", view_func=set_initial_model, methods=["POST"]) + + @app.route("/get_controller_status", methods=["GET"]) +@jwt_auth_required(role="admin") def get_controller_status(): """Get the status of the controller. return: The status as a json object. @@ -405,7 +526,12 @@ def get_controller_status(): return api.get_controller_status() +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_controller_status", view_func=get_controller_status, methods=["GET"]) + + @app.route("/get_client_config", methods=["GET"]) +@jwt_auth_required(role="admin") def get_client_config(): """Get the client configuration. return: The client configuration as a json object. @@ -416,7 +542,12 @@ def get_client_config(): return api.get_client_config(checksum) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_client_config", view_func=get_client_config, methods=["GET"]) + + @app.route("/get_events", methods=["GET"]) +@jwt_auth_required(role="admin") def get_events(): """Get the events from the statestore. return: The events as a json object. @@ -428,7 +559,12 @@ def get_events(): return api.get_events(**kwargs) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_events", view_func=get_client_config, methods=["GET"]) + + @app.route("/list_validations", methods=["GET"]) +@jwt_auth_required(role="admin") def list_validations(): """Get all validations from the statestore. return: All validations as a json object. @@ -439,7 +575,12 @@ def list_validations(): return api.get_all_validations(**kwargs) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/list_validations", view_func=list_validations, methods=["GET"]) + + @app.route("/add_combiner", methods=["POST"]) +@jwt_auth_required(role="combiner") def add_combiner(): """Add a combiner to the network. return: The response from the statestore. @@ -454,7 +595,12 @@ def add_combiner(): return response +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/add_combiner", view_func=add_combiner, methods=["POST"]) + + @app.route("/add_client", methods=["POST"]) +@jwt_auth_required(role="client") def add_client(): """Add a client to the network. return: The response from control. @@ -470,7 +616,12 @@ def add_client(): return response +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/add_client", view_func=add_client, methods=["POST"]) + + @app.route("/list_combiners_data", methods=["POST"]) +@jwt_auth_required(role="admin") def list_combiners_data(): """List data from combiners. return: The response from control. @@ -489,7 +640,12 @@ def list_combiners_data(): return response +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/list_combiners_data", view_func=list_combiners_data, methods=["POST"]) + + @app.route("/get_plot_data", methods=["GET"]) +@jwt_auth_required(role="admin") def get_plot_data(): """Get plot data from the statestore. rtype: json @@ -503,6 +659,9 @@ def get_plot_data(): return response +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_plot_data", view_func=get_plot_data, methods=["GET"]) + if __name__ == "__main__": config = get_controller_config() port = config["port"] diff --git a/fedn/fedn/network/api/v1/__init__.py b/fedn/fedn/network/api/v1/__init__.py index e69de29bb..bb8d8d33c 100644 --- a/fedn/fedn/network/api/v1/__init__.py +++ b/fedn/fedn/network/api/v1/__init__.py @@ -0,0 +1,10 @@ +from fedn.network.api.v1.client_routes import bp as client_bp +from fedn.network.api.v1.combiner_routes import bp as combiner_bp +from fedn.network.api.v1.model_routes import bp as model_bp +from fedn.network.api.v1.package_routes import bp as package_bp +from fedn.network.api.v1.round_routes import bp as round_bp +from fedn.network.api.v1.session_routes import bp as session_bp +from fedn.network.api.v1.status_routes import bp as status_bp +from fedn.network.api.v1.validation_routes import bp as validation_bp + +_routes = [client_bp, combiner_bp, model_bp, package_bp, round_bp, session_bp, status_bp, validation_bp] diff --git a/fedn/fedn/network/api/v1/client_routes.py b/fedn/fedn/network/api/v1/client_routes.py index c4215bd54..30322a9b7 100644 --- a/fedn/fedn/network/api/v1/client_routes.py +++ b/fedn/fedn/network/api/v1/client_routes.py @@ -1,5 +1,6 @@ from flask import Blueprint, jsonify, request +from fedn.network.api.auth import jwt_auth_required from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs, get_typed_list_headers, mdb) from fedn.network.storage.statestore.stores.client_store import ClientStore @@ -11,6 +12,7 @@ @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_clients(): """Get clients Retrieves a list of clients based on the provided parameters. @@ -127,6 +129,7 @@ def get_clients(): @bp.route("/list", methods=["POST"]) +@jwt_auth_required(role="admin") def list_clients(): """List clients Retrieves a list of clients based on the provided parameters. @@ -213,6 +216,7 @@ def list_clients(): @bp.route("/count", methods=["GET"]) +@jwt_auth_required(role="admin") def get_clients_count(): """Clients count Retrieves the total number of clients based on the provided parameters. @@ -273,6 +277,7 @@ def get_clients_count(): @bp.route("/count", methods=["POST"]) +@jwt_auth_required(role="admin") def clients_count(): """Clients count Retrieves the total number of clients based on the provided parameters. @@ -325,6 +330,7 @@ def clients_count(): @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_client(id: str): """Get client Retrieves a client based on the provided id. diff --git a/fedn/fedn/network/api/v1/combiner_routes.py b/fedn/fedn/network/api/v1/combiner_routes.py index ba6bf5dbd..7d1761bee 100644 --- a/fedn/fedn/network/api/v1/combiner_routes.py +++ b/fedn/fedn/network/api/v1/combiner_routes.py @@ -1,5 +1,6 @@ from flask import Blueprint, jsonify, request +from fedn.network.api.auth import jwt_auth_required from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs, get_typed_list_headers, mdb) from fedn.network.storage.statestore.stores.combiner_store import CombinerStore @@ -11,6 +12,7 @@ @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_combiners(): """Get combiners Retrieves a list of combiners based on the provided parameters. @@ -119,6 +121,7 @@ def get_combiners(): @bp.route("/list", methods=["POST"]) +@jwt_auth_required(role="admin") def list_combiners(): """List combiners Retrieves a list of combiners based on the provided parameters. @@ -203,6 +206,7 @@ def list_combiners(): @bp.route("/count", methods=["GET"]) +@jwt_auth_required(role="admin") def get_combiners_count(): """Combiners count Retrieves the count of combiners based on the provided parameters. @@ -249,6 +253,7 @@ def get_combiners_count(): @bp.route("/count", methods=["POST"]) +@jwt_auth_required(role="admin") def combiners_count(): """Combiners count Retrieves the count of combiners based on the provided parameters. @@ -297,6 +302,7 @@ def combiners_count(): @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_combiner(id: str): """Get combiner Retrieves a combiner based on the provided id. diff --git a/fedn/fedn/network/api/v1/model_routes.py b/fedn/fedn/network/api/v1/model_routes.py index 2db99083c..8e9308408 100644 --- a/fedn/fedn/network/api/v1/model_routes.py +++ b/fedn/fedn/network/api/v1/model_routes.py @@ -3,6 +3,7 @@ import numpy as np from flask import Blueprint, jsonify, request, send_file +from fedn.network.api.auth import jwt_auth_required from fedn.network.api.v1.shared import (api_version, get_limit, get_post_data_to_kwargs, get_reverse, get_typed_list_headers, mdb, @@ -22,6 +23,7 @@ @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_models(): """Get models Retrieves a list of models based on the provided parameters. @@ -124,6 +126,7 @@ def get_models(): @bp.route("/list", methods=["POST"]) +@jwt_auth_required(role="admin") def list_models(): """List models Retrieves a list of models based on the provided parameters. @@ -210,6 +213,7 @@ def list_models(): @bp.route("/count", methods=["GET"]) +@jwt_auth_required(role="admin") def get_models_count(): """Models count Retrieves the count of models based on the provided parameters. @@ -257,6 +261,7 @@ def get_models_count(): @bp.route("/count", methods=["POST"]) +@jwt_auth_required(role="admin") def models_count(): """Models count Retrieves the count of models based on the provided parameters. @@ -308,6 +313,7 @@ def models_count(): @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_model(id: str): """Get model Retrieves a model based on the provided id. @@ -353,6 +359,7 @@ def get_model(id: str): @bp.route("//descendants", methods=["GET"]) +@jwt_auth_required(role="admin") def get_descendants(id: str): """Get model descendants Retrieves a list of model descendants of the provided model id/model property. @@ -406,6 +413,7 @@ def get_descendants(id: str): @bp.route("//ancestors", methods=["GET"]) +@jwt_auth_required(role="admin") def get_ancestors(id: str): """Get model ancestors Retrieves a list of model ancestors of the provided model id/model property. diff --git a/fedn/fedn/network/api/v1/package_routes.py b/fedn/fedn/network/api/v1/package_routes.py index b62aa96a9..30ac4d51e 100644 --- a/fedn/fedn/network/api/v1/package_routes.py +++ b/fedn/fedn/network/api/v1/package_routes.py @@ -1,5 +1,6 @@ from flask import Blueprint, jsonify, request +from fedn.network.api.auth import jwt_auth_required from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs, get_typed_list_headers, get_use_typing, mdb) @@ -12,6 +13,7 @@ @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_packages(): """Get packages Retrieves a list of packages based on the provided parameters. @@ -132,6 +134,7 @@ def get_packages(): @bp.route("/list", methods=["POST"]) +@jwt_auth_required(role="admin") def list_packages(): """List packages Retrieves a list of packages based on the provided parameters. @@ -221,6 +224,7 @@ def list_packages(): @bp.route("/count", methods=["GET"]) +@jwt_auth_required(role="admin") def get_packages_count(): """Package count Retrieves the count of packages based on the provided parameters. @@ -281,6 +285,7 @@ def get_packages_count(): @bp.route("/count", methods=["POST"]) +@jwt_auth_required(role="admin") def packages_count(): """Package count Retrieves the count of packages based on the provided parameters. @@ -342,6 +347,7 @@ def packages_count(): @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_package(id: str): """Get package Retrieves a package based on the provided id. @@ -388,6 +394,7 @@ def get_package(id: str): @bp.route("/active", methods=["GET"]) +@jwt_auth_required(role="admin") def get_active_package(): """Get active package Retrieves the active package diff --git a/fedn/fedn/network/api/v1/round_routes.py b/fedn/fedn/network/api/v1/round_routes.py index 317c767ee..8890c510a 100644 --- a/fedn/fedn/network/api/v1/round_routes.py +++ b/fedn/fedn/network/api/v1/round_routes.py @@ -1,5 +1,6 @@ from flask import Blueprint, jsonify, request +from fedn.network.api.auth import jwt_auth_required from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs, get_typed_list_headers, mdb) from fedn.network.storage.statestore.stores.round_store import RoundStore @@ -11,6 +12,7 @@ @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_rounds(): """Get rounds Retrieves a list of rounds based on the provided parameters. @@ -107,6 +109,7 @@ def get_rounds(): @bp.route("/list", methods=["POST"]) +@jwt_auth_required(role="admin") def list_rounds(): """List rounds Retrieves a list of rounds based on the provided parameters. @@ -187,6 +190,7 @@ def list_rounds(): @bp.route("/count", methods=["GET"]) +@jwt_auth_required(role="admin") def get_rounds_count(): """Rounds count Retrieves the count of rounds based on the provided parameters. @@ -227,6 +231,7 @@ def get_rounds_count(): @bp.route("/count", methods=["POST"]) +@jwt_auth_required(role="admin") def rounds_count(): """Rounds count Retrieves the count of rounds based on the provided parameters. @@ -271,6 +276,7 @@ def rounds_count(): @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_round(id: str): """Get round Retrieves a round based on the provided id. diff --git a/fedn/fedn/network/api/v1/session_routes.py b/fedn/fedn/network/api/v1/session_routes.py index 4d3fe493f..99c52d8db 100644 --- a/fedn/fedn/network/api/v1/session_routes.py +++ b/fedn/fedn/network/api/v1/session_routes.py @@ -1,5 +1,6 @@ from flask import Blueprint, jsonify, request +from fedn.network.api.auth import jwt_auth_required from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs, get_typed_list_headers, mdb) from fedn.network.storage.statestore.stores.session_store import SessionStore @@ -11,6 +12,7 @@ @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_sessions(): """Get sessions Retrieves a list of sessions based on the provided parameters. @@ -99,6 +101,7 @@ def get_sessions(): @bp.route("/list", methods=["POST"]) +@jwt_auth_required(role="admin") def list_sessions(): """List sessions Retrieves a list of sessions based on the provided parameters. @@ -178,6 +181,7 @@ def list_sessions(): @bp.route("/count", methods=["GET"]) +@jwt_auth_required(role="admin") def get_sessions_count(): """Sessions count Retrieves the count of sessions based on the provided parameters. @@ -218,6 +222,7 @@ def get_sessions_count(): @bp.route("/count", methods=["POST"]) +@jwt_auth_required(role="admin") def sessions_count(): """Sessions count Retrieves the count of sessions based on the provided parameters. @@ -262,6 +267,7 @@ def sessions_count(): @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_session(id: str): """Get session Retrieves a session based on the provided id. diff --git a/fedn/fedn/network/api/v1/status_routes.py b/fedn/fedn/network/api/v1/status_routes.py index 562b971db..e78c18533 100644 --- a/fedn/fedn/network/api/v1/status_routes.py +++ b/fedn/fedn/network/api/v1/status_routes.py @@ -1,5 +1,6 @@ from flask import Blueprint, jsonify, request +from fedn.network.api.auth import jwt_auth_required from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs, get_typed_list_headers, get_use_typing, mdb) @@ -12,6 +13,7 @@ @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_statuses(): """Get statuses Retrieves a list of statuses based on the provided parameters. @@ -144,6 +146,7 @@ def get_statuses(): @bp.route("/list", methods=["POST"]) +@jwt_auth_required(role="admin") def list_statuses(): """Get statuses Retrieves a list of statuses based on the provided parameters. @@ -246,6 +249,7 @@ def list_statuses(): @bp.route("/count", methods=["GET"]) +@jwt_auth_required(role="admin") def get_statuses_count(): """Statuses count Retrieves the count of statuses based on the provided parameters. @@ -307,6 +311,7 @@ def get_statuses_count(): @bp.route("/count", methods=["POST"]) +@jwt_auth_required(role="admin") def statuses_count(): """Statuses count Retrieves the count of statuses based on the provided parameters. @@ -368,6 +373,7 @@ def statuses_count(): @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_status(id: str): """Get status Retrieves a status based on the provided id. diff --git a/fedn/fedn/network/api/v1/validation_routes.py b/fedn/fedn/network/api/v1/validation_routes.py index 874154dbc..96fbac55c 100644 --- a/fedn/fedn/network/api/v1/validation_routes.py +++ b/fedn/fedn/network/api/v1/validation_routes.py @@ -1,5 +1,6 @@ from flask import Blueprint, jsonify, request +from fedn.network.api.auth import jwt_auth_required from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs, get_typed_list_headers, get_use_typing, mdb) @@ -13,6 +14,7 @@ @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_validations(): """Get validations Retrieves a list of validations based on the provided parameters. @@ -152,6 +154,7 @@ def get_validations(): @bp.route("/list", methods=["POST"]) +@jwt_auth_required(role="admin") def list_validations(): """Get validations Retrieves a list of validations based on the provided parameters. @@ -257,6 +260,7 @@ def list_validations(): @bp.route("/count", methods=["GET"]) +@jwt_auth_required(role="admin") def get_validations_count(): """Validations count Retrieves the count of validations based on the provided parameters. @@ -322,6 +326,7 @@ def get_validations_count(): @bp.route("/count", methods=["POST"]) +@jwt_auth_required(role="admin") def validations_count(): """Validations count Retrieves the count of validations based on the provided parameters. @@ -386,6 +391,7 @@ def validations_count(): @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_validation(id: str): """Get validation Retrieves a validation based on the provided id. diff --git a/fedn/fedn/network/clients/client.py b/fedn/fedn/network/clients/client.py index a6408d66c..54c06ca51 100644 --- a/fedn/fedn/network/clients/client.py +++ b/fedn/fedn/network/clients/client.py @@ -18,9 +18,11 @@ from cryptography.hazmat.primitives.serialization import Encoding from google.protobuf.json_format import MessageToJson from OpenSSL import SSL +from tenacity import retry, stop_after_attempt import fedn.network.grpc.fedn_pb2 as fedn import fedn.network.grpc.fedn_pb2_grpc as rpc +from fedn.common.config import FEDN_AUTH_SCHEME from fedn.common.log_config import (logger, set_log_level_from_string, set_log_stream) from fedn.network.clients.connect import ConnectorClient, Status @@ -39,7 +41,7 @@ def __init__(self, key): self._key = key def __call__(self, context, callback): - callback((('authorization', f'Token {self._key}'),), None) + callback((('authorization', f'{FEDN_AUTH_SCHEME} {self._key}'),), None) class Client: @@ -55,7 +57,7 @@ def __init__(self, config): """Initialize the client.""" self.state = None self.error_state = False - self._attached = False + self._connected = False self._missed_heartbeat = 0 self.config = config @@ -77,8 +79,9 @@ def __init__(self, config): if not match: raise ValueError('Unallowed character in client name. Allowed characters: a-z, A-Z, 0-9, _, -.') + # Folder where the client will store downloaded compute package and logs self.name = config['name'] - dirname = time.strftime("%Y%m%d-%H%M%S") + dirname = self.name+"-"+time.strftime("%Y%m%d-%H%M%S") self.run_path = os.path.join(os.getcwd(), dirname) os.mkdir(self.run_path) @@ -88,20 +91,21 @@ def __init__(self, config): self.inbox = queue.Queue() # Attach to the FEDn network (get combiner) - client_config = self._attach() + combiner_config = self.assign() + self.connect(combiner_config) self._initialize_dispatcher(config) - self._initialize_helper(client_config) + self._initialize_helper(combiner_config) if not self.helper: logger.warning("Failed to retrieve helper class settings: {}".format( - client_config)) + combiner_config)) self._subscribe_to_combiner(config) self.state = ClientState.idle - def _assign(self): + def assign(self): """Contacts the controller and asks for combiner assignment. :return: A configuration dictionary containing connection information for combiner. @@ -112,11 +116,12 @@ def _assign(self): while True: status, response = self.connector.assign() if status == Status.TryAgain: - logger.info(response) + logger.warning(response) + logger.info("Assignment request failed. Retrying in 5 seconds.") time.sleep(5) continue if status == Status.Assigned: - client_config = response + combiner_config = response break if status == Status.UnAuthorized: logger.critical(response) @@ -125,10 +130,9 @@ def _assign(self): logger.critical(response) sys.exit("Exiting: UnMatchedConfig") time.sleep(5) - logger.info("Assignment successfully received.") - logger.info("Received combiner configuration: {}".format(client_config)) - return client_config + logger.info("Received combiner configuration: {}".format(combiner_config)) + return combiner_config def _add_grpc_metadata(self, key, value): """Add metadata for gRPC calls. @@ -166,32 +170,36 @@ def _get_ssl_certificate(self, domain, port=443): cert = cert.to_cryptography().public_bytes(Encoding.PEM).decode() return cert - def _connect(self, client_config): - """Connect to assigned combiner. + def connect(self, combiner_config): + """Connect to combiner. - :param client_config: A configuration dictionary containing connection information for + :param combiner_config: A configuration dictionary containing connection information for the combiner. - :type client_config: dict + :type combiner_config: dict """ - # TODO use the client_config['certificate'] for setting up secure comms' - host = client_config['host'] + if self._connected: + logger.info("Client is already attached. ") + return None + + # TODO use the combiner_config['certificate'] for setting up secure comms' + host = combiner_config['host'] # Add host to gRPC metadata self._add_grpc_metadata('grpc-server', host) - logger.info("Client using metadata: {}.".format(self.metadata)) - port = client_config['port'] + logger.debug("Client using metadata: {}.".format(self.metadata)) + port = combiner_config['port'] secure = False - if client_config['fqdn'] is not None: - host = client_config['fqdn'] + if combiner_config['fqdn'] is not None: + host = combiner_config['fqdn'] # assuming https if fqdn is used port = 443 logger.info(f"Initiating connection to combiner host at: {host}:{port}") - if client_config['certificate']: + if combiner_config['certificate']: logger.info("Utilizing CA certificate for GRPC channel authentication.") secure = True cert = base64.b64decode( - client_config['certificate']) # .decode('utf-8') + combiner_config['certificate']) # .decode('utf-8') credentials = grpc.ssl_channel_credentials(root_certificates=cert) channel = grpc.secure_channel("{}:{}".format(host, str(port)), credentials) elif os.getenv("FEDN_GRPC_ROOT_CERT_PATH"): @@ -231,48 +239,31 @@ def _connect(self, client_config): port)) logger.info("Using {} compute package.".format( - client_config["package"])) - - def _disconnect(self): - """Disconnect from the combiner.""" - self.channel.close() - - def detach(self): - """Detach from the FEDn network (disconnect from combiner)""" - # Setting _attached to False will make all processing threads return - if not self._attached: - logger.info("Client is not attached.") + combiner_config["package"])) - self._attached = False - # Close gRPC connection to combiner - self._disconnect() + self._connected = True - def _attach(self): - """Attach to the FEDn network (connect to combiner)""" - # Ask controller for a combiner and connect to that combiner. - if self._attached: - logger.info("Client is already attached. ") - return None - - client_config = self._assign() - self._connect(client_config) + def disconnect(self): + """Disconnect from the combiner.""" + if not self._connected: + logger.info("Client is not connected.") - if client_config: - self._attached = True - return client_config + self.channel.close() + self._connected = False + logger.info("Client {} disconnected.".format(self.name)) - def _initialize_helper(self, client_config): + def _initialize_helper(self, combiner_config): """Initialize the helper class for the client. - :param client_config: A configuration dictionary containing connection information for + :param combiner_config: A configuration dictionary containing connection information for | the discovery service (controller) and settings governing e.g. | client-combiner assignment behavior. - :type client_config: dict + :type combiner_config: dict :return: """ - if 'helper_type' in client_config.keys(): - self.helper = get_helper(client_config['helper_type']) + if 'helper_type' in combiner_config.keys(): + self.helper = get_helper(combiner_config['helper_type']) def _subscribe_to_combiner(self, config): """Listen to combiner message stream and start all processing threads. @@ -289,11 +280,15 @@ def _subscribe_to_combiner(self, config): # Start listening for combiner training and validation messages threading.Thread( target=self._listen_to_task_stream, daemon=True).start() - self._attached = True + self._connected = True # Start processing the client message inbox threading.Thread(target=self.process_request, daemon=True).start() + @retry(stop=stop_after_attempt(3)) + def untar_package(self, package_runtime): + package_runtime.unpack() + def _initialize_dispatcher(self, config): """ Initialize the dispatcher for the client. @@ -334,7 +329,8 @@ def _initialize_dispatcher(self, config): return if retval: - pr.unpack() + self.untar_package(pr) + # pr.unpack() self.dispatcher = pr.dispatcher(self.run_path) try: @@ -370,21 +366,24 @@ def get_model_from_combiner(self, id, timeout=20): request.sender.name = self.name request.sender.role = fedn.WORKER - for part in self.modelStub.Download(request, metadata=self.metadata): + try: + for part in self.modelStub.Download(request, metadata=self.metadata): - if part.status == fedn.ModelStatus.IN_PROGRESS: - data.write(part.data) + if part.status == fedn.ModelStatus.IN_PROGRESS: + data.write(part.data) - if part.status == fedn.ModelStatus.OK: - return data + if part.status == fedn.ModelStatus.OK: + return data - if part.status == fedn.ModelStatus.FAILED: - return None - - if part.status == fedn.ModelStatus.UNKNOWN: - if time.time() - time_start >= timeout: + if part.status == fedn.ModelStatus.FAILED: return None - continue + + if part.status == fedn.ModelStatus.UNKNOWN: + if time.time() - time_start >= timeout: + return None + continue + except grpc.RpcError as e: + logger.critical(f"GRPC: An error occurred during model download: {e}") return data @@ -409,7 +408,10 @@ def send_model_to_combiner(self, model, id): bt.seek(0, 0) - result = self.modelStub.Upload(upload_request_generator(bt, id), metadata=self.metadata) + try: + result = self.modelStub.Upload(upload_request_generator(bt, id), metadata=self.metadata) + except grpc.RpcError as e: + logger.critical(f"GRPC: An error occurred during model upload: {e}") return result @@ -426,15 +428,15 @@ def _listen_to_task_stream(self): # Add client to metadata self._add_grpc_metadata('client', self.name) - while self._attached: + while self._connected: try: for request in self.combinerStub.TaskStream(r, metadata=self.metadata): if request: logger.debug("Received model update request from combiner: {}.".format(request)) if request.sender.role == fedn.COMBINER: # Process training request - self._send_status("Received model update request.", log_level=fedn.Status.AUDIT, - type=fedn.StatusType.MODEL_UPDATE_REQUEST, request=request, sesssion_id=request.session_id) + self.send_status("Received model update request.", log_level=fedn.Status.AUDIT, + type=fedn.StatusType.MODEL_UPDATE_REQUEST, request=request, sesssion_id=request.session_id) logger.info("Received model update request of type {} for model_id {}".format(request.type, request.model_id)) if request.type == fedn.StatusType.MODEL_UPDATE and self.config['trainer']: @@ -448,19 +450,29 @@ def _listen_to_task_stream(self): # Handle gRPC errors status_code = e.code() if status_code == grpc.StatusCode.UNAVAILABLE: - logger.warning("GRPC server unavailable during model update request stream. Retrying.") + logger.warning("GRPC TaskStream: server unavailable during model update request stream. Retrying.") # Retry after a delay time.sleep(5) + if status_code == grpc.StatusCode.UNAUTHENTICATED: + details = e.details() + if details == 'Token expired': + logger.warning("GRPC TaskStream: Token expired. Reconnecting.") + self.detach() + + if status_code == grpc.StatusCode.CANCELLED: + # Expected if the client is detached + logger.critical("GRPC TaskStream: Client detached from combiner. Atempting to reconnect.") + else: # Log the error and continue - logger.error(f"An error occurred during model update request stream: {e}") + logger.error(f"GRPC TaskStream: An error occurred during model update request stream: {e}") except Exception as ex: # Handle other exceptions - logger.error(f"An error occurred during model update request stream: {ex}") + logger.error(f"GRPC TaskStream: An error occurred during model update request stream: {ex}") # Detach if not attached - if not self._attached: + if not self._connected: return def _process_training_request(self, model_id: str, session_id: str = None): @@ -474,7 +486,7 @@ def _process_training_request(self, model_id: str, session_id: str = None): :rtype: tuple """ - self._send_status( + self.send_status( "\t Starting processing of training request for model_id {}".format(model_id), sesssion_id=session_id) self.state = ClientState.training @@ -546,7 +558,7 @@ def _process_validation_request(self, model_id: str, is_inference: bool, session else: cmd = 'validate' - self._send_status( + self.send_status( f"Processing {cmd} request for model_id {model_id}", sesssion_id=session_id) self.state = ClientState.validating try: @@ -580,7 +592,7 @@ def process_request(self): """Process training and validation tasks. """ while True: - if not self._attached: + if not self._connected: return try: @@ -609,15 +621,24 @@ def process_request(self): update.timestamp = str(datetime.now()) update.correlation_id = request.correlation_id update.meta = json.dumps(meta) - # TODO: Check responses - _ = self.combinerStub.SendModelUpdate(update, metadata=self.metadata) - self._send_status("Model update completed.", log_level=fedn.Status.AUDIT, - type=fedn.StatusType.MODEL_UPDATE, request=update, sesssion_id=request.session_id) + try: + _ = self.combinerStub.SendModelUpdate(update, metadata=self.metadata) + self.send_status("Model update completed.", log_level=fedn.Status.AUDIT, + type=fedn.StatusType.MODEL_UPDATE, request=update, sesssion_id=request.session_id) + except grpc.RpcError as e: + status_code = e.code() + logger.error("GRPC error, {}.".format( + status_code.name)) + logger.debug(e) + except ValueError as e: + logger.error("GRPC error, RPC channel closed. {}".format(e)) + logger.debug(e) else: - self._send_status("Client {} failed to complete model update.", - log_level=fedn.Status.WARNING, - request=request, sesssion_id=request.session_id) + self.send_status("Client {} failed to complete model update.", + log_level=fedn.Status.WARNING, + request=request, sesssion_id=request.session_id) + self.state = ClientState.idle self.inbox.task_done() @@ -639,27 +660,37 @@ def process_request(self): validation.correlation_id = request.correlation_id validation.session_id = request.session_id - _ = self.combinerStub.SendModelValidation( - validation, metadata=self.metadata) - - status_type = fedn.StatusType.MODEL_VALIDATION - - self._send_status("Model validation completed.", log_level=fedn.Status.AUDIT, - type=status_type, request=validation, sesssion_id=request.session_id) + try: + _ = self.combinerStub.SendModelValidation( + validation, metadata=self.metadata) + + status_type = fedn.StatusType.MODEL_VALIDATION + self.send_status("Model validation completed.", log_level=fedn.Status.AUDIT, + type=status_type, request=validation, sesssion_id=request.session_id) + except grpc.RpcError as e: + status_code = e.code() + logger.error("GRPC error, {}.".format( + status_code.name)) + logger.debug(e) + except ValueError as e: + logger.error("GRPC error, RPC channel closed. {}".format(e)) + logger.debug(e) else: - self._send_status("Client {} failed to complete model validation.".format(self.name), - log_level=fedn.Status.WARNING, request=request, sesssion_id=request.session_id) + self.send_status("Client {} failed to complete model validation.".format(self.name), + log_level=fedn.Status.WARNING, request=request, sesssion_id=request.session_id) self.state = ClientState.idle self.inbox.task_done() except queue.Empty: pass + except grpc.RpcError as e: + logger.critical(f"GRPC process_request: An error occurred during process request: {e}") def _handle_combiner_failure(self): """ Register failed combiner connection.""" self._missed_heartbeat += 1 if self._missed_heartbeat > self.config['reconnect_after_missed_heartbeat']: - self.detach()() + self.disconnect() def _send_heartbeat(self, update_frequency=2.0): """Send a heartbeat to the combiner. @@ -677,16 +708,22 @@ def _send_heartbeat(self, update_frequency=2.0): self._missed_heartbeat = 0 except grpc.RpcError as e: status_code = e.code() - logger.warning("Client heartbeat: GRPC error, {}. Retrying.".format( - status_code.name)) + if status_code == grpc.StatusCode.UNAVAILABLE: + logger.warning("GRPC hearbeat: server unavailable during send heartbeat. Retrying.") + if status_code == grpc.StatusCode.UNAUTHENTICATED: + details = e.details() + if details == 'Token expired': + logger.warning("GRPC hearbeat: Token expired. Reconnecting.") + self.detach() logger.debug(e) self._handle_combiner_failure() time.sleep(update_frequency) - if not self._attached: + if not self._connected: + logger.info("SendStatus: Client disconnected.") return - def _send_status(self, msg, log_level=fedn.Status.INFO, type=None, request=None, sesssion_id: str = None): + def send_status(self, msg, log_level=fedn.Status.INFO, type=None, request=None, sesssion_id: str = None): """Send status message. :param msg: The message to send. @@ -698,6 +735,11 @@ def _send_status(self, msg, log_level=fedn.Status.INFO, type=None, request=None, :param request: The request message. :type request: fedn.Request """ + + if not self._connected: + logger.info("SendStatus: Client disconnected.") + return + status = fedn.Status() status.timestamp.GetCurrentTime() status.sender.name = self.name @@ -714,7 +756,16 @@ def _send_status(self, msg, log_level=fedn.Status.INFO, type=None, request=None, self.logs.append( "{} {} LOG LEVEL {} MESSAGE {}".format(str(datetime.now()), status.sender.name, status.log_level, status.status)) - _ = self.connectorStub.SendStatus(status, metadata=self.metadata) + try: + _ = self.connectorStub.SendStatus(status, metadata=self.metadata) + except grpc.RpcError as e: + status_code = e.code() + if status_code == grpc.StatusCode.UNAVAILABLE: + logger.warning("GRPC SendStatus: server unavailable during send status.") + if status_code == grpc.StatusCode.UNAUTHENTICATED: + details = e.details() + if details == 'Token expired': + logger.warning("GRPC SendStatus: Token expired.") def run(self): """ Run the client. """ @@ -728,10 +779,10 @@ def run(self): cnt = 1 if self.state != old_state: logger.info("Client in {} state.".format(ClientStateToString(self.state))) - if not self._attached: + if not self._connected: logger.info("Detached from combiner.") # TODO: Implement a check/condition to ulitmately close down if too many reattachment attepts have failed. s - self._attach() + self.attach() self._subscribe_to_combiner(self.config) if self.error_state: return diff --git a/fedn/fedn/network/clients/connect.py b/fedn/fedn/network/clients/connect.py index 478844d26..2e9345ebb 100644 --- a/fedn/fedn/network/clients/connect.py +++ b/fedn/fedn/network/clients/connect.py @@ -8,6 +8,9 @@ import requests +from fedn.common.config import (FEDN_AUTH_REFRESH_TOKEN, + FEDN_AUTH_REFRESH_TOKEN_URI, FEDN_AUTH_SCHEME, + FEDN_CUSTOM_URL_PREFIX) from fedn.common.log_config import logger @@ -77,12 +80,11 @@ def assign(self): try: retval = None payload = {'client_id': self.name, 'preferred_combiner': self.preferred_combiner} - - retval = requests.post(self.connect_string + '/add_client', + retval = requests.post(self.connect_string + FEDN_CUSTOM_URL_PREFIX + '/add_client', json=payload, verify=self.verify, allow_redirects=True, - headers={'Authorization': 'Token {}'.format(self.token)}) + headers={'Authorization': f"{FEDN_AUTH_SCHEME} {self.token}"}) except Exception as e: print('***** {}'.format(e), flush=True) return Status.Unassigned, {} @@ -93,6 +95,16 @@ def assign(self): return Status.UnMatchedConfig, reason if retval.status_code == 401: + if 'message' in retval.json(): + reason = retval.json()['message'] + logger.warning(reason) + if reason == 'Token expired': + status_code = self.refresh_token() + if status_code >= 200 and status_code < 204: + logger.info("Token refreshed.") + return Status.TryAgain, reason + else: + return Status.UnAuthorized, "Could not refresh token" reason = "Unauthorized connection to reducer, make sure the correct token is set" return Status.UnAuthorized, reason @@ -115,3 +127,21 @@ def assign(self): return Status.Assigned, retval.json() return Status.Unassigned, None + + def refresh_token(self): + """ + Refresh client token. + + :return: Tuple with assingment status, combiner connection information if sucessful, else None. + :rtype: tuple(:class:`fedn.network.clients.connect.Status`, str) + """ + if not FEDN_AUTH_REFRESH_TOKEN_URI or not FEDN_AUTH_REFRESH_TOKEN: + logger.error("No refresh token URI/Token set, cannot refresh token.") + return 401 + + payload = requests.post(FEDN_AUTH_REFRESH_TOKEN_URI, + verify=self.verify, + allow_redirects=True, + json={'refresh': FEDN_AUTH_REFRESH_TOKEN}) + self.token = payload.json()['access'] + return payload.status_code diff --git a/fedn/fedn/network/clients/package.py b/fedn/fedn/network/clients/package.py index d56296de8..cc98eae94 100644 --- a/fedn/fedn/network/clients/package.py +++ b/fedn/fedn/network/clients/package.py @@ -9,6 +9,7 @@ import requests import yaml +from fedn.common.config import FEDN_AUTH_SCHEME, FEDN_CUSTOM_URL_PREFIX from fedn.common.log_config import logger from fedn.utils.checksum import sha from fedn.utils.dispatcher import Dispatcher @@ -52,13 +53,13 @@ def download(self, host, port, token, force_ssl=False, secure=False, name=None): else: scheme = "http" if port: - path = f"{scheme}://{host}:{port}/download_package" + path = f"{scheme}://{host}:{port}{FEDN_CUSTOM_URL_PREFIX}/download_package" else: - path = f"{scheme}://{host}/download_package" + path = f"{scheme}://{host}{FEDN_CUSTOM_URL_PREFIX}/download_package" if name: path = path + "?name={}".format(name) - with requests.get(path, stream=True, verify=False, headers={'Authorization': 'Token {}'.format(token)}) as r: + with requests.get(path, stream=True, verify=False, headers={'Authorization': f'{FEDN_AUTH_SCHEME} {token}'}) as r: if 200 <= r.status_code < 204: params = cgi.parse_header( @@ -73,13 +74,13 @@ def download(self, host, port, token, force_ssl=False, secure=False, name=None): for chunk in r.iter_content(chunk_size=8192): f.write(chunk) if port: - path = f"{scheme}://{host}:{port}/get_package_checksum" + path = f"{scheme}://{host}:{port}{FEDN_CUSTOM_URL_PREFIX}/get_package_checksum" else: - path = f"{scheme}://{host}/get_package_checksum" + path = f"{scheme}://{host}{FEDN_CUSTOM_URL_PREFIX}/get_package_checksum" if name: path = path + "?name={}".format(name) - with requests.get(path, verify=False, headers={'Authorization': 'Token {}'.format(token)}) as r: + with requests.get(path, verify=False, headers={'Authorization': f'{FEDN_AUTH_SCHEME} {token}'}) as r: if 200 <= r.status_code < 204: data = r.json() diff --git a/fedn/fedn/network/combiner/aggregators/fedopt.py b/fedn/fedn/network/combiner/aggregators/fedopt.py index d3152c957..ccabb2789 100644 --- a/fedn/fedn/network/combiner/aggregators/fedopt.py +++ b/fedn/fedn/network/combiner/aggregators/fedopt.py @@ -34,13 +34,13 @@ def __init__(self, storage, server, modelservice, round_handler): self.m = None # Server side hyperparameters. Note that these may need extensive fine tuning. - self.eta = 0.1 + self.eta = 1e-2 self.beta1 = 0.9 self.beta2 = 0.99 self.tau = 1e-4 def combine_models(self, helper=None, delete_models=True): - """Compute pseudo gradients usigng model updates in the queue. + """Compute pseudo gradients using model updates in the queue. :param helper: An instance of :class: `fedn.utils.helpers.helpers.HelperBase`, ML framework specific helper, defaults to None :type helper: class: `fedn.utils.helpers.helpers.HelperBase`, optional @@ -88,7 +88,7 @@ def combine_models(self, helper=None, delete_models=True): pseudo_gradient = helper.increment_average( pseudo_gradient, pseudo_gradient_next, metadata['num_examples'], total_examples) - print("NORM PSEUDOGRADIENT: ", helper.norm(pseudo_gradient), flush=True) + logger.info("NORM PSEUDOGRADIENT: {}".format(helper.norm(pseudo_gradient))) nr_aggregated_models += 1 # Delete model from storage diff --git a/fedn/fedn/network/combiner/connect.py b/fedn/fedn/network/combiner/connect.py index 4c1c94266..7dc388261 100644 --- a/fedn/fedn/network/combiner/connect.py +++ b/fedn/fedn/network/combiner/connect.py @@ -5,6 +5,7 @@ # # import enum +import os import requests @@ -72,10 +73,14 @@ def __init__(self, host, port, myhost, fqdn, myport, token, name, secure=False, self.myhost = myhost self.myport = myport self.token = token + self.token_scheme = os.environ.get('FEDN_AUTH_SCHEME', 'Bearer') self.name = name self.secure = secure self.verify = verify + if not self.token: + self.token = os.environ.get('FEDN_AUTH_TOKEN', None) + # for https we assume a an ingress handles permanent redirect (308) self.prefix = "http://" if port: @@ -101,10 +106,11 @@ def announce(self): "port": self.myport, "secure_grpc": self.secure } + url_prefix = os.environ.get('FEDN_CUSTOM_URL_PREFIX', '') try: - retval = requests.post(self.connect_string + '/add_combiner', json=payload, + retval = requests.post(self.connect_string + url_prefix + '/add_combiner', json=payload, verify=self.verify, - headers={'Authorization': 'Token {}'.format(self.token)}) + headers={'Authorization': f'{self.token_scheme} {self.token}'}) except Exception: return Status.Unassigned, {} diff --git a/fedn/fedn/network/combiner/roundhandler.py b/fedn/fedn/network/combiner/roundhandler.py index 8d87b2b69..d27b7ffdb 100644 --- a/fedn/fedn/network/combiner/roundhandler.py +++ b/fedn/fedn/network/combiner/roundhandler.py @@ -254,7 +254,7 @@ def _assign_round_clients(self, n, type="trainers"): return clients - def _check_nr_round_clients(self, config, timeout=0.0): + def _check_nr_round_clients(self, config): """Check that the minimal number of clients required to start a round are available. :param config: The round config object. @@ -265,27 +265,14 @@ def _check_nr_round_clients(self, config, timeout=0.0): :rtype: bool """ - ready = False - t = 0.0 - while not ready: - active = self.server.nr_active_trainers() - - if active >= int(config['clients_requested']): - return True - else: - logger.info("waiting for {} clients to get started, currently: {}".format( - int(config['clients_requested']) - active, - active)) - if t >= timeout: - if active >= int(config['clients_required']): - return True - else: - return False - - time.sleep(1.0) - t += 1.0 - - return ready + active = self.server.nr_active_trainers() + if active >= int(config['clients_required']): + logger.info("Number of clients required ({0}) to start round met {1}.".format( + config['clients_required'], active)) + return True + else: + logger.info("Too few clients to start round.") + return False def execute_validation_round(self, round_config): """ Coordinate validation rounds as specified in config. diff --git a/fedn/fedn/network/controller/controlbase.py b/fedn/fedn/network/controller/controlbase.py index de13ce32a..e825e8e8b 100644 --- a/fedn/fedn/network/controller/controlbase.py +++ b/fedn/fedn/network/controller/controlbase.py @@ -177,7 +177,7 @@ def get_compute_package(self, compute_package=""): else: return None - def create_session(self, config): + def create_session(self, config, status='Initialized'): """ Initialize a new session in backend db. """ if "session_id" not in config.keys(): @@ -188,6 +188,7 @@ def create_session(self, config): self.statestore.create_session(id=session_id) self.statestore.set_session_config(session_id, config) + self.statestore.set_session_status(session_id, status) def set_session_status(self, session_id, status): """ Set the round round stats. diff --git a/fedn/fedn/network/grpc/auth.py b/fedn/fedn/network/grpc/auth.py new file mode 100644 index 000000000..d879cd812 --- /dev/null +++ b/fedn/fedn/network/grpc/auth.py @@ -0,0 +1,95 @@ +import grpc +import jwt + +from fedn.common.config import FEDN_AUTH_SCHEME, FEDN_JWT_ALGORITHM, SECRET_KEY +from fedn.common.log_config import logger +from fedn.network.api.auth import check_custom_claims + +ENDPOINT_ROLES_MAPPING = { + '/fedn.Combiner/TaskStream': ['client'], + '/fedn.Combiner/SendModelUpdate': ['client'], + '/fedn.Combiner/SendModelValidation': ['client'], + '/fedn.Connector/SendHeartbeat': ['client'], + '/fedn.Connector/SendStatus': ['client'], + '/fedn.ModelService/Download': ['client'], + '/fedn.ModelService/Upload': ['client'], + '/fedn.Control/Start': ['controller'], + '/fedn.Control/Stop': ['controller'], + '/fedn.Control/FlushAggregationQueue': ['controller'], + '/fedn.Control/SetAggregator': ['controller'], +} + +ENDPOINT_WHITELIST = [ + '/fedn.Connector/AcceptingClients', + '/fedn.Connector/ListActiveClients', + '/fedn.Control/Start', + '/fedn.Control/Stop', + '/fedn.Control/FlushAggregationQueue', + '/fedn.Control/SetAggregator', +] + +USER_AGENT_WHITELIST = [ + 'grpc_health_probe' +] + + +def check_role_claims(payload, endpoint): + user_role = payload.get('role', '') + + # Perform endpoint-specific RBAC check + allowed_roles = ENDPOINT_ROLES_MAPPING.get(endpoint) + if allowed_roles and user_role not in allowed_roles: + return False + return True + + +def _unary_unary_rpc_terminator(code, details): + def terminate(ignored_request, context): + context.abort(code, details) + + return grpc.unary_unary_rpc_method_handler(terminate) + + +class JWTInterceptor(grpc.ServerInterceptor): + def __init__(self): + pass + + def intercept_service(self, continuation, handler_call_details): + # Pass if no secret key is set + if not SECRET_KEY: + return continuation(handler_call_details) + metadata = dict(handler_call_details.invocation_metadata) + # Pass whitelisted methods + if handler_call_details.method in ENDPOINT_WHITELIST: + return continuation(handler_call_details) + # Pass if the request comes from whitelisted user agents + user_agent = metadata.get('user-agent').split(' ')[0] + if user_agent in USER_AGENT_WHITELIST: + return continuation(handler_call_details) + + token = metadata.get('authorization') + if token is None: + return _unary_unary_rpc_terminator(grpc.StatusCode.UNAUTHENTICATED, 'Token is missing') + + if not token.startswith(FEDN_AUTH_SCHEME): + return _unary_unary_rpc_terminator(grpc.StatusCode.UNAUTHENTICATED, f'Invalid token scheme, expected {FEDN_AUTH_SCHEME}') + + token = token.split(' ')[1] + + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[FEDN_JWT_ALGORITHM]) + + if not check_role_claims(payload, handler_call_details.method): + return _unary_unary_rpc_terminator(grpc.StatusCode.PERMISSION_DENIED, 'Insufficient permissions') + + if not check_custom_claims(payload): + return _unary_unary_rpc_terminator(grpc.StatusCode.PERMISSION_DENIED, 'Insufficient permissions') + + return continuation(handler_call_details) + except jwt.InvalidTokenError: + return _unary_unary_rpc_terminator(grpc.StatusCode.UNAUTHENTICATED, 'Invalid token') + except jwt.ExpiredSignatureError: + return _unary_unary_rpc_terminator(grpc.StatusCode.UNAUTHENTICATED, 'Token expired') + except Exception as e: + logger.error(str(e)) + return _unary_unary_rpc_terminator(grpc.StatusCode.UNKNOWN, str(e)) diff --git a/fedn/fedn/network/grpc/server.py b/fedn/fedn/network/grpc/server.py index 59ed6b1ba..916b4756e 100644 --- a/fedn/fedn/network/grpc/server.py +++ b/fedn/fedn/network/grpc/server.py @@ -6,6 +6,7 @@ import fedn.network.grpc.fedn_pb2_grpc as rpc from fedn.common.log_config import (logger, set_log_level_from_string, set_log_stream) +from fedn.network.grpc.auth import JWTInterceptor class Server: @@ -16,7 +17,7 @@ def __init__(self, servicer, modelservicer, config): set_log_level_from_string(config.get('verbosity', "INFO")) set_log_stream(config.get('logfile', None)) - self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=350)) + self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=350), interceptors=[JWTInterceptor()]) self.certificate = None self.health_servicer = health.HealthServicer() diff --git a/fedn/fedn/tests/__init__.py b/fedn/fedn/tests/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/fedn/setup.py b/fedn/setup.py index cd604773d..dd3d9df5e 100644 --- a/fedn/setup.py +++ b/fedn/setup.py @@ -10,30 +10,22 @@ py_modules=['fedn'], python_requires='>=3.8,<3.11', install_requires=[ - "PyYAML>=5.4", "requests", "urllib3>=1.26.4", "minio", - "python-slugify", "grpcio~=1.57.0", "grpcio-tools~=1.57.0", "numpy>=1.21.6", "protobuf", "pymongo", "Flask", - "Flask-WTF", "pyjwt", "pyopenssl", - "ttictoc", "psutil", "click==8.0.1", - "jinja2", - "plotly", - "pandas", - "bokeh<3.0.0", - "networkx", "grpcio-health-checking~=1.57.0", - "flasgger==0.9.5" + "flasgger==0.9.5", + "plotly", ], license='Apache 2.0', zip_safe=False,