From b3310ac877f3c454129bc59881d45340d35b4639 Mon Sep 17 00:00:00 2001 From: Fredrik Wrede Date: Mon, 10 Jun 2024 14:35:39 +0000 Subject: [PATCH] format --- fedn/network/combiner/combiner.py | 8 +++----- fedn/network/combiner/roundhandler.py | 14 ++++++++------ fedn/network/controller/control.py | 17 +++++++---------- fedn/network/controller/controlbase.py | 11 ++++------- 4 files changed, 22 insertions(+), 28 deletions(-) diff --git a/fedn/network/combiner/combiner.py b/fedn/network/combiner/combiner.py index a053cc3ca..70755ac6b 100644 --- a/fedn/network/combiner/combiner.py +++ b/fedn/network/combiner/combiner.py @@ -12,8 +12,7 @@ import fedn.network.grpc.fedn_pb2 as fedn import fedn.network.grpc.fedn_pb2_grpc as rpc -from fedn.common.log_config import (logger, set_log_level_from_string, - set_log_stream) +from fedn.common.log_config import logger, set_log_level_from_string, set_log_stream from fedn.network.combiner.connect import ConnectorCombiner, Status from fedn.network.combiner.modelservice import ModelService from fedn.network.combiner.roundhandler import RoundConfig, RoundHandler @@ -66,7 +65,6 @@ def __init__(self, config): # Client queues self.clients = {} - # Validate combiner name match = re.search(VALID_NAME_REGEX, config["name"]) if not match: @@ -196,7 +194,7 @@ def request_model_validation(self, session_id, model_id, clients=[]): else: logger.info("Sent model validation request for model {} to {} clients".format(request.model_id, len(clients))) - def request_model_inference(self, session_id: str, model_id: str, clients: list=[]) -> None: + def request_model_inference(self, session_id: str, model_id: str, clients: list = []) -> None: """Ask clients to perform inference on the model. :param model_id: the model id to perform inference on @@ -250,7 +248,7 @@ def _send_request_type(self, request_type, session_id, model_id, config=None, cl if len(clients) == 0: # TODO: add inference clients type clients = self.get_active_validators() - + # TODO: if inference, request.data should be user-defined data/parameters for client in clients: diff --git a/fedn/network/combiner/roundhandler.py b/fedn/network/combiner/roundhandler.py index ef0f6076f..ef9029de9 100644 --- a/fedn/network/combiner/roundhandler.py +++ b/fedn/network/combiner/roundhandler.py @@ -8,15 +8,14 @@ from fedn.common.log_config import logger from fedn.network.combiner.aggregators.aggregatorbase import get_aggregator -from fedn.network.combiner.modelservice import (load_model_from_BytesIO, - serialize_model_to_BytesIO) +from fedn.network.combiner.modelservice import load_model_from_BytesIO, serialize_model_to_BytesIO from fedn.utils.helpers.helpers import get_helper from fedn.utils.parameters import Parameters class RoundConfig(TypedDict): """Round configuration. - + :param _job_id: A universally unique identifier for the round. Set by Combiner. :type _job_id: str :param committed_at: The time the round was committed. Set by Controller. @@ -47,6 +46,7 @@ class RoundConfig(TypedDict): :param aggregator: The aggregator type. :type aggregator: str """ + _job_id: str committed_at: str task: str @@ -62,6 +62,8 @@ class RoundConfig(TypedDict): session_id: str helper_type: str aggregator: str + + class ModelUpdateError(Exception): pass @@ -246,7 +248,7 @@ def _validation_round(self, session_id, model_id, clients): :type model_id: str """ self.server.request_model_validation(session_id, model_id, clients=clients) - + def _inference_round(self, session_id: str, model_id: str, clients: list): """Send model inference requests to clients. @@ -346,7 +348,7 @@ def execute_validation_round(self, session_id, model_id): self.stage_model(model_id) validators = self._assign_round_clients(self.server.max_clients, type="validators") self._validation_round(session_id, model_id, validators) - + def execute_inference_round(self, session_id: str, model_id: str) -> None: """Coordinate inference rounds as specified in config. @@ -423,7 +425,7 @@ def run(self, polling_interval=1.0): self.server.statestore.set_round_combiner_data(round_meta) elif round_config["task"] == "validation": self.execute_validation_round(session_id, model_id) - elif round_config["task"] == "inference": + elif round_config["task"] == "inference": self.execute_inference_round(session_id, model_id) else: logger.warning("config contains unkown task type.") diff --git a/fedn/network/controller/control.py b/fedn/network/controller/control.py index 235ae78ca..c7a6d1c26 100644 --- a/fedn/network/controller/control.py +++ b/fedn/network/controller/control.py @@ -2,10 +2,8 @@ import datetime import time import uuid -from typing import TypedDict -from tenacity import (retry, retry_if_exception_type, stop_after_delay, - wait_random) +from tenacity import retry, retry_if_exception_type, stop_after_delay, wait_random from fedn.common.log_config import logger from fedn.network.combiner.interfaces import CombinerUnavailableError @@ -185,7 +183,7 @@ def session(self, config: RoundConfig) -> None: # TODO: Report completion of session self.set_session_status(config["session_id"], "Finished") self._state = ReducerState.idle - + def inference_session(self, config: RoundConfig) -> None: """Execute a new inference session. @@ -193,7 +191,6 @@ def inference_session(self, config: RoundConfig) -> None: :type config: InferenceConfig :return: None """ - if self._state == ReducerState.instructing: logger.info("Controller already in INSTRUCTING state. A session is in progress.") return @@ -201,10 +198,10 @@ def inference_session(self, config: RoundConfig) -> None: if len(self.network.get_combiners()) < 1: logger.warning("Inference round cannot start, no combiners connected!") return - - if not "model_id" in config.keys(): - config["model_id"]= self.statestore.get_latest_model() - + + if "model_id" not in config.keys(): + config["model_id"] = self.statestore.get_latest_model() + config["committed_at"] = datetime.datetime.now() config["task"] = "inference" config["rounds"] = str(1) @@ -216,7 +213,7 @@ def inference_session(self, config: RoundConfig) -> None: round_start = self.evaluate_round_start_policy(participating_combiners) if round_start: - logger.info("Inference round start policy met, {} participating combiners.".format(len(participating_combiners))) + logger.info("Inference round start policy met, {} participating combiners.".format(len(participating_combiners))) for combiner, _ in participating_combiners: combiner.submit(config) logger.info("Inference round submitted to combiner {}".format(combiner)) diff --git a/fedn/network/controller/controlbase.py b/fedn/network/controller/controlbase.py index ba5c72276..141848b78 100644 --- a/fedn/network/controller/controlbase.py +++ b/fedn/network/controller/controlbase.py @@ -114,14 +114,12 @@ def idle(self): return False def get_model_info(self): - """:return: - """ + """:return:""" return self.statestore.get_model_trail() # TODO: remove use statestore.get_events() instead def get_events(self): - """:return: - """ + """:return:""" return self.statestore.get_events() def get_latest_round_id(self): @@ -136,8 +134,7 @@ def get_latest_round(self): return round def get_compute_package_name(self): - """:return: - """ + """:return:""" definition = self.statestore.get_compute_package() if definition: try: @@ -164,7 +161,7 @@ def get_compute_package(self, compute_package=""): else: return None - def create_session(self, config: RoundConfig, status: str="Initialized") -> None: + def create_session(self, config: RoundConfig, status: str = "Initialized") -> None: """Initialize a new session in backend db.""" if "session_id" not in config.keys(): session_id = uuid.uuid4()