diff --git a/docker-compose.yaml b/docker-compose.yaml
index 9fe395c3a..2f05aba8e 100644
--- a/docker-compose.yaml
+++ b/docker-compose.yaml
@@ -34,7 +34,7 @@ services:
- 9001:9001
mongo:
- image: mongo:7.0
+ image: mongo:7.0
restart: always
environment:
- MONGO_INITDB_ROOT_USERNAME=fedn_admin
diff --git a/docs/tutorial.rst b/docs/tutorial.rst
index f48d06017..b0ded9777 100644
--- a/docs/tutorial.rst
+++ b/docs/tutorial.rst
@@ -308,7 +308,7 @@ For the compute package we need to compress the *client* folder as .tar.gz file.
.. code-block:: bash
- tar -czvf package.tar.gz client
+ tar -czvf package.tgz client
This file can then be uploaded to the FEDn network using the FEDn UI or the :py:mod:`fedn.network.api.client`.
diff --git a/examples/async-simulation/.gitignore b/examples/async-simulation/.gitignore
new file mode 100644
index 000000000..4ab9fa59f
--- /dev/null
+++ b/examples/async-simulation/.gitignore
@@ -0,0 +1,6 @@
+data
+*.npz
+*.tgz
+*.tar.gz
+.async-simulation
+client.yaml
\ No newline at end of file
diff --git a/examples/async-simulation/Experiment.ipynb b/examples/async-simulation/Experiment.ipynb
new file mode 100644
index 000000000..12a9aee42
--- /dev/null
+++ b/examples/async-simulation/Experiment.ipynb
@@ -0,0 +1,178 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "622f7047",
+ "metadata": {},
+ "source": [
+ "## FEDn API Example\n",
+ "\n",
+ "This notebook provides an example of how to use the FEDn API to organize experiments and to analyze validation results. We will here run one training session using FedAvg and one session using FedAdam and compare the results.\n",
+ "\n",
+ "When you start this tutorial you should have a deployed FEDn Network up and running, and you should have created the compute package and the initial model, see the README for instructions."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "id": "743dfe47",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from fedn import APIClient\n",
+ "from fedn.dashboard.plots import Plot\n",
+ "from fedn.network.clients.client import Client\n",
+ "import uuid\n",
+ "import json\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "import collections\n",
+ "import copy"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1046a4e5",
+ "metadata": {},
+ "source": [
+ "We make a client connection to the FEDn API service. Here we assume that FEDn is deployed locally in pseudo-distributed mode with default ports."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 33,
+ "id": "1061722d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "DISCOVER_HOST = '127.0.0.1'\n",
+ "DISCOVER_PORT = 8092\n",
+ "client = APIClient(DISCOVER_HOST, DISCOVER_PORT)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "07f69f5f",
+ "metadata": {},
+ "source": [
+ "Initialize FEDn with the compute package and seed model. Note that these files needs to be created separately by follwing instructions in the README."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 48,
+ "id": "5107f6f9",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "client.set_package('package.tgz', 'numpyhelper')\n",
+ "client.set_initial_model('seed.npz')\n",
+ "seed_model = client.get_initial_model()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "4e26c50b",
+ "metadata": {},
+ "source": [
+ "Next we start a training session using FedAvg:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 70,
+ "id": "f0380d35",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "session_config_fedavg = {\n",
+ " \"helper\": \"numpyhelper\",\n",
+ " \"session_id\": \"experiment_fedavg4\",\n",
+ " \"aggregator\": \"fedavg\",\n",
+ " \"model_id\": seed_model['model_id'],\n",
+ " \"rounds\": 1,\n",
+ " }\n",
+ "\n",
+ "result_fedavg = client.start_session(**session_config_fedavg)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "29552af9",
+ "metadata": {},
+ "source": [
+ "Next, we retrive all model validations from all clients, extract the training accuracy metric, and compute its mean value accross all clients"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "id": "11fd17ef",
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "models = client.list_models(session_id = \"experiment_fedavg\")\n",
+ "\n",
+ "validations = []\n",
+ "acc = collections.OrderedDict()\n",
+ "for model in models[\"result\"]:\n",
+ " model_id = model[\"model\"]\n",
+ " validations = client.list_validations(modelId=model_id)\n",
+ "\n",
+ " for _ , validation in validations.items(): \n",
+ " metrics = json.loads(validation['data'])\n",
+ " try:\n",
+ " acc[model_id].append(metrics['training_accuracy'])\n",
+ " except KeyError: \n",
+ " acc[model_id] = [metrics['training_accuracy']]\n",
+ " \n",
+ "mean_acc_fedavg = []\n",
+ "for model, data in acc.items():\n",
+ " mean_acc_fedavg.append(np.mean(data))\n",
+ "mean_acc_fedavg.reverse()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "40db4542",
+ "metadata": {},
+ "source": [
+ "Finally, plot the result."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "d064aaf9",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "x = range(1,len(mean_acc_fedavg)+1)\n",
+ "plt.plot(x, mean_acc_fedavg)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.6"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/examples/async-simulation/README.md b/examples/async-simulation/README.md
new file mode 100644
index 000000000..b5cbfe2ed
--- /dev/null
+++ b/examples/async-simulation/README.md
@@ -0,0 +1,53 @@
+# ASYNC SIMULATION
+This example is intended as a test for asynchronous clients.
+
+## Prerequisites
+- [Python 3.8, 3.9 or 3.10](https://www.python.org/downloads)
+- [Docker](https://docs.docker.com/get-docker)
+- [Docker Compose](https://docs.docker.com/compose/install)
+
+## Running the example (pseudo-distributed, single host)
+
+Clone FEDn and locate into this directory.
+```sh
+git clone https://github.com/scaleoutsystems/fedn.git
+cd fedn/examples/async-simulation
+```
+
+### Preparing the environment, the local data, the compute package and seed model
+
+Install FEDn and dependencies (we recommend using a virtual environment):
+
+Standing in the folder 'fedn/fedn'
+
+```
+pip install -e .
+```
+
+From examples/async-simulation
+```
+pip install -r requirements.txt
+```
+
+Create the compute package and a seed model that you will be asked to upload in the next step.
+```
+tar -czvf package.tgz client
+```
+
+```
+python client/entrypoint init_seed
+```
+
+### Deploy FEDn and two clients
+docker-compose -f ../../docker-compose.yaml -f docker-compose.override.yaml up
+
+### Initialize the federated model
+See 'Experiments.pynb' or 'launch_client.py' to set the package and seed model.
+
+> **Note**: run with `--scale client=N` to start *N* clients.
+
+### Run federated training
+See 'Experiment.ipynb'.
+
+## Clean up
+You can clean up by running `docker-compose down -v`.
diff --git a/examples/async-simulation/client/entrypoint b/examples/async-simulation/client/entrypoint
new file mode 100644
index 000000000..dd2216fc0
--- /dev/null
+++ b/examples/async-simulation/client/entrypoint
@@ -0,0 +1,98 @@
+# /bin/python
+import time
+
+import fire
+import numpy as np
+
+from fedn.utils.helpers.helpers import get_helper, save_metadata, save_metrics
+
+HELPER_MODULE = 'numpyhelper'
+ARRAY_SIZE = 1000000
+
+
+def save_model(weights, out_path):
+ """ Save model to disk.
+
+ :param model: The model to save.
+ :type model: torch.nn.Module
+ :param out_path: The path to save to.
+ :type out_path: str
+ """
+ helper = get_helper(HELPER_MODULE)
+ helper.save(weights, out_path)
+
+
+def load_model(model_path):
+ """ Load model from disk.
+
+ param model_path: The path to load from.
+ :type model_path: str
+ :return: The loaded model.
+ :rtype: torch.nn.Module
+ """
+ helper = get_helper(HELPER_MODULE)
+ weights = helper.load(model_path)
+ return weights
+
+
+def init_seed(out_path='seed.npz'):
+ """ Initialize seed model.
+
+ :param out_path: The path to save the seed model to.
+ :type out_path: str
+ """
+ # Init and save
+ weights = [np.random.rand(1, ARRAY_SIZE)]
+ save_model(weights, out_path)
+
+
+def train(in_model_path, out_model_path):
+ """ Train model.
+
+ """
+
+ # Load model
+ weights = load_model(in_model_path)
+
+ # Train
+ time.sleep(np.random.randint(4, 15))
+
+ # Metadata needed for aggregation server side
+ metadata = {
+ 'num_examples': ARRAY_SIZE,
+ }
+
+ # Save JSON metadata file
+ save_metadata(metadata, out_model_path)
+
+ # Save model update
+ save_model(weights, out_model_path)
+
+
+def validate(in_model_path, out_json_path):
+ """ Validate model.
+
+ :param in_model_path: The path to the input model.
+ :type in_model_path: str
+ :param out_json_path: The path to save the output JSON to.
+ :type out_json_path: str
+ :param data_path: The path to the data file.
+ :type data_path: str
+ """
+ weights = load_model(in_model_path)
+
+ # JSON schema
+ report = {
+ "mean": np.mean(weights),
+ }
+
+ # Save JSON
+ save_metrics(report, out_json_path)
+
+
+if __name__ == '__main__':
+ fire.Fire({
+ 'init_seed': init_seed,
+ 'train': train,
+ 'validate': validate
+ })
diff --git a/examples/async-simulation/client/fedn.yaml b/examples/async-simulation/client/fedn.yaml
new file mode 100644
index 000000000..68cb70cef
--- /dev/null
+++ b/examples/async-simulation/client/fedn.yaml
@@ -0,0 +1,5 @@
+entry_points:
+ train:
+ command: /venv/bin/python entrypoint train $ENTRYPOINT_OPTS
+ validate:
+ command: /venv/bin/python entrypoint validate $ENTRYPOINT_OPTS
\ No newline at end of file
diff --git a/examples/async-simulation/init_fedn.py b/examples/async-simulation/init_fedn.py
new file mode 100644
index 000000000..23078fcd9
--- /dev/null
+++ b/examples/async-simulation/init_fedn.py
@@ -0,0 +1,8 @@
+from fedn import APIClient
+
+DISCOVER_HOST = '127.0.0.1'
+DISCOVER_PORT = 8092
+
+client = APIClient(DISCOVER_HOST, DISCOVER_PORT)
+client.set_package('package.tgz', 'numpyhelper')
+client.set_initial_model('seed.npz')
diff --git a/examples/async-simulation/launch_clients.py b/examples/async-simulation/launch_clients.py
new file mode 100644
index 000000000..6cffbedd3
--- /dev/null
+++ b/examples/async-simulation/launch_clients.py
@@ -0,0 +1,41 @@
+"""This scripts starts N_CLIENTS using the SDK.
+
+If you are running with a local deploy of FEDn
+using docker compose, you need to make sure that clients
+are able to resolve the name "combiner" to 127.0.0.1
+
+One way to accomplish this is to edit your /etc/host,
+adding the line:
+
+combiner 127.0.0.1
+
+"""
+
+
+import copy
+import time
+
+from fedn.network.clients.client import Client
+
+DISCOVER_HOST = '127.0.0.1'
+DISCOVER_PORT = 8092
+N_CLIENTS = 5
+CLIENTS_AVAILABLE_TIME = 120
+
+config = {'discover_host': DISCOVER_HOST, 'discover_port': DISCOVER_PORT, 'token': None, 'name': 'testclient',
+ 'client_id': 1, 'remote_compute_context': True, 'force_ssl': False, 'dry_run': False, 'secure': False,
+ 'preshared_cert': False, 'verify': False, 'preferred_combiner': False,
+ 'validator': True, 'trainer': True, 'init': None, 'logfile': 'test.log', 'heartbeat_interval': 2,
+ 'reconnect_after_missed_heartbeat': 30}
+
+# Start up N_CLIENTS clients
+clients = []
+for i in range(N_CLIENTS):
+ config_i = copy.deepcopy(config)
+ config['name'] = 'client{}'.format(i)
+ clients.append(Client(config))
+
+# Disconnect clients after some time
+time.sleep(CLIENTS_AVAILABLE_TIME)
+for client in clients:
+ client.detach()
diff --git a/examples/async-simulation/requirements.txt b/examples/async-simulation/requirements.txt
new file mode 100644
index 000000000..c6bceff1d
--- /dev/null
+++ b/examples/async-simulation/requirements.txt
@@ -0,0 +1 @@
+fire==0.3.1
\ No newline at end of file
diff --git a/examples/mnist-keras/bin/init_venv_macm1.sh b/examples/mnist-keras/bin/init_venv_macm1.sh
new file mode 100755
index 000000000..d60f602d5
--- /dev/null
+++ b/examples/mnist-keras/bin/init_venv_macm1.sh
@@ -0,0 +1,10 @@
+#!/bin/bash
+set -e
+
+# Init venv
+python3 -m venv .mnist-keras
+
+# Pip deps
+.mnist-keras/bin/pip install --upgrade pip
+.mnist-keras/bin/pip install -e ../../fedn
+.mnist-keras/bin/pip install -r requirements-macos.txt
diff --git a/examples/mnist-keras/requirements-macos.txt b/examples/mnist-keras/requirements-macos.txt
new file mode 100644
index 000000000..4770f97cd
--- /dev/null
+++ b/examples/mnist-keras/requirements-macos.txt
@@ -0,0 +1,4 @@
+tensorflow-macos
+tensorflow-metal
+fire==0.3.1
+docker==5.0.2
diff --git a/fedn/fedn/dashboard/restservice.py b/fedn/fedn/dashboard/restservice.py
index 64814a5de..27ee6f13b 100644
--- a/fedn/fedn/dashboard/restservice.py
+++ b/fedn/fedn/dashboard/restservice.py
@@ -790,7 +790,7 @@ def context():
return redirect(url_for("context"))
file = request.files["file"]
- helper_type = request.form.get("helper", "kerashelper")
+ helper_type = request.form.get("helper", "numpyhelper")
# if user does not select file, browser also
# submit an empty part without filename
if file.filename == "":
diff --git a/fedn/fedn/dashboard/templates/context.html b/fedn/fedn/dashboard/templates/context.html
index 8f392082a..9ad050984 100644
--- a/fedn/fedn/dashboard/templates/context.html
+++ b/fedn/fedn/dashboard/templates/context.html
@@ -12,8 +12,7 @@
{{ message }}
diff --git a/fedn/fedn/network/clients/client.py b/fedn/fedn/network/clients/client.py
index 633b2078a..68eae0845 100644
--- a/fedn/fedn/network/clients/client.py
+++ b/fedn/fedn/network/clients/client.py
@@ -26,6 +26,7 @@
from fedn.network.clients.connect import ConnectorClient, Status
from fedn.network.clients.package import PackageRuntime
from fedn.network.clients.state import ClientState, ClientStateToString
+from fedn.network.combiner.modelservice import upload_request_generator
from fedn.utils.dispatcher import Dispatcher
from fedn.utils.helpers.helpers import get_helper
@@ -236,7 +237,7 @@ def _disconnect(self):
"""Disconnect from the combiner."""
self.channel.close()
- def _detach(self):
+ def detach(self):
"""Detach from the FEDn network (disconnect from combiner)"""
# Setting _attached to False will make all processing threads return
if not self._attached:
@@ -328,7 +329,7 @@ def _initialize_dispatcher(self, config):
if retval:
if 'checksum' not in config:
- logger.warning("Bypassing security validation for local package. Ensure the package source is trusted.")
+ logger.warning("Bypassing validation of package checksum. Ensure the package source is trusted.")
else:
checks_out = pr.validate(config['checksum'])
if not checks_out:
@@ -358,7 +359,7 @@ def _initialize_dispatcher(self, config):
copy_tree(from_path, self.run_path)
self.dispatcher = Dispatcher(dispatch_config, self.run_path)
- def get_model(self, id):
+ def get_model_from_combiner(self, id):
"""Fetch a model from the assigned combiner.
Downloads the model update object via a gRPC streaming channel.
@@ -382,7 +383,7 @@ def get_model(self, id):
return data
- def set_model(self, model, id):
+ def send_model_to_combiner(self, model, id):
"""Send a model update to the assigned combiner.
Uploads the model updated object via a gRPC streaming channel, Upload.
@@ -403,28 +404,7 @@ def set_model(self, model, id):
bt.seek(0, 0)
- def upload_request_generator(mdl):
- """Generator function for model upload requests.
-
- :param mdl: The model update object.
- :type mdl: BytesIO
- :return: A model update request.
- :rtype: fedn.ModelRequest
- """
- while True:
- b = mdl.read(CHUNK_SIZE)
- if b:
- result = fedn.ModelRequest(
- data=b, id=id, status=fedn.ModelStatus.IN_PROGRESS)
- else:
- result = fedn.ModelRequest(
- id=id, status=fedn.ModelStatus.OK)
-
- yield result
- if not b:
- break
-
- result = self.modelStub.Upload(upload_request_generator(bt), metadata=self.metadata)
+ result = self.modelStub.Upload(upload_request_generator(bt, id), metadata=self.metadata)
return result
@@ -528,7 +508,7 @@ def _process_training_request(self, model_id):
try:
meta = {}
tic = time.time()
- mdl = self.get_model(str(model_id))
+ mdl = self.get_model_from_combiner(str(model_id))
meta['fetch_model'] = time.time() - tic
inpath = self.helper.get_tmp_path()
@@ -549,9 +529,9 @@ def _process_training_request(self, model_id):
with open(outpath, "rb") as fr:
out_model = io.BytesIO(fr.read())
- # Push model update to combiner server
+ # Stream model update to combiner server
updated_model_id = uuid.uuid4()
- self.set_model(out_model, str(updated_model_id))
+ self.send_model_to_combiner(out_model, str(updated_model_id))
meta['upload_model'] = time.time() - tic
# Read the metadata file
@@ -592,7 +572,7 @@ def _process_validation_request(self, model_id, is_inference):
f"Processing {cmd} request for model_id {model_id}")
self.state = ClientState.validating
try:
- model = self.get_model(str(model_id))
+ model = self.get_model_from_combiner(str(model_id))
inpath = self.helper.get_tmp_path()
with open(inpath, "wb") as fh:
@@ -609,7 +589,6 @@ def _process_validation_request(self, model_id, is_inference):
except Exception as e:
logger.warning("Validation failed with exception {}".format(e))
- raise
self.state = ClientState.idle
return None
@@ -700,7 +679,7 @@ def _handle_combiner_failure(self):
""" Register failed combiner connection."""
self._missed_heartbeat += 1
if self._missed_heartbeat > self.config['reconnect_after_missed_heartbeat']:
- self._detach()
+ self.detach()()
def _send_heartbeat(self, update_frequency=2.0):
"""Send a heartbeat to the combiner.
diff --git a/fedn/fedn/network/combiner/aggregators/aggregator.py b/fedn/fedn/network/combiner/aggregators/aggregator.py
new file mode 100644
index 000000000..69e0c4fcf
--- /dev/null
+++ b/fedn/fedn/network/combiner/aggregators/aggregator.py
@@ -0,0 +1,119 @@
+import json
+import queue
+from abc import ABC, abstractmethod
+
+import fedn.common.net.grpc.fedn_pb2 as fedn
+
+
+class Aggregator(ABC):
+ """ Abstract class defining an aggregator. """
+
+ @abstractmethod
+ def __init__(self, id, storage, server, modelservice, control):
+ """ Initialize the aggregator.
+
+ :param id: A reference to id of :class: `fedn.network.combiner.Combiner`
+ :type id: str
+ :param storage: Model repository for :class: `fedn.network.combiner.Combiner`
+ :type storage: class: `fedn.common.storage.s3.s3repo.S3ModelRepository`
+ :param server: A handle to the Combiner class :class: `fedn.network.combiner.Combiner`
+ :type server: class: `fedn.network.combiner.Combiner`
+ :param modelservice: A handle to the model service :class: `fedn.network.combiner.modelservice.ModelService`
+ :type modelservice: class: `fedn.network.combiner.modelservice.ModelService`
+ :param control: A handle to the :class: `fedn.network.combiner.round.RoundController`
+ :type control: class: `fedn.network.combiner.round.RoundController`
+ """
+ self.name = self.__class__.__name__
+ self.storage = storage
+ self.id = id
+ self.server = server
+ self.modelservice = modelservice
+ self.control = control
+ self.model_updates = queue.Queue()
+
+ @abstractmethod
+ def combine_models(self, nr_expected_models=None, nr_required_models=1, helper=None, timeout=180):
+ """Routine for combining model updates. Implemented in subclass.
+
+ :param nr_expected_models: Number of expected models. If None, wait for all models.
+ :type nr_expected_models: int
+ :param nr_required_models: Number of required models to combine.
+ :type nr_required_models: int
+ :param helper: A helper object.
+ :type helper: :class: `fedn.utils.plugins.helperbase.HelperBase`
+ :param timeout: Timeout in seconds to wait for models to be combined.
+ :type timeout: int
+ :return: A combined model.
+ """
+ pass
+
+ def on_model_update(self, model_update):
+ """Callback when a new client model update is recieved.
+ Performs (optional) pre-processing and then puts the update id
+ on the aggregation queue. Override in subclass as needed.
+
+ :param model_update: A ModelUpdate message.
+ :type model_id: str
+ """
+ try:
+ self.server.report_status("AGGREGATOR({}): callback received model update {}".format(self.name, model_update.model_update_id),
+ log_level=fedn.Status.INFO)
+
+ # Validate the update and metadata
+ valid_update = self._validate_model_update(model_update)
+ if valid_update:
+ # Push the model update to the processing queue
+ self.model_updates.put(model_update)
+ else:
+ self.server.report_status("AGGREGATOR({}): Invalid model update, skipping.".format(self.name))
+ except Exception as e:
+ self.server.report_status("AGGREGATOR({}): Failed to receive candidate model! {}".format(self.name, e),
+ log_level=fedn.Status.WARNING)
+ pass
+
+ def on_model_validation(self, model_validation):
+ """ Callback when a new client model validation is recieved.
+ Performs (optional) pre-processing and then writes the validation
+ to the database. Override in subclass as needed.
+
+ :param validation: Dict containing validation data sent by client.
+ Must be valid JSON.
+ :type validation: dict
+ """
+
+ # self.report_validation(validation)
+ self.server.report_status("AGGREGATOR({}): callback processed validation {}".format(self.name, model_validation.model_id),
+ log_level=fedn.Status.INFO)
+
+ def _validate_model_update(self, model_update):
+ """ Validate the model update.
+
+ :param model_update: A ModelUpdate message.
+ :type model_update: object
+ :return: True if the model update is valid, False otherwise.
+ :rtype: bool
+ """
+ # TODO: Validate the metadata to check that it contains all variables assumed by the aggregator.
+ data = json.loads(model_update.meta)['training_metadata']
+ if 'num_examples' not in data.keys():
+ self.server.report_status("AGGREGATOR({}): Model validation failed, num_examples missing in metadata.".format(self.name))
+ return False
+ return True
+
+ def next_model_update(self, helper):
+ """ Get the next model update from the queue.
+
+ :param helper: A helper object.
+ :type helper: object
+ :return: A tuple containing the model update, metadata and model id.
+ :rtype: tuple
+ """
+ model_update = self.model_updates.get(block=False)
+ model_id = model_update.model_update_id
+ model_next = self.control.load_model_update(helper, model_id)
+ # Get relevant metadata
+ data = json.loads(model_update.meta)['training_metadata']
+ config = json.loads(json.loads(model_update.meta)['config'])
+ data['round_id'] = config['round_id']
+
+ return model_next, data, model_id
diff --git a/fedn/fedn/network/combiner/aggregators/fedavg.py b/fedn/fedn/network/combiner/aggregators/fedavg.py
index e8541f326..50d45e641 100644
--- a/fedn/fedn/network/combiner/aggregators/fedavg.py
+++ b/fedn/fedn/network/combiner/aggregators/fedavg.py
@@ -76,7 +76,7 @@ def combine_models(self, helper=None, delete_models=True):
nr_aggregated_models += 1
# Delete model from storage
if delete_models:
- self.modelservice.models.delete(model_update.model_update_id)
+ self.modelservice.temp_model_storage.delete(model_update.model_update_id)
logger.info(
"AGGREGATOR({}): Deleted model update {} from storage.".format(self.name, model_update.model_update_id))
self.model_updates.task_done()
diff --git a/fedn/fedn/network/combiner/aggregators/fedopt.py b/fedn/fedn/network/combiner/aggregators/fedopt.py
index 2298190af..d3152c957 100644
--- a/fedn/fedn/network/combiner/aggregators/fedopt.py
+++ b/fedn/fedn/network/combiner/aggregators/fedopt.py
@@ -93,7 +93,7 @@ def combine_models(self, helper=None, delete_models=True):
nr_aggregated_models += 1
# Delete model from storage
if delete_models:
- self.modelservice.models.delete(model_update.model_update_id)
+ self.modelservice.temp_model_storage.delete(model_update.model_update_id)
logger.info(
"AGGREGATOR({}): Deleted model update {} from storage.".format(self.name, model_update.model_update_id))
self.model_updates.task_done()
diff --git a/fedn/fedn/network/combiner/modelservice.py b/fedn/fedn/network/combiner/modelservice.py
index 78555e61e..909d5936a 100644
--- a/fedn/fedn/network/combiner/modelservice.py
+++ b/fedn/fedn/network/combiner/modelservice.py
@@ -10,13 +10,95 @@
CHUNK_SIZE = 1024 * 1024
+def upload_request_generator(mdl, id):
+ """Generator function for model upload requests.
+
+ :param mdl: The model update object.
+ :type mdl: BytesIO
+ :return: A model update request.
+ :rtype: fedn.ModelRequest
+ """
+ while True:
+ b = mdl.read(CHUNK_SIZE)
+ if b:
+ result = fedn.ModelRequest(
+ data=b, id=id, status=fedn.ModelStatus.IN_PROGRESS)
+ else:
+ result = fedn.ModelRequest(
+ id=id, data=None, status=fedn.ModelStatus.OK)
+ yield result
+ if not b:
+ break
+
+
+def model_as_bytesIO(model):
+ if not isinstance(model, BytesIO):
+ bt = BytesIO()
+
+ written_total = 0
+ for d in model.stream(32 * 1024):
+ written = bt.write(d)
+ written_total += written
+ else:
+ bt = model
+
+ bt.seek(0, 0)
+ return bt
+
+
+def get_tmp_path():
+ """ Return a temporary output path compatible with save_model, load_model. """
+ fd, path = tempfile.mkstemp()
+ os.close(fd)
+ return path
+
+
+def load_model_from_BytesIO(model_bytesio, helper):
+ """ Load a model from a BytesIO object.
+ :param model_bytesio: A BytesIO object containing the model.
+ :type model_bytesio: :class:`io.BytesIO`
+ :param helper: The helper object for the model.
+ :type helper: :class:`fedn.utils.helperbase.HelperBase`
+ :return: The model object.
+ :rtype: return type of helper.load
+ """
+ path = get_tmp_path()
+ with open(path, 'wb') as fh:
+ fh.write(model_bytesio)
+ fh.flush()
+ model = helper.load(path)
+ os.unlink(path)
+ return model
+
+
+def serialize_model_to_BytesIO(model, helper):
+ """ Serialize a model to a BytesIO object.
+
+ :param model: The model object.
+ :type model: return type of helper.load
+ :param helper: The helper object for the model.
+ :type helper: :class:`fedn.utils.helperbase.HelperBase`
+ :return: A BytesIO object containing the model.
+ :rtype: :class:`io.BytesIO`
+ """
+ outfile_name = helper.save(model)
+
+ a = BytesIO()
+ a.seek(0, 0)
+ with open(outfile_name, 'rb') as f:
+ a.write(f.read())
+ a.seek(0)
+ os.unlink(outfile_name)
+ return a
+
+
class ModelService(rpc.ModelServiceServicer):
""" Service for handling download and upload of models to the server.
"""
def __init__(self):
- self.models = TempModelStorage()
+ self.temp_model_storage = TempModelStorage()
def exist(self, model_id):
""" Check if a model exists on the server.
@@ -24,50 +106,7 @@ def exist(self, model_id):
:param model_id: The model id.
:return: True if the model exists, else False.
"""
- return self.models.exist(model_id)
-
- def get_tmp_path(self):
- """ Return a temporary output path compatible with save_model, load_model. """
- fd, path = tempfile.mkstemp()
- os.close(fd)
- return path
-
- def load_model_from_BytesIO(self, model_bytesio, helper):
- """ Load a model from a BytesIO object.
-
- :param model_bytesio: A BytesIO object containing the model.
- :type model_bytesio: :class:`io.BytesIO`
- :param helper: The helper object for the model.
- :type helper: :class:`fedn.utils.helperbase.HelperBase`
- :return: The model object.
- :rtype: return type of helper.load
- """
- path = self.get_tmp_path()
- with open(path, 'wb') as fh:
- fh.write(model_bytesio)
- fh.flush()
- model = helper.load(path)
- os.unlink(path)
- return model
-
- def serialize_model_to_BytesIO(self, model, helper):
- """ Serialize a model to a BytesIO object.
-
- :param model: The model object.
- :type model: return type of helper.load
- :param helper: The helper object for the model.
- :type helper: :class:`fedn.utils.helperbase.HelperBase`
- :return: A BytesIO object containing the model.
- :rtype: :class:`io.BytesIO`
- """
- outfile_name = helper.save(model)
-
- a = BytesIO()
- a.seek(0, 0)
- with open(outfile_name, 'rb') as f:
- a.write(f.read())
- os.unlink(outfile_name)
- return a
+ return self.temp_model_storage.exist(model_id)
def get_model(self, id):
""" Download model with id 'id' from server.
@@ -99,37 +138,9 @@ def set_model(self, model, id):
:param id: The model id.
:type id: str
"""
- if not isinstance(model, BytesIO):
- bt = BytesIO()
-
- written_total = 0
- for d in model.stream(32 * 1024):
- written = bt.write(d)
- written_total += written
- else:
- bt = model
-
- bt.seek(0, 0)
-
- def upload_request_generator(mdl):
- """
-
- :param mdl:
- """
- while True:
- b = mdl.read(CHUNK_SIZE)
- if b:
- result = fedn.ModelRequest(
- data=b, id=id, status=fedn.ModelStatus.IN_PROGRESS)
- else:
- result = fedn.ModelRequest(
- id=id, data=None, status=fedn.ModelStatus.OK)
- yield result
- if not b:
- break
-
+ bt = model_as_bytesIO(model)
# TODO: Check result
- _ = self.Upload(upload_request_generator(bt), self)
+ _ = self.Upload(upload_request_generator(bt, id), self)
# Model Service
def Upload(self, request_iterator, context):
@@ -146,16 +157,16 @@ def Upload(self, request_iterator, context):
result = None
for request in request_iterator:
if request.status == fedn.ModelStatus.IN_PROGRESS:
- self.models.get_ptr(request.id).write(request.data)
- self.models.set_model_metadata(request.id, fedn.ModelStatus.IN_PROGRESS)
+ self.temp_model_storage.get_ptr(request.id).write(request.data)
+ self.temp_model_storage.set_model_metadata(request.id, fedn.ModelStatus.IN_PROGRESS)
if request.status == fedn.ModelStatus.OK and not request.data:
result = fedn.ModelResponse(id=request.id, status=fedn.ModelStatus.OK,
message="Got model successfully.")
- # self.models_metadata.update({request.id: fedn.ModelStatus.OK})
- self.models.set_model_metadata(request.id, fedn.ModelStatus.OK)
- self.models.get_ptr(request.id).flush()
- self.models.get_ptr(request.id).close()
+ # self.temp_model_storage_metadata.update({request.id: fedn.ModelStatus.OK})
+ self.temp_model_storage.set_model_metadata(request.id, fedn.ModelStatus.OK)
+ self.temp_model_storage.get_ptr(request.id).flush()
+ self.temp_model_storage.get_ptr(request.id).close()
return result
def Download(self, request, context):
@@ -170,7 +181,7 @@ def Download(self, request, context):
"""
logger.debug("grpc.ModelService.Download: Called")
try:
- if self.models.get_model_metadata(request.id) != fedn.ModelStatus.OK:
+ if self.temp_model_storage.get_model_metadata(request.id) != fedn.ModelStatus.OK:
logger.error("Error file is not ready")
yield fedn.ModelResponse(id=request.id, data=None, status=fedn.ModelStatus.FAILED)
except Exception:
@@ -178,7 +189,7 @@ def Download(self, request, context):
yield fedn.ModelResponse(id=request.id, data=None, status=fedn.ModelStatus.FAILED)
try:
- obj = self.models.get(request.id)
+ obj = self.temp_model_storage.get(request.id)
with obj as f:
while True:
piece = f.read(CHUNK_SIZE)
diff --git a/fedn/fedn/network/combiner/roundhandler.py b/fedn/fedn/network/combiner/roundhandler.py
index a3acaf20e..416f12188 100644
--- a/fedn/fedn/network/combiner/roundhandler.py
+++ b/fedn/fedn/network/combiner/roundhandler.py
@@ -6,6 +6,8 @@
from fedn.common.log_config import logger
from fedn.network.combiner.aggregators.aggregatorbase import get_aggregator
+from fedn.network.combiner.modelservice import (load_model_from_BytesIO,
+ serialize_model_to_BytesIO)
from fedn.utils.helpers.helpers import get_helper
@@ -69,7 +71,7 @@ def load_model_update(self, helper, model_id):
model_str = self.load_model_update_str(model_id)
if model_str:
try:
- model = self.modelservice.load_model_from_BytesIO(model_str.getbuffer(), helper)
+ model = load_model_from_BytesIO(model_str.getbuffer(), helper)
except IOError:
logger.warning(
"AGGREGATOR({}): Failed to load model!".format(self.name))
@@ -89,7 +91,7 @@ def load_model_update_str(self, model_id, retry=3):
:rtype: class: `io.BytesIO`
"""
# Try reading model update from local disk/combiner memory
- model_str = self.modelservice.models.get(model_id)
+ model_str = self.modelservice.temp_model_storage.get(model_id)
# And if we cannot access that, try downloading from the server
if model_str is None:
model_str = self.modelservice.get_model(model_id)
@@ -206,7 +208,7 @@ def stage_model(self, model_id, timeout_retry=3, retry=2):
"""
# If the model is already in memory at the server we do not need to do anything.
- if self.modelservice.models.exist(model_id):
+ if self.modelservice.temp_model_storage.exist(model_id):
logger.info("ROUNDCONTROL: Model already exists in memory, skipping model staging.")
return
logger.info("ROUNDCONTROL: Model Staging, fetching model from storage...")
@@ -320,7 +322,7 @@ def execute_training_round(self, config):
data['config'] = config
data['round_id'] = config['round_id']
- # Make sure the model to update is available on this combiner.
+ # Download model to update and set in temp storage.
self.stage_model(config['model_id'])
clients = self._assign_round_clients(self.server.max_clients)
@@ -333,17 +335,16 @@ def execute_training_round(self, config):
if model is not None:
helper = get_helper(config['helper_type'])
- a = self.modelservice.serialize_model_to_BytesIO(model, helper)
- # Send aggregated model to server
- model_id = str(uuid.uuid4())
- self.modelservice.set_model(a, model_id)
+ a = serialize_model_to_BytesIO(model, helper)
+ model_id = self.storage.set_model(a.read(), is_file=False)
a.close()
data['model_id'] = model_id
logger.info(
"ROUNDCONTROL: TRAINING ROUND COMPLETED. Aggregated model id: {}, Job id: {}".format(model_id, config['_job_id']))
- self.modelservice.models.delete(config['model_id'])
+ # Delete temp model
+ self.modelservice.temp_model_storage.delete(config['model_id'])
return data
def run(self, polling_interval=1.0):
diff --git a/fedn/fedn/network/controller/control.py b/fedn/fedn/network/controller/control.py
index 842760625..a2ff6f4cc 100644
--- a/fedn/fedn/network/controller/control.py
+++ b/fedn/fedn/network/controller/control.py
@@ -6,7 +6,9 @@
from tenacity import (retry, retry_if_exception_type, stop_after_delay,
wait_random)
+from fedn.common.log_config import logger
from fedn.network.combiner.interfaces import CombinerUnavailableError
+from fedn.network.combiner.modelservice import load_model_from_BytesIO
from fedn.network.controller.controlbase import ControlBase
from fedn.network.state import ReducerState
@@ -328,44 +330,34 @@ def reduce(self, combiners):
for combiner in combiners:
name = combiner['name']
model_id = combiner['model_id']
- # TODO: Handle inactive RPC error in get_model and raise specific error
- print(
- "REDUCER: Fetching model ({model_id}) from combiner {name}".format(
- model_id=model_id, name=name
- ),
- flush=True,
- )
+
+ logger.info("Fetching model ({}) from model repository".format(model_id))
+
try:
tic = time.time()
- combiner_interface = self.get_combiner(name)
- data = combiner_interface.get_model(model_id)
+ data = self.model_repository.get_model(model_id)
meta['time_fetch_model'] += (time.time() - tic)
except Exception as e:
- print(
- "REDUCER: Failed to fetch model from combiner {}: {}".format(
- name, e
- ),
- flush=True,
- )
+ logger.error("Failed to fetch model from model repository {}: {}".format(name, e))
data = None
if data is not None:
try:
tic = time.time()
helper = self.get_helper()
- data.seek(0)
- model_next = helper.load(data)
+ model_next = load_model_from_BytesIO(data, helper)
meta["time_load_model"] += time.time() - tic
tic = time.time()
model = helper.increment_average(model, model_next, i, i)
meta["time_aggregate_model"] += time.time() - tic
except Exception:
tic = time.time()
- data.seek(0)
- model = helper.load(data)
+ model = load_model_from_BytesIO(data, helper)
meta["time_aggregate_model"] += time.time() - tic
i = i + 1
+ self.model_repository.delete_model(model_id)
+
return model, meta
def infer_instruct(self, config):
diff --git a/fedn/fedn/network/storage/s3/miniorepository.py b/fedn/fedn/network/storage/s3/miniorepository.py
index b8873188a..dcdf5c1f6 100644
--- a/fedn/fedn/network/storage/s3/miniorepository.py
+++ b/fedn/fedn/network/storage/s3/miniorepository.py
@@ -58,6 +58,9 @@ def get_artifact(self, instance_name, bucket):
return data.read()
except Exception as e:
raise Exception("Could not fetch data from bucket, {}".format(e))
+ finally:
+ data.close()
+ data.release_conn()
def get_artifact_stream(self, instance_name, bucket):
@@ -84,12 +87,12 @@ def list_artifacts(self, bucket):
"Could not list models in bucket {}".format(bucket))
return objects
- def delete_artifact(self, instance_name, bucket=[]):
+ def delete_artifact(self, instance_name, bucket):
""" Delete object with name instance_name from buckets.
:param instance_name: The object name
- :param bucket: List of buckets to delete from
- :type bucket: list
+ :param bucket: Buckets to delete from
+ :type bucket: str
"""
try:
diff --git a/fedn/fedn/network/storage/s3/repository.py b/fedn/fedn/network/storage/s3/repository.py
index e2f057821..d7d455341 100644
--- a/fedn/fedn/network/storage/s3/repository.py
+++ b/fedn/fedn/network/storage/s3/repository.py
@@ -56,6 +56,18 @@ def set_model(self, model, is_file=True):
raise
return str(model_id)
+ def delete_model(self, model_id):
+ """ Delete model.
+
+ :param model_id: The id of the model to delete
+ :type model_id: str
+ """
+ try:
+ self.client.delete_artifact(model_id, bucket=self.model_bucket)
+ except Exception:
+ logger.error("Failed to delete model {} repository.".format(model_id))
+ raise
+
def set_compute_package(self, name, compute_package, is_file=True):
""" Upload compute package.