Skip to content

Commit

Permalink
add inference endpoint api
Browse files Browse the repository at this point in the history
  • Loading branch information
Wrede committed May 24, 2024
1 parent 512a4ca commit ab88fe5
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 3 deletions.
3 changes: 2 additions & 1 deletion fedn/network/api/v1/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from fedn.network.api.v1.client_routes import bp as client_bp
from fedn.network.api.v1.combiner_routes import bp as combiner_bp
from fedn.network.api.v1.inference_routes import bp as inference_bp
from fedn.network.api.v1.model_routes import bp as model_bp
from fedn.network.api.v1.package_routes import bp as package_bp
from fedn.network.api.v1.round_routes import bp as round_bp
from fedn.network.api.v1.session_routes import bp as session_bp
from fedn.network.api.v1.status_routes import bp as status_bp
from fedn.network.api.v1.validation_routes import bp as validation_bp

_routes = [client_bp, combiner_bp, model_bp, package_bp, round_bp, session_bp, status_bp, validation_bp]
_routes = [client_bp, combiner_bp, model_bp, package_bp, round_bp, session_bp, status_bp, validation_bp, inference_bp]
39 changes: 39 additions & 0 deletions fedn/network/api/v1/inference_routes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import threading

from flask import Blueprint, jsonify, request

from fedn.network.api.auth import jwt_auth_required
from fedn.network.api.shared import control
from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs,
get_typed_list_headers, mdb)
from fedn.network.storage.statestore.stores.session_store import SessionStore
from fedn.network.storage.statestore.stores.shared import EntityNotFound

from .model_routes import model_store

bp = Blueprint("inference", __name__, url_prefix=f"/api/{api_version}/infer")


@bp.route("/start", methods=["POST"])
@jwt_auth_required(role="admin")
def start_session():
"""Start a new inference session.
param: session_id: The session id to start.
type: session_id: str
param: rounds: The number of rounds to run.
type: rounds: int
"""
try:
data = request.json if request.headers["Content-Type"] == "application/json" else request.form.to_dict()
session_id: str = data.get("session_id")

if not session_id or session_id == "":
return jsonify({"message": "Session ID is required"}), 400

session_config = {"session_id": session_id}

threading.Thread(target=control.inference_session, kwargs={"config":session_config}).start()

return jsonify({"message": "Session started"}), 200
except Exception as e:
return jsonify({"message": str(e)}), 500
2 changes: 1 addition & 1 deletion fedn/network/combiner/roundhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def run(self, polling_interval=1.0):
elif round_config["task"] == "validation":
self.execute_validation_round(session_id, model_id)
elif round_config["task"] == "inference":
logger.info("Inference task not yet implemented.")
self.execute_inference_round(session_id, model_id)
else:
logger.warning("config contains unkown task type.")
else:
Expand Down
3 changes: 2 additions & 1 deletion fedn/network/controller/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,12 +202,13 @@ def inference_session(self, config: RoundConfig) -> None:
logger.warning("Inference round cannot start, no combiners connected!")
return

if not config["model_id"]:
if not "model_id" in config.keys():
config["model_id"]= self.statestore.get_latest_model()

config["committed_at"] = datetime.datetime.now()
config["task"] = "inference"
config["rounds"] = str(1)
config["clients_required"] = 1

participating_combiners = self.get_participating_combiners(config)

Expand Down

0 comments on commit ab88fe5

Please sign in to comment.