From ab88fe54f23a815b0ed6a019d4ce5dee6d6e79c9 Mon Sep 17 00:00:00 2001 From: Fredrik Wrede Date: Fri, 24 May 2024 11:36:10 +0200 Subject: [PATCH] add inference endpoint api --- fedn/network/api/v1/__init__.py | 3 +- fedn/network/api/v1/inference_routes.py | 39 +++++++++++++++++++++++++ fedn/network/combiner/roundhandler.py | 2 +- fedn/network/controller/control.py | 3 +- 4 files changed, 44 insertions(+), 3 deletions(-) create mode 100644 fedn/network/api/v1/inference_routes.py diff --git a/fedn/network/api/v1/__init__.py b/fedn/network/api/v1/__init__.py index bb8d8d33c..0e05dd249 100644 --- a/fedn/network/api/v1/__init__.py +++ b/fedn/network/api/v1/__init__.py @@ -1,5 +1,6 @@ from fedn.network.api.v1.client_routes import bp as client_bp from fedn.network.api.v1.combiner_routes import bp as combiner_bp +from fedn.network.api.v1.inference_routes import bp as inference_bp from fedn.network.api.v1.model_routes import bp as model_bp from fedn.network.api.v1.package_routes import bp as package_bp from fedn.network.api.v1.round_routes import bp as round_bp @@ -7,4 +8,4 @@ from fedn.network.api.v1.status_routes import bp as status_bp from fedn.network.api.v1.validation_routes import bp as validation_bp -_routes = [client_bp, combiner_bp, model_bp, package_bp, round_bp, session_bp, status_bp, validation_bp] +_routes = [client_bp, combiner_bp, model_bp, package_bp, round_bp, session_bp, status_bp, validation_bp, inference_bp] diff --git a/fedn/network/api/v1/inference_routes.py b/fedn/network/api/v1/inference_routes.py new file mode 100644 index 000000000..41ffb92b7 --- /dev/null +++ b/fedn/network/api/v1/inference_routes.py @@ -0,0 +1,39 @@ +import threading + +from flask import Blueprint, jsonify, request + +from fedn.network.api.auth import jwt_auth_required +from fedn.network.api.shared import control +from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs, + get_typed_list_headers, mdb) +from fedn.network.storage.statestore.stores.session_store import SessionStore +from fedn.network.storage.statestore.stores.shared import EntityNotFound + +from .model_routes import model_store + +bp = Blueprint("inference", __name__, url_prefix=f"/api/{api_version}/infer") + + +@bp.route("/start", methods=["POST"]) +@jwt_auth_required(role="admin") +def start_session(): + """Start a new inference session. + param: session_id: The session id to start. + type: session_id: str + param: rounds: The number of rounds to run. + type: rounds: int + """ + try: + data = request.json if request.headers["Content-Type"] == "application/json" else request.form.to_dict() + session_id: str = data.get("session_id") + + if not session_id or session_id == "": + return jsonify({"message": "Session ID is required"}), 400 + + session_config = {"session_id": session_id} + + threading.Thread(target=control.inference_session, kwargs={"config":session_config}).start() + + return jsonify({"message": "Session started"}), 200 + except Exception as e: + return jsonify({"message": str(e)}), 500 diff --git a/fedn/network/combiner/roundhandler.py b/fedn/network/combiner/roundhandler.py index 6af9366e4..ef0f6076f 100644 --- a/fedn/network/combiner/roundhandler.py +++ b/fedn/network/combiner/roundhandler.py @@ -424,7 +424,7 @@ def run(self, polling_interval=1.0): elif round_config["task"] == "validation": self.execute_validation_round(session_id, model_id) elif round_config["task"] == "inference": - logger.info("Inference task not yet implemented.") + self.execute_inference_round(session_id, model_id) else: logger.warning("config contains unkown task type.") else: diff --git a/fedn/network/controller/control.py b/fedn/network/controller/control.py index b634e3a9f..235ae78ca 100644 --- a/fedn/network/controller/control.py +++ b/fedn/network/controller/control.py @@ -202,12 +202,13 @@ def inference_session(self, config: RoundConfig) -> None: logger.warning("Inference round cannot start, no combiners connected!") return - if not config["model_id"]: + if not "model_id" in config.keys(): config["model_id"]= self.statestore.get_latest_model() config["committed_at"] = datetime.datetime.now() config["task"] = "inference" config["rounds"] = str(1) + config["clients_required"] = 1 participating_combiners = self.get_participating_combiners(config)