From 600b67e7d71bc2e1705e97141d43c443bfd130be Mon Sep 17 00:00:00 2001 From: Andreas Hellander Date: Wed, 15 May 2024 22:23:45 +0200 Subject: [PATCH 1/2] detach replaced by disconnect --- fedn/network/clients/client.py | 186 ++++++++++++++++++++++----------- 1 file changed, 123 insertions(+), 63 deletions(-) diff --git a/fedn/network/clients/client.py b/fedn/network/clients/client.py index 70fe005ff..45d416445 100644 --- a/fedn/network/clients/client.py +++ b/fedn/network/clients/client.py @@ -22,11 +22,13 @@ import fedn.network.grpc.fedn_pb2 as fedn import fedn.network.grpc.fedn_pb2_grpc as rpc from fedn.common.config import FEDN_AUTH_SCHEME, FEDN_PACKAGE_EXTRACT_DIR -from fedn.common.log_config import logger, set_log_level_from_string, set_log_stream +from fedn.common.log_config import (logger, set_log_level_from_string, + set_log_stream) from fedn.network.clients.connect import ConnectorClient, Status from fedn.network.clients.package import PackageRuntime from fedn.network.clients.state import ClientState, ClientStateToString -from fedn.network.combiner.modelservice import get_tmp_path, upload_request_generator +from fedn.network.combiner.modelservice import (get_tmp_path, + upload_request_generator) from fedn.utils.dispatcher import Dispatcher from fedn.utils.helpers.helpers import get_helper @@ -77,7 +79,8 @@ def __init__(self, config): # Validate client name match = re.search(VALID_NAME_REGEX, config["name"]) if not match: - raise ValueError("Unallowed character in client name. Allowed characters: a-z, A-Z, 0-9, _, -.") + raise ValueError( + "Unallowed character in client name. Allowed characters: a-z, A-Z, 0-9, _, -.") # Folder where the client will store downloaded compute package and logs self.name = config["name"] @@ -102,7 +105,8 @@ def __init__(self, config): self._initialize_helper(combiner_config) if not self.helper: - logger.warning("Failed to retrieve helper class settings: {}".format(combiner_config)) + logger.warning( + "Failed to retrieve helper class settings: {}".format(combiner_config)) self._subscribe_to_combiner(self.config) @@ -119,7 +123,8 @@ def assign(self): status, response = self.connector.assign() if status == Status.TryAgain: logger.warning(response) - logger.info("Assignment request failed. Retrying in 5 seconds.") + logger.info( + "Assignment request failed. Retrying in 5 seconds.") time.sleep(5) continue if status == Status.Assigned: @@ -133,7 +138,8 @@ def assign(self): sys.exit("Exiting: UnMatchedConfig") time.sleep(5) logger.info("Assignment successfully received.") - logger.info("Received combiner configuration: {}".format(combiner_config)) + logger.info( + "Received combiner configuration: {}".format(combiner_config)) return combiner_config def _add_grpc_metadata(self, key, value): @@ -152,7 +158,8 @@ def _add_grpc_metadata(self, key, value): for i, (k, v) in enumerate(self.metadata): if k == key: # Replace value - self.metadata = self.metadata[:i] + ((key, value),) + self.metadata[i + 1 :] + self.metadata = self.metadata[:i] + \ + ((key, value),) + self.metadata[i + 1:] return # Set metadata using tuple concatenation @@ -193,20 +200,26 @@ def connect(self, combiner_config): host = combiner_config["fqdn"] # assuming https if fqdn is used port = 443 - logger.info(f"Initiating connection to combiner host at: {host}:{port}") + logger.info( + f"Initiating connection to combiner host at: {host}:{port}") if combiner_config["certificate"]: - logger.info("Utilizing CA certificate for GRPC channel authentication.") + logger.info( + "Utilizing CA certificate for GRPC channel authentication.") secure = True - cert = base64.b64decode(combiner_config["certificate"]) # .decode('utf-8') + cert = base64.b64decode( + combiner_config["certificate"]) # .decode('utf-8') credentials = grpc.ssl_channel_credentials(root_certificates=cert) - channel = grpc.secure_channel("{}:{}".format(host, str(port)), credentials) + channel = grpc.secure_channel( + "{}:{}".format(host, str(port)), credentials) elif os.getenv("FEDN_GRPC_ROOT_CERT_PATH"): secure = True - logger.info("Using root certificate from environment variable for GRPC channel.") + logger.info( + "Using root certificate from environment variable for GRPC channel.") with open(os.environ["FEDN_GRPC_ROOT_CERT_PATH"], "rb") as f: credentials = grpc.ssl_channel_credentials(f.read()) - channel = grpc.secure_channel("{}:{}".format(host, str(port)), credentials) + channel = grpc.secure_channel( + "{}:{}".format(host, str(port)), credentials) elif self.config["secure"]: secure = True logger.info("Using CA certificate for GRPC channel.") @@ -216,9 +229,11 @@ def connect(self, combiner_config): if self.config["token"]: token = self.config["token"] auth_creds = grpc.metadata_call_credentials(GrpcAuth(token)) - channel = grpc.secure_channel("{}:{}".format(host, str(port)), grpc.composite_channel_credentials(credentials, auth_creds)) + channel = grpc.secure_channel("{}:{}".format( + host, str(port)), grpc.composite_channel_credentials(credentials, auth_creds)) else: - channel = grpc.secure_channel("{}:{}".format(host, str(port)), credentials) + channel = grpc.secure_channel( + "{}:{}".format(host, str(port)), credentials) else: logger.info("Using insecure GRPC channel.") if port == 443: @@ -231,9 +246,11 @@ def connect(self, combiner_config): self.combinerStub = rpc.CombinerStub(channel) self.modelStub = rpc.ModelServiceStub(channel) - logger.info("Successfully established {} connection to {}:{}".format("secure" if secure else "insecure", host, port)) + logger.info("Successfully established {} connection to {}:{}".format( + "secure" if secure else "insecure", host, port)) - logger.info("Using {} compute package.".format(combiner_config["package"])) + logger.info("Using {} compute package.".format( + combiner_config["package"])) self._connected = True @@ -266,10 +283,12 @@ def _subscribe_to_combiner(self, config): | client-combiner assignment behavior. """ # Start sending heartbeats to the combiner. - threading.Thread(target=self._send_heartbeat, kwargs={"update_frequency": config["heartbeat_interval"]}, daemon=True).start() + threading.Thread(target=self._send_heartbeat, kwargs={ + "update_frequency": config["heartbeat_interval"]}, daemon=True).start() # Start listening for combiner training and validation messages - threading.Thread(target=self._listen_to_task_stream, daemon=True).start() + threading.Thread(target=self._listen_to_task_stream, + daemon=True).start() self._connected = True # Start processing the client message inbox @@ -297,21 +316,25 @@ def _initialize_dispatcher(self, config): while tries > 0: retval = pr.download( - host=config["discover_host"], port=config["discover_port"], token=config["token"], force_ssl=config["force_ssl"], secure=config["secure"] + host=config["discover_host"], port=config["discover_port"], token=config[ + "token"], force_ssl=config["force_ssl"], secure=config["secure"] ) if retval: break time.sleep(60) - logger.warning("Compute package not available. Retrying in 60 seconds. {} attempts remaining.".format(tries)) + logger.warning( + "Compute package not available. Retrying in 60 seconds. {} attempts remaining.".format(tries)) tries -= 1 if retval: if "checksum" not in config: - logger.warning("Bypassing validation of package checksum. Ensure the package source is trusted.") + logger.warning( + "Bypassing validation of package checksum. Ensure the package source is trusted.") else: checks_out = pr.validate(config["checksum"]) if not checks_out: - logger.critical("Validation of local package failed. Client terminating.") + logger.critical( + "Validation of local package failed. Client terminating.") self.error_state = True return package_runpath = "" @@ -320,7 +343,8 @@ def _initialize_dispatcher(self, config): self.dispatcher = pr.dispatcher(package_runpath) try: - logger.info("Initiating Dispatcher with entrypoint set to: startup") + logger.info( + "Initiating Dispatcher with entrypoint set to: startup") activate_cmd = self.dispatcher._get_or_create_python_env() self.dispatcher.run_cmd("startup") except KeyError: @@ -345,7 +369,8 @@ def _initialize_dispatcher(self, config): # Get or create python environment activate_cmd = self.dispatcher._get_or_create_python_env() if activate_cmd: - logger.info("To activate the virtual environment, run: {}".format(activate_cmd)) + logger.info( + "To activate the virtual environment, run: {}".format(activate_cmd)) def get_model_from_combiner(self, id, timeout=20): """Fetch a model from the assigned combiner. @@ -378,7 +403,8 @@ def get_model_from_combiner(self, id, timeout=20): return None continue except grpc.RpcError as e: - logger.critical(f"GRPC: An error occurred during model download: {e}") + logger.critical( + f"GRPC: An error occurred during model download: {e}") return data @@ -404,9 +430,11 @@ def send_model_to_combiner(self, model, id): bt.seek(0, 0) try: - result = self.modelStub.Upload(upload_request_generator(bt, id), metadata=self.metadata) + result = self.modelStub.Upload( + upload_request_generator(bt, id), metadata=self.metadata) except grpc.RpcError as e: - logger.critical(f"GRPC: An error occurred during model upload: {e}") + logger.critical( + f"GRPC: An error occurred during model upload: {e}") return result @@ -426,7 +454,8 @@ def _listen_to_task_stream(self): try: for request in self.combinerStub.TaskStream(r, metadata=self.metadata): if request: - logger.debug("Received model update request from combiner: {}.".format(request)) + logger.debug( + "Received model update request from combiner: {}.".format(request)) if request.sender.role == fedn.COMBINER: # Process training request self.send_status( @@ -436,39 +465,46 @@ def _listen_to_task_stream(self): request=request, sesssion_id=request.session_id, ) - logger.info("Received model update request of type {} for model_id {}".format(request.type, request.model_id)) + logger.info("Received model update request of type {} for model_id {}".format( + request.type, request.model_id)) if request.type == fedn.StatusType.MODEL_UPDATE and self.config["trainer"]: self.inbox.put(("train", request)) elif request.type == fedn.StatusType.MODEL_VALIDATION and self.config["validator"]: self.inbox.put(("validate", request)) else: - logger.error("Unknown request type: {}".format(request.type)) + logger.error( + "Unknown request type: {}".format(request.type)) except grpc.RpcError as e: # Handle gRPC errors status_code = e.code() if status_code == grpc.StatusCode.UNAVAILABLE: - logger.warning("GRPC TaskStream: server unavailable during model update request stream. Retrying.") + logger.warning( + "GRPC TaskStream: server unavailable during model update request stream. Retrying.") # Retry after a delay time.sleep(5) if status_code == grpc.StatusCode.UNAUTHENTICATED: details = e.details() if details == "Token expired": - logger.warning("GRPC TaskStream: Token expired. Reconnecting.") - self.detach() + logger.warning( + "GRPC TaskStream: Token expired. Reconnecting.") + self.disconnect() if status_code == grpc.StatusCode.CANCELLED: - # Expected if the client is detached - logger.critical("GRPC TaskStream: Client detached from combiner. Atempting to reconnect.") + # Expected if the client is disconnected + logger.critical( + "GRPC TaskStream: Client disconnected from combiner. Trying to reconnect.") else: # Log the error and continue - logger.error(f"GRPC TaskStream: An error occurred during model update request stream: {e}") + logger.error( + f"GRPC TaskStream: An error occurred during model update request stream: {e}") except Exception as ex: # Handle other exceptions - logger.error(f"GRPC TaskStream: An error occurred during model update request stream: {ex}") + logger.error( + f"GRPC TaskStream: An error occurred during model update request stream: {ex}") # Detach if not attached if not self._connected: @@ -484,7 +520,8 @@ def _process_training_request(self, model_id: str, session_id: str = None): :return: The model id of the updated model, or None if the update failed. And a dict with metadata. :rtype: tuple """ - self.send_status("\t Starting processing of training request for model_id {}".format(model_id), sesssion_id=session_id) + self.send_status("\t Starting processing of training request for model_id {}".format( + model_id), sesssion_id=session_id) self.state = ClientState.training try: @@ -492,7 +529,8 @@ def _process_training_request(self, model_id: str, session_id: str = None): tic = time.time() mdl = self.get_model_from_combiner(str(model_id)) if mdl is None: - logger.error("Could not retrieve model from combiner. Aborting training request.") + logger.error( + "Could not retrieve model from combiner. Aborting training request.") return None, None meta["fetch_model"] = time.time() - tic @@ -529,7 +567,8 @@ def _process_training_request(self, model_id: str, session_id: str = None): os.unlink(outpath + "-metadata") except Exception as e: - logger.error("Could not process training request due to error: {}".format(e)) + logger.error( + "Could not process training request due to error: {}".format(e)) updated_model_id = None meta = {"status": "failed", "error": str(e)} @@ -555,12 +594,14 @@ def _process_validation_request(self, model_id: str, is_inference: bool, session else: cmd = "validate" - self.send_status(f"Processing {cmd} request for model_id {model_id}", sesssion_id=session_id) + self.send_status( + f"Processing {cmd} request for model_id {model_id}", sesssion_id=session_id) self.state = ClientState.validating try: model = self.get_model_from_combiner(str(model_id)) if model is None: - logger.error("Could not retrieve model from combiner. Aborting validation request.") + logger.error( + "Could not retrieve model from combiner. Aborting validation request.") return None inpath = self.helper.get_tmp_path() @@ -595,7 +636,8 @@ def process_request(self): if task_type == "train": tic = time.time() self.state = ClientState.training - model_id, meta = self._process_training_request(request.model_id, session_id=request.session_id) + model_id, meta = self._process_training_request( + request.model_id, session_id=request.session_id) if meta is not None: processing_time = time.time() - tic @@ -616,7 +658,8 @@ def process_request(self): update.meta = json.dumps(meta) try: - _ = self.combinerStub.SendModelUpdate(update, metadata=self.metadata) + _ = self.combinerStub.SendModelUpdate( + update, metadata=self.metadata) self.send_status( "Model update completed.", log_level=fedn.Status.AUDIT, @@ -626,10 +669,12 @@ def process_request(self): ) except grpc.RpcError as e: status_code = e.code() - logger.error("GRPC error, {}.".format(status_code.name)) + logger.error( + "GRPC error, {}.".format(status_code.name)) logger.debug(e) except ValueError as e: - logger.error("GRPC error, RPC channel closed. {}".format(e)) + logger.error( + "GRPC error, RPC channel closed. {}".format(e)) logger.debug(e) else: self.send_status( @@ -641,7 +686,8 @@ def process_request(self): elif task_type == "validate": self.state = ClientState.validating - metrics = self._process_validation_request(request.model_id, False, request.session_id) + metrics = self._process_validation_request( + request.model_id, False, request.session_id) if metrics is not None: # Send validation @@ -657,7 +703,8 @@ def process_request(self): validation.session_id = request.session_id try: - _ = self.combinerStub.SendModelValidation(validation, metadata=self.metadata) + _ = self.combinerStub.SendModelValidation( + validation, metadata=self.metadata) status_type = fedn.StatusType.MODEL_VALIDATION self.send_status( @@ -665,14 +712,17 @@ def process_request(self): ) except grpc.RpcError as e: status_code = e.code() - logger.error("GRPC error, {}.".format(status_code.name)) + logger.error( + "GRPC error, {}.".format(status_code.name)) logger.debug(e) except ValueError as e: - logger.error("GRPC error, RPC channel closed. {}".format(e)) + logger.error( + "GRPC error, RPC channel closed. {}".format(e)) logger.debug(e) else: self.send_status( - "Client {} failed to complete model validation.".format(self.name), + "Client {} failed to complete model validation.".format( + self.name), log_level=fedn.Status.WARNING, request=request, sesssion_id=request.session_id, @@ -683,20 +733,23 @@ def process_request(self): except queue.Empty: pass except grpc.RpcError as e: - logger.critical(f"GRPC process_request: An error occurred during process request: {e}") + logger.critical( + f"GRPC process_request: An error occurred during process request: {e}") def _send_heartbeat(self, update_frequency=2.0): """Send a heartbeat to the combiner. :param update_frequency: The frequency of the heartbeat in seconds. :type update_frequency: float - :return: None if the client is detached. + :return: None if the client is disconnected. :rtype: None """ while True: - heartbeat = fedn.Heartbeat(sender=fedn.Client(name=self.name, role=fedn.WORKER)) + heartbeat = fedn.Heartbeat(sender=fedn.Client( + name=self.name, role=fedn.WORKER)) try: - self.connectorStub.SendHeartbeat(heartbeat, metadata=self.metadata) + self.connectorStub.SendHeartbeat( + heartbeat, metadata=self.metadata) self._missed_heartbeat = 0 except grpc.RpcError as e: status_code = e.code() @@ -712,9 +765,11 @@ def _send_heartbeat(self, update_frequency=2.0): if status_code == grpc.StatusCode.UNAUTHENTICATED: details = e.details() if details == "Token expired": - logger.error("GRPC hearbeat: Token expired. Disconnecting.") + logger.error( + "GRPC hearbeat: Token expired. Disconnecting.") self.disconnect() - sys.exit("Unauthorized. Token expired. Please obtain a new token.") + sys.exit( + "Unauthorized. Token expired. Please obtain a new token.") logger.debug(e) time.sleep(update_frequency) @@ -751,13 +806,15 @@ def send_status(self, msg, log_level=fedn.Status.INFO, type=None, request=None, if request is not None: status.data = MessageToJson(request) - self.logs.append("{} {} LOG LEVEL {} MESSAGE {}".format(str(datetime.now()), status.sender.name, status.log_level, status.status)) + self.logs.append("{} {} LOG LEVEL {} MESSAGE {}".format( + str(datetime.now()), status.sender.name, status.log_level, status.status)) try: _ = self.connectorStub.SendStatus(status, metadata=self.metadata) except grpc.RpcError as e: status_code = e.code() if status_code == grpc.StatusCode.UNAVAILABLE: - logger.warning("GRPC SendStatus: server unavailable during send status.") + logger.warning( + "GRPC SendStatus: server unavailable during send status.") if status_code == grpc.StatusCode.UNAUTHENTICATED: details = e.details() if details == "Token expired": @@ -771,12 +828,15 @@ def run(self): while True: time.sleep(1) if cnt == 0: - logger.info("Client is active, waiting for model update requests.") + logger.info( + "Client is active, waiting for model update requests.") cnt = 1 if self.state != old_state: - logger.info("Client in {} state.".format(ClientStateToString(self.state))) + logger.info("Client in {} state.".format( + ClientStateToString(self.state))) if not self._connected: - logger.warning("Client lost connection to combiner. Attempting to reconnect to FEDn network.") + logger.warning( + "Client lost connection to combiner. Attempting to reconnect to FEDn network.") combiner_config = self.assign() self.connect(combiner_config) self._subscribe_to_combiner(self.config) From c02b87316b9b2b79c9f354e317c49dd28ef50353 Mon Sep 17 00:00:00 2001 From: Fredrik Wrede Date: Fri, 17 May 2024 12:43:40 +0000 Subject: [PATCH 2/2] format --- fedn/network/clients/client.py | 180 +++++++++++---------------------- 1 file changed, 60 insertions(+), 120 deletions(-) diff --git a/fedn/network/clients/client.py b/fedn/network/clients/client.py index 45d416445..df20a8956 100644 --- a/fedn/network/clients/client.py +++ b/fedn/network/clients/client.py @@ -22,13 +22,11 @@ import fedn.network.grpc.fedn_pb2 as fedn import fedn.network.grpc.fedn_pb2_grpc as rpc from fedn.common.config import FEDN_AUTH_SCHEME, FEDN_PACKAGE_EXTRACT_DIR -from fedn.common.log_config import (logger, set_log_level_from_string, - set_log_stream) +from fedn.common.log_config import logger, set_log_level_from_string, set_log_stream from fedn.network.clients.connect import ConnectorClient, Status from fedn.network.clients.package import PackageRuntime from fedn.network.clients.state import ClientState, ClientStateToString -from fedn.network.combiner.modelservice import (get_tmp_path, - upload_request_generator) +from fedn.network.combiner.modelservice import get_tmp_path, upload_request_generator from fedn.utils.dispatcher import Dispatcher from fedn.utils.helpers.helpers import get_helper @@ -79,8 +77,7 @@ def __init__(self, config): # Validate client name match = re.search(VALID_NAME_REGEX, config["name"]) if not match: - raise ValueError( - "Unallowed character in client name. Allowed characters: a-z, A-Z, 0-9, _, -.") + raise ValueError("Unallowed character in client name. Allowed characters: a-z, A-Z, 0-9, _, -.") # Folder where the client will store downloaded compute package and logs self.name = config["name"] @@ -105,8 +102,7 @@ def __init__(self, config): self._initialize_helper(combiner_config) if not self.helper: - logger.warning( - "Failed to retrieve helper class settings: {}".format(combiner_config)) + logger.warning("Failed to retrieve helper class settings: {}".format(combiner_config)) self._subscribe_to_combiner(self.config) @@ -123,8 +119,7 @@ def assign(self): status, response = self.connector.assign() if status == Status.TryAgain: logger.warning(response) - logger.info( - "Assignment request failed. Retrying in 5 seconds.") + logger.info("Assignment request failed. Retrying in 5 seconds.") time.sleep(5) continue if status == Status.Assigned: @@ -138,8 +133,7 @@ def assign(self): sys.exit("Exiting: UnMatchedConfig") time.sleep(5) logger.info("Assignment successfully received.") - logger.info( - "Received combiner configuration: {}".format(combiner_config)) + logger.info("Received combiner configuration: {}".format(combiner_config)) return combiner_config def _add_grpc_metadata(self, key, value): @@ -158,8 +152,7 @@ def _add_grpc_metadata(self, key, value): for i, (k, v) in enumerate(self.metadata): if k == key: # Replace value - self.metadata = self.metadata[:i] + \ - ((key, value),) + self.metadata[i + 1:] + self.metadata = self.metadata[:i] + ((key, value),) + self.metadata[i + 1 :] return # Set metadata using tuple concatenation @@ -200,26 +193,20 @@ def connect(self, combiner_config): host = combiner_config["fqdn"] # assuming https if fqdn is used port = 443 - logger.info( - f"Initiating connection to combiner host at: {host}:{port}") + logger.info(f"Initiating connection to combiner host at: {host}:{port}") if combiner_config["certificate"]: - logger.info( - "Utilizing CA certificate for GRPC channel authentication.") + logger.info("Utilizing CA certificate for GRPC channel authentication.") secure = True - cert = base64.b64decode( - combiner_config["certificate"]) # .decode('utf-8') + cert = base64.b64decode(combiner_config["certificate"]) # .decode('utf-8') credentials = grpc.ssl_channel_credentials(root_certificates=cert) - channel = grpc.secure_channel( - "{}:{}".format(host, str(port)), credentials) + channel = grpc.secure_channel("{}:{}".format(host, str(port)), credentials) elif os.getenv("FEDN_GRPC_ROOT_CERT_PATH"): secure = True - logger.info( - "Using root certificate from environment variable for GRPC channel.") + logger.info("Using root certificate from environment variable for GRPC channel.") with open(os.environ["FEDN_GRPC_ROOT_CERT_PATH"], "rb") as f: credentials = grpc.ssl_channel_credentials(f.read()) - channel = grpc.secure_channel( - "{}:{}".format(host, str(port)), credentials) + channel = grpc.secure_channel("{}:{}".format(host, str(port)), credentials) elif self.config["secure"]: secure = True logger.info("Using CA certificate for GRPC channel.") @@ -229,11 +216,9 @@ def connect(self, combiner_config): if self.config["token"]: token = self.config["token"] auth_creds = grpc.metadata_call_credentials(GrpcAuth(token)) - channel = grpc.secure_channel("{}:{}".format( - host, str(port)), grpc.composite_channel_credentials(credentials, auth_creds)) + channel = grpc.secure_channel("{}:{}".format(host, str(port)), grpc.composite_channel_credentials(credentials, auth_creds)) else: - channel = grpc.secure_channel( - "{}:{}".format(host, str(port)), credentials) + channel = grpc.secure_channel("{}:{}".format(host, str(port)), credentials) else: logger.info("Using insecure GRPC channel.") if port == 443: @@ -246,11 +231,9 @@ def connect(self, combiner_config): self.combinerStub = rpc.CombinerStub(channel) self.modelStub = rpc.ModelServiceStub(channel) - logger.info("Successfully established {} connection to {}:{}".format( - "secure" if secure else "insecure", host, port)) + logger.info("Successfully established {} connection to {}:{}".format("secure" if secure else "insecure", host, port)) - logger.info("Using {} compute package.".format( - combiner_config["package"])) + logger.info("Using {} compute package.".format(combiner_config["package"])) self._connected = True @@ -283,12 +266,10 @@ def _subscribe_to_combiner(self, config): | client-combiner assignment behavior. """ # Start sending heartbeats to the combiner. - threading.Thread(target=self._send_heartbeat, kwargs={ - "update_frequency": config["heartbeat_interval"]}, daemon=True).start() + threading.Thread(target=self._send_heartbeat, kwargs={"update_frequency": config["heartbeat_interval"]}, daemon=True).start() # Start listening for combiner training and validation messages - threading.Thread(target=self._listen_to_task_stream, - daemon=True).start() + threading.Thread(target=self._listen_to_task_stream, daemon=True).start() self._connected = True # Start processing the client message inbox @@ -316,25 +297,21 @@ def _initialize_dispatcher(self, config): while tries > 0: retval = pr.download( - host=config["discover_host"], port=config["discover_port"], token=config[ - "token"], force_ssl=config["force_ssl"], secure=config["secure"] + host=config["discover_host"], port=config["discover_port"], token=config["token"], force_ssl=config["force_ssl"], secure=config["secure"] ) if retval: break time.sleep(60) - logger.warning( - "Compute package not available. Retrying in 60 seconds. {} attempts remaining.".format(tries)) + logger.warning("Compute package not available. Retrying in 60 seconds. {} attempts remaining.".format(tries)) tries -= 1 if retval: if "checksum" not in config: - logger.warning( - "Bypassing validation of package checksum. Ensure the package source is trusted.") + logger.warning("Bypassing validation of package checksum. Ensure the package source is trusted.") else: checks_out = pr.validate(config["checksum"]) if not checks_out: - logger.critical( - "Validation of local package failed. Client terminating.") + logger.critical("Validation of local package failed. Client terminating.") self.error_state = True return package_runpath = "" @@ -343,8 +320,7 @@ def _initialize_dispatcher(self, config): self.dispatcher = pr.dispatcher(package_runpath) try: - logger.info( - "Initiating Dispatcher with entrypoint set to: startup") + logger.info("Initiating Dispatcher with entrypoint set to: startup") activate_cmd = self.dispatcher._get_or_create_python_env() self.dispatcher.run_cmd("startup") except KeyError: @@ -369,8 +345,7 @@ def _initialize_dispatcher(self, config): # Get or create python environment activate_cmd = self.dispatcher._get_or_create_python_env() if activate_cmd: - logger.info( - "To activate the virtual environment, run: {}".format(activate_cmd)) + logger.info("To activate the virtual environment, run: {}".format(activate_cmd)) def get_model_from_combiner(self, id, timeout=20): """Fetch a model from the assigned combiner. @@ -403,8 +378,7 @@ def get_model_from_combiner(self, id, timeout=20): return None continue except grpc.RpcError as e: - logger.critical( - f"GRPC: An error occurred during model download: {e}") + logger.critical(f"GRPC: An error occurred during model download: {e}") return data @@ -430,11 +404,9 @@ def send_model_to_combiner(self, model, id): bt.seek(0, 0) try: - result = self.modelStub.Upload( - upload_request_generator(bt, id), metadata=self.metadata) + result = self.modelStub.Upload(upload_request_generator(bt, id), metadata=self.metadata) except grpc.RpcError as e: - logger.critical( - f"GRPC: An error occurred during model upload: {e}") + logger.critical(f"GRPC: An error occurred during model upload: {e}") return result @@ -454,8 +426,7 @@ def _listen_to_task_stream(self): try: for request in self.combinerStub.TaskStream(r, metadata=self.metadata): if request: - logger.debug( - "Received model update request from combiner: {}.".format(request)) + logger.debug("Received model update request from combiner: {}.".format(request)) if request.sender.role == fedn.COMBINER: # Process training request self.send_status( @@ -465,46 +436,39 @@ def _listen_to_task_stream(self): request=request, sesssion_id=request.session_id, ) - logger.info("Received model update request of type {} for model_id {}".format( - request.type, request.model_id)) + logger.info("Received model update request of type {} for model_id {}".format(request.type, request.model_id)) if request.type == fedn.StatusType.MODEL_UPDATE and self.config["trainer"]: self.inbox.put(("train", request)) elif request.type == fedn.StatusType.MODEL_VALIDATION and self.config["validator"]: self.inbox.put(("validate", request)) else: - logger.error( - "Unknown request type: {}".format(request.type)) + logger.error("Unknown request type: {}".format(request.type)) except grpc.RpcError as e: # Handle gRPC errors status_code = e.code() if status_code == grpc.StatusCode.UNAVAILABLE: - logger.warning( - "GRPC TaskStream: server unavailable during model update request stream. Retrying.") + logger.warning("GRPC TaskStream: server unavailable during model update request stream. Retrying.") # Retry after a delay time.sleep(5) if status_code == grpc.StatusCode.UNAUTHENTICATED: details = e.details() if details == "Token expired": - logger.warning( - "GRPC TaskStream: Token expired. Reconnecting.") + logger.warning("GRPC TaskStream: Token expired. Reconnecting.") self.disconnect() if status_code == grpc.StatusCode.CANCELLED: # Expected if the client is disconnected - logger.critical( - "GRPC TaskStream: Client disconnected from combiner. Trying to reconnect.") + logger.critical("GRPC TaskStream: Client disconnected from combiner. Trying to reconnect.") else: # Log the error and continue - logger.error( - f"GRPC TaskStream: An error occurred during model update request stream: {e}") + logger.error(f"GRPC TaskStream: An error occurred during model update request stream: {e}") except Exception as ex: # Handle other exceptions - logger.error( - f"GRPC TaskStream: An error occurred during model update request stream: {ex}") + logger.error(f"GRPC TaskStream: An error occurred during model update request stream: {ex}") # Detach if not attached if not self._connected: @@ -520,8 +484,7 @@ def _process_training_request(self, model_id: str, session_id: str = None): :return: The model id of the updated model, or None if the update failed. And a dict with metadata. :rtype: tuple """ - self.send_status("\t Starting processing of training request for model_id {}".format( - model_id), sesssion_id=session_id) + self.send_status("\t Starting processing of training request for model_id {}".format(model_id), sesssion_id=session_id) self.state = ClientState.training try: @@ -529,8 +492,7 @@ def _process_training_request(self, model_id: str, session_id: str = None): tic = time.time() mdl = self.get_model_from_combiner(str(model_id)) if mdl is None: - logger.error( - "Could not retrieve model from combiner. Aborting training request.") + logger.error("Could not retrieve model from combiner. Aborting training request.") return None, None meta["fetch_model"] = time.time() - tic @@ -567,8 +529,7 @@ def _process_training_request(self, model_id: str, session_id: str = None): os.unlink(outpath + "-metadata") except Exception as e: - logger.error( - "Could not process training request due to error: {}".format(e)) + logger.error("Could not process training request due to error: {}".format(e)) updated_model_id = None meta = {"status": "failed", "error": str(e)} @@ -594,14 +555,12 @@ def _process_validation_request(self, model_id: str, is_inference: bool, session else: cmd = "validate" - self.send_status( - f"Processing {cmd} request for model_id {model_id}", sesssion_id=session_id) + self.send_status(f"Processing {cmd} request for model_id {model_id}", sesssion_id=session_id) self.state = ClientState.validating try: model = self.get_model_from_combiner(str(model_id)) if model is None: - logger.error( - "Could not retrieve model from combiner. Aborting validation request.") + logger.error("Could not retrieve model from combiner. Aborting validation request.") return None inpath = self.helper.get_tmp_path() @@ -636,8 +595,7 @@ def process_request(self): if task_type == "train": tic = time.time() self.state = ClientState.training - model_id, meta = self._process_training_request( - request.model_id, session_id=request.session_id) + model_id, meta = self._process_training_request(request.model_id, session_id=request.session_id) if meta is not None: processing_time = time.time() - tic @@ -658,8 +616,7 @@ def process_request(self): update.meta = json.dumps(meta) try: - _ = self.combinerStub.SendModelUpdate( - update, metadata=self.metadata) + _ = self.combinerStub.SendModelUpdate(update, metadata=self.metadata) self.send_status( "Model update completed.", log_level=fedn.Status.AUDIT, @@ -669,12 +626,10 @@ def process_request(self): ) except grpc.RpcError as e: status_code = e.code() - logger.error( - "GRPC error, {}.".format(status_code.name)) + logger.error("GRPC error, {}.".format(status_code.name)) logger.debug(e) except ValueError as e: - logger.error( - "GRPC error, RPC channel closed. {}".format(e)) + logger.error("GRPC error, RPC channel closed. {}".format(e)) logger.debug(e) else: self.send_status( @@ -686,8 +641,7 @@ def process_request(self): elif task_type == "validate": self.state = ClientState.validating - metrics = self._process_validation_request( - request.model_id, False, request.session_id) + metrics = self._process_validation_request(request.model_id, False, request.session_id) if metrics is not None: # Send validation @@ -703,8 +657,7 @@ def process_request(self): validation.session_id = request.session_id try: - _ = self.combinerStub.SendModelValidation( - validation, metadata=self.metadata) + _ = self.combinerStub.SendModelValidation(validation, metadata=self.metadata) status_type = fedn.StatusType.MODEL_VALIDATION self.send_status( @@ -712,17 +665,14 @@ def process_request(self): ) except grpc.RpcError as e: status_code = e.code() - logger.error( - "GRPC error, {}.".format(status_code.name)) + logger.error("GRPC error, {}.".format(status_code.name)) logger.debug(e) except ValueError as e: - logger.error( - "GRPC error, RPC channel closed. {}".format(e)) + logger.error("GRPC error, RPC channel closed. {}".format(e)) logger.debug(e) else: self.send_status( - "Client {} failed to complete model validation.".format( - self.name), + "Client {} failed to complete model validation.".format(self.name), log_level=fedn.Status.WARNING, request=request, sesssion_id=request.session_id, @@ -733,8 +683,7 @@ def process_request(self): except queue.Empty: pass except grpc.RpcError as e: - logger.critical( - f"GRPC process_request: An error occurred during process request: {e}") + logger.critical(f"GRPC process_request: An error occurred during process request: {e}") def _send_heartbeat(self, update_frequency=2.0): """Send a heartbeat to the combiner. @@ -745,11 +694,9 @@ def _send_heartbeat(self, update_frequency=2.0): :rtype: None """ while True: - heartbeat = fedn.Heartbeat(sender=fedn.Client( - name=self.name, role=fedn.WORKER)) + heartbeat = fedn.Heartbeat(sender=fedn.Client(name=self.name, role=fedn.WORKER)) try: - self.connectorStub.SendHeartbeat( - heartbeat, metadata=self.metadata) + self.connectorStub.SendHeartbeat(heartbeat, metadata=self.metadata) self._missed_heartbeat = 0 except grpc.RpcError as e: status_code = e.code() @@ -765,11 +712,9 @@ def _send_heartbeat(self, update_frequency=2.0): if status_code == grpc.StatusCode.UNAUTHENTICATED: details = e.details() if details == "Token expired": - logger.error( - "GRPC hearbeat: Token expired. Disconnecting.") + logger.error("GRPC hearbeat: Token expired. Disconnecting.") self.disconnect() - sys.exit( - "Unauthorized. Token expired. Please obtain a new token.") + sys.exit("Unauthorized. Token expired. Please obtain a new token.") logger.debug(e) time.sleep(update_frequency) @@ -806,15 +751,13 @@ def send_status(self, msg, log_level=fedn.Status.INFO, type=None, request=None, if request is not None: status.data = MessageToJson(request) - self.logs.append("{} {} LOG LEVEL {} MESSAGE {}".format( - str(datetime.now()), status.sender.name, status.log_level, status.status)) + self.logs.append("{} {} LOG LEVEL {} MESSAGE {}".format(str(datetime.now()), status.sender.name, status.log_level, status.status)) try: _ = self.connectorStub.SendStatus(status, metadata=self.metadata) except grpc.RpcError as e: status_code = e.code() if status_code == grpc.StatusCode.UNAVAILABLE: - logger.warning( - "GRPC SendStatus: server unavailable during send status.") + logger.warning("GRPC SendStatus: server unavailable during send status.") if status_code == grpc.StatusCode.UNAUTHENTICATED: details = e.details() if details == "Token expired": @@ -828,15 +771,12 @@ def run(self): while True: time.sleep(1) if cnt == 0: - logger.info( - "Client is active, waiting for model update requests.") + logger.info("Client is active, waiting for model update requests.") cnt = 1 if self.state != old_state: - logger.info("Client in {} state.".format( - ClientStateToString(self.state))) + logger.info("Client in {} state.".format(ClientStateToString(self.state))) if not self._connected: - logger.warning( - "Client lost connection to combiner. Attempting to reconnect to FEDn network.") + logger.warning("Client lost connection to combiner. Attempting to reconnect to FEDn network.") combiner_config = self.assign() self.connect(combiner_config) self._subscribe_to_combiner(self.config)