Skip to content

Commit

Permalink
handle refresh token
Browse files Browse the repository at this point in the history
  • Loading branch information
Wrede committed Mar 19, 2024
1 parent 6a40493 commit fe9029d
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 26 deletions.
2 changes: 2 additions & 0 deletions fedn/fedn/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
FEDN_AUTH_WHITELIST_URL_PREFIX = os.environ.get('FEDN_AUTH_WHITELIST_URL_PREFIX', False)
FEDN_JWT_ALGORITHM = os.environ.get('FEDN_JWT_ALGORITHM', 'HS256')
FEDN_AUTH_SCHEME = os.environ.get('FEDN_AUTH_SCHEME', 'Token')
FEDN_AUTH_REFRESH_TOKEN_URI = os.environ.get('FEDN_AUTH_REFRESH_TOKEN_URI', False)
FEDN_AUTH_REFRESH_TOKEN = os.environ.get('FEDN_AUTH_REFRESH_TOKEN', False)
FEDN_CUSTOM_URL_PREFIX = os.environ.get('FEDN_CUSTOM_URL_PREFIX', '')


Expand Down
84 changes: 60 additions & 24 deletions fedn/fedn/network/clients/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import fedn.network.grpc.fedn_pb2 as fedn
import fedn.network.grpc.fedn_pb2_grpc as rpc
from fedn.common.config import FEDN_AUTH_SCHEME
from fedn.common.log_config import (logger, set_log_level_from_string,
set_log_stream)
from fedn.network.clients.connect import ConnectorClient, Status
Expand Down Expand Up @@ -112,7 +113,7 @@ def _assign(self):
while True:
status, response = self.connector.assign()
if status == Status.TryAgain:
logger.info(response)
logger.info("Assignment request failed. Retrying in 5 seconds.")
time.sleep(5)
continue
if status == Status.Assigned:
Expand All @@ -125,7 +126,10 @@ def _assign(self):
logger.critical(response)
sys.exit("Exiting: UnMatchedConfig")
time.sleep(5)

# If token was refreshed, update the config
if self.config['token'] != self.connector.token:
self.config['token'] = self.connector.token
self._add_grpc_metadata('authorization', f"{FEDN_AUTH_SCHEME} {self.config['token']}")
logger.info("Assignment successfully received.")
logger.info("Received combiner configuration: {}".format(client_config))
return client_config
Expand Down Expand Up @@ -178,10 +182,9 @@ def _connect(self, client_config):
host = client_config['host']
# Add host to gRPC metadata
self._add_grpc_metadata('grpc-server', host)
auth_scheme = os.environ.get('FEDN_AUTH_SCHEME', 'Token')
if self.config['token']:
self._add_grpc_metadata('authorization', f"{auth_scheme} {self.config['token']}")
logger.info("Client using metadata: {}.".format(self.metadata))
self._add_grpc_metadata('authorization', f"{FEDN_AUTH_SCHEME} {self.config['token']}")
logger.debug("Client using metadata: {}.".format(self.metadata))
port = client_config['port']
secure = False
if client_config['fqdn'] is not None:
Expand Down Expand Up @@ -373,21 +376,24 @@ def get_model_from_combiner(self, id, timeout=20):
request.sender.name = self.name
request.sender.role = fedn.WORKER

for part in self.modelStub.Download(request, metadata=self.metadata):

if part.status == fedn.ModelStatus.IN_PROGRESS:
data.write(part.data)
try:
for part in self.modelStub.Download(request, metadata=self.metadata):

if part.status == fedn.ModelStatus.OK:
return data
if part.status == fedn.ModelStatus.IN_PROGRESS:
data.write(part.data)

if part.status == fedn.ModelStatus.FAILED:
return None
if part.status == fedn.ModelStatus.OK:
return data

if part.status == fedn.ModelStatus.UNKNOWN:
if time.time() - time_start >= timeout:
if part.status == fedn.ModelStatus.FAILED:
return None
continue

if part.status == fedn.ModelStatus.UNKNOWN:
if time.time() - time_start >= timeout:
return None
continue
except grpc.RpcError as e:
logger.critical(f"GRPC: An error occurred during model download: {e}")

return data

Expand All @@ -412,7 +418,10 @@ def send_model_to_combiner(self, model, id):

bt.seek(0, 0)

result = self.modelStub.Upload(upload_request_generator(bt, id), metadata=self.metadata)
try:
result = self.modelStub.Upload(upload_request_generator(bt, id), metadata=self.metadata)
except grpc.RpcError as e:
logger.critical(f"GRPC: An error occurred during model upload: {e}")

return result

Expand Down Expand Up @@ -451,16 +460,26 @@ def _listen_to_task_stream(self):
# Handle gRPC errors
status_code = e.code()
if status_code == grpc.StatusCode.UNAVAILABLE:
logger.warning("GRPC server unavailable during model update request stream. Retrying.")
logger.warning("GRPC TaskStream: server unavailable during model update request stream. Retrying.")
# Retry after a delay
time.sleep(5)
if status_code == grpc.StatusCode.UNAUTHENTICATED:
details = e.details()
if details == 'Token expired':
logger.warning("GRPC TaskStream: Token expired. Reconnecting.")
self.detach()

if status_code == grpc.StatusCode.CANCELLED:
# Expected if the client is detached
logger.critical("GRPC TaskStream: Client detached from combiner. Atempting to reconnect.")

else:
# Log the error and continue
logger.error(f"An error occurred during model update request stream: {e}")
logger.error(f"GRPC TaskStream: An error occurred during model update request stream: {e}")

except Exception as ex:
# Handle other exceptions
logger.error(f"An error occurred during model update request stream: {ex}")
logger.error(f"GRPC TaskStream: An error occurred during model update request stream: {ex}")

# Detach if not attached
if not self._attached:
Expand Down Expand Up @@ -657,12 +676,15 @@ def process_request(self):
self.inbox.task_done()
except queue.Empty:
pass
except grpc.RpcError as e:
status_code = e.code()
logger.critical(f"GRPC process_request: An error occurred during process request: {e}")

def _handle_combiner_failure(self):
""" Register failed combiner connection."""
self._missed_heartbeat += 1
if self._missed_heartbeat > self.config['reconnect_after_missed_heartbeat']:
self.detach()()
self.detach()

def _send_heartbeat(self, update_frequency=2.0):
"""Send a heartbeat to the combiner.
Expand All @@ -680,8 +702,13 @@ def _send_heartbeat(self, update_frequency=2.0):
self._missed_heartbeat = 0
except grpc.RpcError as e:
status_code = e.code()
logger.warning("Client heartbeat: GRPC error, {}. Retrying.".format(
status_code.name))
if status_code == grpc.StatusCode.UNAVAILABLE:
logger.warning("GRPC hearbeat: server unavailable during send heartbeat. Retrying.")
if status_code == grpc.StatusCode.UNAUTHENTICATED:
details = e.details()
if details == 'Token expired':
logger.warning("GRPC hearbeat: Token expired. Reconnecting.")
self.detach()
logger.debug(e)
self._handle_combiner_failure()

Expand Down Expand Up @@ -717,7 +744,16 @@ def _send_status(self, msg, log_level=fedn.Status.INFO, type=None, request=None,
self.logs.append(
"{} {} LOG LEVEL {} MESSAGE {}".format(str(datetime.now()), status.sender.name, status.log_level,
status.status))
_ = self.connectorStub.SendStatus(status, metadata=self.metadata)
try:
_ = self.connectorStub.SendStatus(status, metadata=self.metadata)
except grpc.RpcError as e:
status_code = e.code()
if status_code == grpc.StatusCode.UNAVAILABLE:
logger.warning("GRPC SendStatus: server unavailable during send status.")
if status_code == grpc.StatusCode.UNAUTHENTICATED:
details = e.details()
if details == 'Token expired':
logger.warning("GRPC SendStatus: Token expired.")

def run(self):
""" Run the client. """
Expand Down
32 changes: 31 additions & 1 deletion fedn/fedn/network/clients/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

import requests

from fedn.common.config import FEDN_AUTH_SCHEME, FEDN_CUSTOM_URL_PREFIX
from fedn.common.config import (FEDN_AUTH_REFRESH_TOKEN,
FEDN_AUTH_REFRESH_TOKEN_URI, FEDN_AUTH_SCHEME,
FEDN_CUSTOM_URL_PREFIX)
from fedn.common.log_config import logger


Expand Down Expand Up @@ -94,6 +96,16 @@ def assign(self):
return Status.UnMatchedConfig, reason

if retval.status_code == 401:
if 'message' in retval.json():
reason = retval.json()['message']
logger.warning(reason)
if reason == 'Token expired':
status_code = self.refresh_token()
if status_code >= 200 and status_code < 204:
logger.info("Token refreshed.")
return Status.TryAgain, reason
else:
return Status.UnAuthorized, "Could not refresh token"
reason = "Unauthorized connection to reducer, make sure the correct token is set"
return Status.UnAuthorized, reason

Expand All @@ -116,3 +128,21 @@ def assign(self):
return Status.Assigned, retval.json()

return Status.Unassigned, None

def refresh_token(self):
"""
Refresh client token.
:return: Tuple with assingment status, combiner connection information if sucessful, else None.
:rtype: tuple(:class:`fedn.network.clients.connect.Status`, str)
"""
if not FEDN_AUTH_REFRESH_TOKEN_URI or not FEDN_AUTH_REFRESH_TOKEN:
logger.error("No refresh token URI/Token set, cannot refresh token.")
return 401

payload = requests.post(FEDN_AUTH_REFRESH_TOKEN_URI,
verify=self.verify,
allow_redirects=True,
json={'refresh': FEDN_AUTH_REFRESH_TOKEN})
self.token = payload.json()['access']
return payload.status_code
3 changes: 2 additions & 1 deletion fedn/fedn/network/grpc/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def intercept_service(self, continuation, handler_call_details):
return _unary_unary_rpc_terminator(grpc.StatusCode.UNAUTHENTICATED, f'Invalid token scheme, expected {FEDN_AUTH_SCHEME}')

token = token.split(' ')[1]
print(f"HANDLER: {handler_call_details}", flush=True)

try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[FEDN_JWT_ALGORITHM])
Expand All @@ -87,6 +86,8 @@ def intercept_service(self, continuation, handler_call_details):
return continuation(handler_call_details)
except jwt.InvalidTokenError:
return _unary_unary_rpc_terminator(grpc.StatusCode.UNAUTHENTICATED, 'Invalid token')
except jwt.ExpiredSignatureError:
return _unary_unary_rpc_terminator(grpc.StatusCode.UNAUTHENTICATED, 'Token expired')
except Exception as e:
logger.error(str(e))
return _unary_unary_rpc_terminator(grpc.StatusCode.UNKNOWN, str(e))

0 comments on commit fe9029d

Please sign in to comment.