Skip to content

Commit

Permalink
Feature/SK-1462 | Add Combiner DTO (#832)
Browse files Browse the repository at this point in the history
  • Loading branch information
carl-andersson authored Mar 4, 2025
1 parent bf8f3c4 commit a6a1df0
Show file tree
Hide file tree
Showing 10 changed files with 228 additions and 296 deletions.
10 changes: 6 additions & 4 deletions fedn/network/api/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ def get_combiners(self):
:return: list of combiners objects
:rtype: list(:class:`fedn.network.combiner.interfaces.CombinerInterface`)
"""
data = self.combiner_store.list(limit=0, skip=0, sort_key=None)
result = self.combiner_store.select(limit=0, skip=0, sort_key=None)
combiners = []
for c in data["result"]:
name = c["name"].upper()
for combiner in result:
name = combiner.name.upper()
# General certificate handling, same for all combiners.
if os.environ.get("FEDN_GRPC_CERT_PATH"):
with open(os.environ.get("FEDN_GRPC_CERT_PATH"), "rb") as f:
Expand All @@ -63,7 +63,9 @@ def get_combiners(self):
cert = f.read()
else:
cert = None
combiners.append(CombinerInterface(c["parent"], c["name"], c["address"], c["fqdn"], c["port"], certificate=cert, ip=c["ip"]))
combiners.append(
CombinerInterface(combiner.parent, combiner.name, combiner.address, combiner.fqdn, combiner.port, certificate=cert, ip=combiner.ip)
)

return combiners

Expand Down
14 changes: 9 additions & 5 deletions fedn/network/api/v1/combiner_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ def get_combiners():

kwargs = request.args.to_dict()

response = combiner_store.list(limit, skip, sort_key, sort_order, **kwargs)
combiners = combiner_store.select(limit, skip, sort_key, sort_order, **kwargs)
count = combiner_store.count(**kwargs)
response = {"count": count, "result": [combiner.to_dict() for combiner in combiners]}

return jsonify(response), 200
except Exception as e:
Expand Down Expand Up @@ -184,7 +186,9 @@ def list_combiners():

kwargs = get_post_data_to_kwargs(request)

response = combiner_store.list(limit, skip, sort_key, sort_order, **kwargs)
combiners = combiner_store.select(limit, skip, sort_key, sort_order, **kwargs)
count = combiner_store.count(**kwargs)
response = {"count": count, "result": [combiner.to_dict() for combiner in combiners]}

return jsonify(response), 200
except Exception as e:
Expand Down Expand Up @@ -327,8 +331,8 @@ def get_combiner(id: str):
try:
response = combiner_store.get(id)
if response is None:
return jsonify({"message": f"Entity with id: {id} not found"}), 404
return jsonify(response), 200
return jsonify({"message": f"Entity with id: {id} not found"}), 404
return jsonify(response.to_dict()), 200
except Exception as e:
logger.error(f"An unexpected error occurred: {e}")
return jsonify({"message": "An unexpected error occurred"}), 500
Expand Down Expand Up @@ -369,7 +373,7 @@ def delete_combiner(id: str):
try:
result: bool = combiner_store.delete(id)
if not result:
return jsonify({"message": f"Entity with id: {id} not found"}), 404
return jsonify({"message": f"Entity with id: {id} not found"}), 404
msg = "Combiner deleted" if result else "Combiner not deleted"
return jsonify({"message": msg}), 200
except Exception as e:
Expand Down
23 changes: 11 additions & 12 deletions fedn/network/combiner/combiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from fedn.network.combiner.shared import client_store, combiner_store, prediction_store, repository, round_store, status_store, validation_store
from fedn.network.grpc.server import Server, ServerConfig
from fedn.network.storage.statestore.stores.dto import ClientDTO
from fedn.network.storage.statestore.stores.dto.combiner import CombinerDTO

VALID_NAME_REGEX = "^[a-zA-Z0-9_-]*$"

Expand Down Expand Up @@ -109,19 +110,17 @@ def __init__(self, config):

self.round_store = round_store

# Add combiner to statestore
interface_config = {
"port": config["port"],
"fqdn": config["fqdn"],
"name": config["name"],
"address": config["host"],
"parent": "localhost",
"ip": "",
"updated_at": str(datetime.now()),
}
# Check if combiner already exists in statestore
if combiner_store.get(config["name"]) is None:
combiner_store.add(interface_config)
if combiner_store.get_by_name(config["name"]) is None:
new_combiner = CombinerDTO()
new_combiner.port = config["port"]
new_combiner.fqdn = config["fqdn"]
new_combiner.name = config["name"]
new_combiner.address = config["host"]
new_combiner.parent = "localhost"
new_combiner.ip = ""
new_combiner.updated_at = str(datetime.now())
combiner_store.add(new_combiner)

# Fetch all clients previously connected to the combiner
# If a client and a combiner goes down at the same time,
Expand Down
254 changes: 72 additions & 182 deletions fedn/network/storage/statestore/stores/combiner_store.py
Original file line number Diff line number Diff line change
@@ -1,217 +1,107 @@
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
from abc import abstractmethod
from typing import Dict, List

import pymongo
from bson import ObjectId
from pymongo.database import Database
from sqlalchemy import String, func, or_, select
from sqlalchemy.orm import Mapped, mapped_column

from fedn.network.storage.statestore.stores.sql.shared import MyAbstractBase
from fedn.network.storage.statestore.stores.store import MongoDBStore, SQLStore, Store

from .shared import from_document


class Combiner:
def __init__(
self,
id: str,
name: str,
address: str,
certificate: str,
config: dict,
fqdn: str,
ip: str,
key: str,
parent: dict,
port: int,
status: str,
updated_at: str,
):
self.id = id
self.name = name
self.address = address
self.certificate = certificate
self.config = config
self.fqdn = fqdn
self.ip = ip
self.key = key
self.parent = parent
self.port = port
self.status = status
self.updated_at = updated_at


class CombinerStore(Store[Combiner]):
pass


class MongoDBCombinerStore(MongoDBStore[Combiner]):
def __init__(self, database: Database, collection: str):
super().__init__(database, collection)

def get(self, id: str) -> Combiner:
"""Get an entity by id
param id: The id of the entity
type: str
description: The id of the entity, can be either the id or the name (property)
return: The entity
"""
if ObjectId.is_valid(id):
id_obj = ObjectId(id)
document = self.database[self.collection].find_one({"_id": id_obj})
else:
document = self.database[self.collection].find_one({"name": id})

if document is None:
return None
from fedn.network.storage.statestore.stores.dto import CombinerDTO
from fedn.network.storage.statestore.stores.new_store import MongoDBStore, SQLStore, Store, from_document
from fedn.network.storage.statestore.stores.sql.shared import CombinerModel, from_orm_model

return from_document(document)

def update(self, id: str, item: Combiner) -> bool:
raise NotImplementedError("Update not implemented for CombinerStore")
class CombinerStore(Store[CombinerDTO]):
@abstractmethod
def get_by_name(name: str) -> CombinerDTO:
pass

def add(self, item: Combiner) -> Tuple[bool, Any]:
return super().add(item)

def delete(self, id: str) -> bool:
if ObjectId.is_valid(id):
kwargs = {"_id": ObjectId(id)}
else:
return False
class MongoDBCombinerStore(CombinerStore, MongoDBStore):
def __init__(self, database: Database, collection: str):
super().__init__(database, collection, "id")

document = self.database[self.collection].find_one(kwargs)
def get(self, id: str) -> CombinerDTO:
obj = self.mongo_get(id)
if obj is None:
return None
return self._dto_from_document(obj)

if document is None:
return False

return super().delete(document["_id"])

def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDING, **kwargs) -> Dict[int, List[Combiner]]:
"""List entities
param limit: The maximum number of entities to return
type: int
param skip: The number of entities to skip
type: int
param sort_key: The key to sort by
type: str
param sort_order: The order to sort by
type: pymongo.DESCENDING | pymongo.ASCENDING
param kwargs: Additional query parameters
type: dict
example: {"key": "models"}
return: A dictionary with the count and the result
"""
response = super().list(limit, skip, sort_key or "updated_at", sort_order, **kwargs)

return response
def update(self, item: CombinerDTO):
raise NotImplementedError("Update not implemented for CombinerStore")

def count(self, **kwargs) -> int:
return super().count(**kwargs)
def add(self, item: CombinerDTO):
item_dict = item.to_db(exclude_unset=False)
success, obj = self.mongo_add(item_dict)
if success:
return success, self._dto_from_document(obj)
return success, obj

def delete(self, id: str) -> bool:
return self.mongo_delete(id)

class CombinerModel(MyAbstractBase):
__tablename__ = "combiners"
def select(self, limit: int = 0, skip: int = 0, sort_key: str = None, sort_order=pymongo.DESCENDING, **filter_kwargs) -> List[CombinerDTO]:
entities = self.mongo_select(limit, skip, sort_key, sort_order, **filter_kwargs)
result = []
for entity in entities:
result.append(self._dto_from_document(entity))
return result

address: Mapped[str] = mapped_column(String(255))
fqdn: Mapped[Optional[str]] = mapped_column(String(255))
ip: Mapped[Optional[str]] = mapped_column(String(255))
name: Mapped[str] = mapped_column(String(255))
parent: Mapped[Optional[str]] = mapped_column(String(255))
port: Mapped[int]
updated_at: Mapped[datetime] = mapped_column(default=datetime.now())
def count(self, **kwargs) -> int:
return self.mongo_count(**kwargs)

def get_by_name(self, name: str) -> CombinerDTO:
document = self.database[self.collection].find_one({"name": name})
if document is None:
return None
return self._dto_from_document(document)

def from_row(row: CombinerModel) -> Combiner:
return {
"id": row.id,
"committed_at": row.committed_at,
"address": row.address,
"ip": row.ip,
"name": row.name,
"parent": row.parent,
"fqdn": row.fqdn,
"port": row.port,
"updated_at": row.updated_at,
}
def _dto_from_document(self, document: Dict) -> CombinerDTO:
return CombinerDTO().populate_with(from_document(document))


class SQLCombinerStore(CombinerStore, SQLStore[Combiner]):
class SQLCombinerStore(CombinerStore, SQLStore[CombinerDTO]):
def __init__(self, Session):
super().__init__(Session)
super().__init__(Session, CombinerModel)

def get(self, id: str) -> Combiner:
def get(self, id: str) -> CombinerDTO:
with self.Session() as session:
stmt = select(CombinerModel).where(or_(CombinerModel.id == id, CombinerModel.name == id))
item = session.scalars(stmt).first()
if item is None:
entity = self.sql_get(session, id)
if entity is None:
return None
return from_row(item)
return self._dto_from_orm_model(entity)

def update(self, id, item):
def update(self, item):
raise NotImplementedError

def add(self, item):
with self.Session() as session:
entity = CombinerModel(
address=item["address"],
fqdn=item["fqdn"],
ip=item["ip"],
name=item["name"],
parent=item["parent"],
port=item["port"],
)
session.add(entity)
session.commit()
return True, from_row(entity)
item_dict = item.to_db(exclude_unset=False)
item_dict = self._to_orm_dict(item_dict)
entity = CombinerModel(**item_dict)
success, obj = self.sql_add(session, entity)
if success:
return success, self._dto_from_orm_model(obj)
return success, obj

def delete(self, id: str) -> bool:
with self.Session() as session:
stmt = select(CombinerModel).where(CombinerModel.id == id)
item = session.scalars(stmt).first()
if item is None:
return False
session.delete(item)
return True

def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDING, **kwargs):
with self.Session() as session:
stmt = select(CombinerModel)

for key, value in kwargs.items():
stmt = stmt.where(getattr(CombinerModel, key) == value)

_sort_order: str = "DESC" if sort_order == pymongo.DESCENDING else "ASC"
_sort_key: str = sort_key or "committed_at"

if _sort_key in CombinerModel.__table__.columns:
sort_obj = CombinerModel.__table__.columns.get(_sort_key) if _sort_order == "ASC" else CombinerModel.__table__.columns.get(_sort_key).desc()

stmt = stmt.order_by(sort_obj)

if limit:
stmt = stmt.offset(skip or 0).limit(limit)
elif skip:
stmt = stmt.offset(skip)
return self.sql_delete(id)

items = session.scalars(stmt).all()

result = []
for i in items:
result.append(from_row(i))

count = session.scalar(select(func.count()).select_from(CombinerModel))

return {"count": count, "result": result}
def select(self, limit=0, skip=0, sort_key=None, sort_order=pymongo.DESCENDING, **kwargs):
with self.Session() as session:
entities = self.sql_select(session, limit, skip, sort_key, sort_order, **kwargs)
return [self._dto_from_orm_model(item) for item in entities]

def count(self, **kwargs):
with self.Session() as session:
stmt = select(func.count()).select_from(CombinerModel)
return self.sql_count(**kwargs)

for key, value in kwargs.items():
stmt = stmt.where(getattr(CombinerModel, key) == value)
def get_by_name(self, name: str) -> CombinerDTO:
with self.Session() as session:
entity = session.query(CombinerModel).filter(CombinerModel.name == name).first()
if entity is None:
return None
return self._dto_from_orm_model(entity)

count = session.scalar(stmt)
def _to_orm_dict(self, item_dict: Dict) -> Dict:
return item_dict

return count
def _dto_from_orm_model(self, item: CombinerModel) -> CombinerDTO:
return CombinerDTO().populate_with(from_orm_model(item, CombinerModel))
Loading

0 comments on commit a6a1df0

Please sign in to comment.