From e018d1449ab3fb8110f539a86b9924ed28eca080 Mon Sep 17 00:00:00 2001 From: viktorvaladi Date: Fri, 31 Jan 2025 11:50:55 +0100 Subject: [PATCH] fix so server functions code resets by sessions --- fedn/network/combiner/hooks/hooks.py | 2 -- fedn/network/combiner/roundhandler.py | 6 +++++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/fedn/network/combiner/hooks/hooks.py b/fedn/network/combiner/hooks/hooks.py index a20ddb27f..24153ac81 100644 --- a/fedn/network/combiner/hooks/hooks.py +++ b/fedn/network/combiner/hooks/hooks.py @@ -122,8 +122,6 @@ def HandleProvidedFunctions(self, request: fedn.ProvidedFunctionsResponse, conte :rtype: :class:`fedn.network.grpc.fedn_pb2.ProvidedFunctionsResponse` """ logger.info("Receieved provided functions request.") - if self.implemented_functions is not None: - return fedn.ProvidedFunctionsResponse(available_functions=self.implemented_functions) server_functions_code = request.function_code self.server_functions_code = server_functions_code self.implemented_functions = {} diff --git a/fedn/network/combiner/roundhandler.py b/fedn/network/combiner/roundhandler.py index 20ed336c9..23a135aec 100644 --- a/fedn/network/combiner/roundhandler.py +++ b/fedn/network/combiner/roundhandler.py @@ -98,6 +98,7 @@ def __init__(self, server): self.server_functions = inspect.getsource(ServerFunctions) self.update_handler = UpdateHandler(modelservice=modelservice) self.hook_interface = CombinerHookInterface() + self.session_id = "" def set_aggregator(self, aggregator): self.aggregator = get_aggregator(aggregator, self.update_handler) @@ -324,7 +325,10 @@ def execute_training_round(self, config): # Download model to update and set in temp storage. self.stage_model(config["model_id"]) - provided_functions = self.hook_interface.provided_functions(self.server_functions) + # If new session, update server function code and check which functions are provided + if self.session_id != config["session_id"]: + self.session_id = config["session_id"] + provided_functions = self.hook_interface.provided_functions(self.server_functions) if provided_functions.get("client_selection", False): clients = self.hook_interface.client_selection(clients=self.server.get_active_trainers())