Skip to content

Commit

Permalink
complement method added to set date and id if not passed
Browse files Browse the repository at this point in the history
  • Loading branch information
niklastheman committed May 15, 2024
1 parent 79175d6 commit 5d9a2e3
Showing 1 changed file with 10 additions and 17 deletions.
27 changes: 10 additions & 17 deletions fedn/network/storage/statestore/stores/session_store.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import datetime
import uuid
from typing import Any, Dict, List, Tuple

import pymongo
Expand Down Expand Up @@ -30,9 +32,6 @@ def __init__(self, database: Database, collection: str):
super().__init__(database, collection)

def _validate_session_config(self, session_config: dict) -> Tuple[bool, str]:
if "session_id" not in session_config or session_config["session_id"] == "":
return False, "session_config.session_id is required"

if "aggregator" not in session_config:
return False, "session_config.aggregator is required"

Expand Down Expand Up @@ -81,29 +80,21 @@ def _validate_session_config(self, session_config: dict) -> Tuple[bool, str]:
return True, ""

def _validate(self, item: Session) -> Tuple[bool, str]:
if "session_id" not in item or item["session_id"] == "":
return False, "session_id is required"

if not isinstance(item["session_id"], str):
return False, "session_id must be a string"

if "session_config" not in item or item["session_config"] is None:
return False, "session_config is required"
elif not isinstance(item["session_config"], dict):
return False, "session_config must be a dict"

session_config = item["session_config"]
session_id = item["session_id"]

success, result = self._validate_session_config(session_config)
return self._validate_session_config(session_config)

if not success:
return False, result
def _complement(self, item: Session):
item["status"] = "Created"
item["committed_at"] = datetime.datetime.now()

if session_id != session_config["session_id"]:
return False, "session_id must match session_config.session_id"

return True, ""
if "session_id" not in item or item["session_id"] == "" or not isinstance(item["session_id"], str):
item["session_id"] = str(uuid.uuid4())


def get(self, id: str, use_typing: bool = False) -> Session:
Expand Down Expand Up @@ -140,6 +131,8 @@ def add(self, item: Session)-> Tuple[bool, Any]:
if not valid:
return False, message

self._complement(item)

return super().add(item)

def delete(self, id: str) -> bool:
Expand Down

0 comments on commit 5d9a2e3

Please sign in to comment.