From 5d9a2e3ada729f6ce834ab4471bf30998c9ec388 Mon Sep 17 00:00:00 2001 From: Niklas Date: Wed, 15 May 2024 15:53:32 +0200 Subject: [PATCH] complement method added to set date and id if not passed --- .../statestore/stores/session_store.py | 27 +++++++------------ 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/fedn/network/storage/statestore/stores/session_store.py b/fedn/network/storage/statestore/stores/session_store.py index d69d426bd..b25a34319 100644 --- a/fedn/network/storage/statestore/stores/session_store.py +++ b/fedn/network/storage/statestore/stores/session_store.py @@ -1,3 +1,5 @@ +import datetime +import uuid from typing import Any, Dict, List, Tuple import pymongo @@ -30,9 +32,6 @@ def __init__(self, database: Database, collection: str): super().__init__(database, collection) def _validate_session_config(self, session_config: dict) -> Tuple[bool, str]: - if "session_id" not in session_config or session_config["session_id"] == "": - return False, "session_config.session_id is required" - if "aggregator" not in session_config: return False, "session_config.aggregator is required" @@ -81,29 +80,21 @@ def _validate_session_config(self, session_config: dict) -> Tuple[bool, str]: return True, "" def _validate(self, item: Session) -> Tuple[bool, str]: - if "session_id" not in item or item["session_id"] == "": - return False, "session_id is required" - - if not isinstance(item["session_id"], str): - return False, "session_id must be a string" - if "session_config" not in item or item["session_config"] is None: return False, "session_config is required" elif not isinstance(item["session_config"], dict): return False, "session_config must be a dict" session_config = item["session_config"] - session_id = item["session_id"] - success, result = self._validate_session_config(session_config) + return self._validate_session_config(session_config) - if not success: - return False, result + def _complement(self, item: Session): + item["status"] = "Created" + item["committed_at"] = datetime.datetime.now() - if session_id != session_config["session_id"]: - return False, "session_id must match session_config.session_id" - - return True, "" + if "session_id" not in item or item["session_id"] == "" or not isinstance(item["session_id"], str): + item["session_id"] = str(uuid.uuid4()) def get(self, id: str, use_typing: bool = False) -> Session: @@ -140,6 +131,8 @@ def add(self, item: Session)-> Tuple[bool, Any]: if not valid: return False, message + self._complement(item) + return super().add(item) def delete(self, id: str) -> bool: