Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/SK-866 | Inference TaskType workflow #614

Merged
merged 5 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion fedn/network/api/v1/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from fedn.network.api.v1.client_routes import bp as client_bp
from fedn.network.api.v1.combiner_routes import bp as combiner_bp
from fedn.network.api.v1.inference_routes import bp as inference_bp
from fedn.network.api.v1.model_routes import bp as model_bp
from fedn.network.api.v1.package_routes import bp as package_bp
from fedn.network.api.v1.round_routes import bp as round_bp
from fedn.network.api.v1.session_routes import bp as session_bp
from fedn.network.api.v1.status_routes import bp as status_bp
from fedn.network.api.v1.validation_routes import bp as validation_bp

_routes = [client_bp, combiner_bp, model_bp, package_bp, round_bp, session_bp, status_bp, validation_bp]
_routes = [client_bp, combiner_bp, model_bp, package_bp, round_bp, session_bp, status_bp, validation_bp, inference_bp]
39 changes: 39 additions & 0 deletions fedn/network/api/v1/inference_routes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import threading

from flask import Blueprint, jsonify, request

from fedn.network.api.auth import jwt_auth_required
from fedn.network.api.shared import control
from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs,
get_typed_list_headers, mdb)
from fedn.network.storage.statestore.stores.session_store import SessionStore
from fedn.network.storage.statestore.stores.shared import EntityNotFound

from .model_routes import model_store

bp = Blueprint("inference", __name__, url_prefix=f"/api/{api_version}/infer")


@bp.route("/start", methods=["POST"])
@jwt_auth_required(role="admin")
def start_session():
"""Start a new inference session.
param: session_id: The session id to start.
type: session_id: str
param: rounds: The number of rounds to run.
type: rounds: int
"""
try:
data = request.json if request.headers["Content-Type"] == "application/json" else request.form.to_dict()
session_id: str = data.get("session_id")

if not session_id or session_id == "":
return jsonify({"message": "Session ID is required"}), 400

session_config = {"session_id": session_id}

threading.Thread(target=control.inference_session, kwargs={"config":session_config}).start()

return jsonify({"message": "Session started"}), 200
except Exception as e:
return jsonify({"message": str(e)}), 500
8 changes: 5 additions & 3 deletions fedn/network/combiner/aggregators/aggregatorbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,11 @@ def _validate_model_update(self, model_update):
:return: True if the model update is valid, False otherwise.
:rtype: bool
"""
data = json.loads(model_update.meta)["training_metadata"]
if "num_examples" not in data.keys():
logger.error("AGGREGATOR({}): Model validation failed, num_examples missing in metadata.".format(self.name))
try:
data = json.loads(model_update.meta)["training_metadata"]
num_examples = data["num_examples"]
except KeyError as e:
logger.error("AGGREGATOR({}): Invalid model update, missing metadata.".format(self.name))
return False
return True

Expand Down
97 changes: 63 additions & 34 deletions fedn/network/combiner/combiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@

import fedn.network.grpc.fedn_pb2 as fedn
import fedn.network.grpc.fedn_pb2_grpc as rpc
from fedn.common.log_config import logger, set_log_level_from_string, set_log_stream
from fedn.common.log_config import (logger, set_log_level_from_string,
set_log_stream)
from fedn.network.combiner.connect import ConnectorCombiner, Status
from fedn.network.combiner.modelservice import ModelService
from fedn.network.combiner.roundhandler import RoundHandler
from fedn.network.combiner.roundhandler import RoundConfig, RoundHandler
from fedn.network.grpc.server import Server
from fedn.network.storage.s3.repository import Repository
from fedn.network.storage.statestore.mongostatestore import MongoStateStore
Expand Down Expand Up @@ -161,7 +162,7 @@ def __whoami(self, client, instance):
client.role = role_to_proto_role(instance.role)
return client

def request_model_update(self, config, clients=[]):
def request_model_update(self, session_id, model_id, config, clients=[]):
"""Ask clients to update the current global model.

:param config: the model configuration to send to clients
Expand All @@ -170,32 +171,14 @@ def request_model_update(self, config, clients=[]):
:type clients: list

"""
# The request to be added to the client queue
request = fedn.TaskRequest()
request.model_id = config["model_id"]
request.correlation_id = str(uuid.uuid4())
request.timestamp = str(datetime.now())
request.data = json.dumps(config)
request.type = fedn.StatusType.MODEL_UPDATE
request.session_id = config["session_id"]

request.sender.name = self.id
request.sender.role = fedn.COMBINER

if len(clients) == 0:
clients = self.get_active_trainers()

for client in clients:
request.receiver.name = client
request.receiver.role = fedn.WORKER
self._put_request_to_client_queue(request, fedn.Queue.TASK_QUEUE)
request, clients = self._send_request_type(fedn.StatusType.MODEL_UPDATE, session_id, model_id, config, clients)

if len(clients) < 20:
logger.info("Sent model update request for model {} to clients {}".format(request.model_id, clients))
else:
logger.info("Sent model update request for model {} to {} clients".format(request.model_id, len(clients)))

def request_model_validation(self, model_id, config, clients=[]):
def request_model_validation(self, session_id, model_id, clients=[]):
"""Ask clients to validate the current global model.

:param model_id: the model id to validate
Expand All @@ -206,30 +189,76 @@ def request_model_validation(self, model_id, config, clients=[]):
:type clients: list

"""
# The request to be added to the client queue
request, clients = self._send_request_type(fedn.StatusType.MODEL_VALIDATION, session_id, model_id, clients)

if len(clients) < 20:
logger.info("Sent model validation request for model {} to clients {}".format(request.model_id, clients))
else:
logger.info("Sent model validation request for model {} to {} clients".format(request.model_id, len(clients)))

def request_model_inference(self, session_id: str, model_id: str, clients: list=[]) -> None:
"""Ask clients to perform inference on the model.

:param model_id: the model id to perform inference on
:type model_id: str
:param config: the model configuration to send to clients
:type config: dict
:param clients: the clients to send the request to
:type clients: list

"""
request, clients = self._send_request_type(fedn.StatusType.INFERENCE, session_id, model_id, clients)

if len(clients) < 20:
logger.info("Sent model inference request for model {} to clients {}".format(request.model_id, clients))
else:
logger.info("Sent model inference request for model {} to {} clients".format(request.model_id, len(clients)))

def _send_request_type(self, request_type, session_id, model_id, config=None, clients=[]):
"""Send a request of a specific type to clients.

:param request_type: the type of request
:type request_type: :class:`fedn.network.grpc.fedn_pb2.StatusType`
:param model_id: the model id to send in the request
:type model_id: str
:param config: the model configuration to send to clients
:type config: dict
:param clients: the clients to send the request to
:type clients: list
:return: the request and the clients
:rtype: tuple
"""
request = fedn.TaskRequest()
request.model_id = model_id
request.correlation_id = str(uuid.uuid4())
request.timestamp = str(datetime.now())
# request.is_inference = (config['task'] == 'inference')
request.type = fedn.StatusType.MODEL_VALIDATION
request.type = request_type
request.session_id = session_id

request.sender.name = self.id
request.sender.role = fedn.COMBINER
request.session_id = config["session_id"]

if len(clients) == 0:
clients = self.get_active_validators()
if request_type == fedn.StatusType.MODEL_UPDATE:
request.data = json.dumps(config)
if len(clients) == 0:
clients = self.get_active_trainers()
elif request_type == fedn.StatusType.MODEL_VALIDATION:
if len(clients) == 0:
clients = self.get_active_validators()
elif request_type == fedn.StatusType.INFERENCE:
request.data = json.dumps(config)
if len(clients) == 0:
# TODO: add inference clients type
clients = self.get_active_validators()

# TODO: if inference, request.data should be user-defined data/parameters

for client in clients:
request.receiver.name = client
request.receiver.role = fedn.WORKER
self._put_request_to_client_queue(request, fedn.Queue.TASK_QUEUE)

if len(clients) < 20:
logger.info("Sent model validation request for model {} to clients {}".format(request.model_id, clients))
else:
logger.info("Sent model validation request for model {} to {} clients".format(request.model_id, len(clients)))
return request, clients

def get_active_trainers(self):
"""Get a list of active trainers.
Expand Down Expand Up @@ -410,7 +439,7 @@ def Start(self, control: fedn.ControlRequest, context):
"""
logger.info("grpc.Combiner.Start: Starting round")

config = {}
config = RoundConfig()
for parameter in control.parameter:
config.update({parameter.key: parameter.value})

Expand Down
3 changes: 2 additions & 1 deletion fedn/network/combiner/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import fedn.network.grpc.fedn_pb2 as fedn
import fedn.network.grpc.fedn_pb2_grpc as rpc
from fedn.network.combiner.roundhandler import RoundConfig


class CombinerUnavailableError(Exception):
Expand Down Expand Up @@ -202,7 +203,7 @@ def set_aggregator(self, aggregator):
else:
raise

def submit(self, config):
def submit(self, config: RoundConfig):
"""Submit a compute plan to the combiner.

:param config: The job configuration.
Expand Down
Loading
Loading