diff --git a/fedn/network/api/network.py b/fedn/network/api/network.py index bd350703f..4f2ea835e 100644 --- a/fedn/network/api/network.py +++ b/fedn/network/api/network.py @@ -48,10 +48,10 @@ def get_combiners(self): :return: list of combiners objects :rtype: list(:class:`fedn.network.combiner.interfaces.CombinerInterface`) """ - data = self.combiner_store.list(limit=0, skip=0, sort_key=None) + result = self.combiner_store.select(limit=0, skip=0, sort_key=None) combiners = [] - for c in data["result"]: - name = c["name"].upper() + for combiner in result: + name = combiner.name.upper() # General certificate handling, same for all combiners. if os.environ.get("FEDN_GRPC_CERT_PATH"): with open(os.environ.get("FEDN_GRPC_CERT_PATH"), "rb") as f: @@ -63,7 +63,9 @@ def get_combiners(self): cert = f.read() else: cert = None - combiners.append(CombinerInterface(c["parent"], c["name"], c["address"], c["fqdn"], c["port"], certificate=cert, ip=c["ip"])) + combiners.append( + CombinerInterface(combiner.parent, combiner.name, combiner.address, combiner.fqdn, combiner.port, certificate=cert, ip=combiner.ip) + ) return combiners diff --git a/fedn/network/api/v1/combiner_routes.py b/fedn/network/api/v1/combiner_routes.py index 4f64b59db..e79bc11ef 100644 --- a/fedn/network/api/v1/combiner_routes.py +++ b/fedn/network/api/v1/combiner_routes.py @@ -104,7 +104,9 @@ def get_combiners(): kwargs = request.args.to_dict() - response = combiner_store.list(limit, skip, sort_key, sort_order, **kwargs) + combiners = combiner_store.select(limit, skip, sort_key, sort_order, **kwargs) + count = combiner_store.count(**kwargs) + response = {"count": count, "result": [combiner.to_dict() for combiner in combiners]} return jsonify(response), 200 except Exception as e: @@ -184,7 +186,9 @@ def list_combiners(): kwargs = get_post_data_to_kwargs(request) - response = combiner_store.list(limit, skip, sort_key, sort_order, **kwargs) + combiners = combiner_store.select(limit, skip, sort_key, sort_order, **kwargs) + count = combiner_store.count(**kwargs) + response = {"count": count, "result": [combiner.to_dict() for combiner in combiners]} return jsonify(response), 200 except Exception as e: @@ -327,8 +331,8 @@ def get_combiner(id: str): try: response = combiner_store.get(id) if response is None: - return jsonify({"message": f"Entity with id: {id} not found"}), 404 - return jsonify(response), 200 + return jsonify({"message": f"Entity with id: {id} not found"}), 404 + return jsonify(response.to_dict()), 200 except Exception as e: logger.error(f"An unexpected error occurred: {e}") return jsonify({"message": "An unexpected error occurred"}), 500 @@ -369,7 +373,7 @@ def delete_combiner(id: str): try: result: bool = combiner_store.delete(id) if not result: - return jsonify({"message": f"Entity with id: {id} not found"}), 404 + return jsonify({"message": f"Entity with id: {id} not found"}), 404 msg = "Combiner deleted" if result else "Combiner not deleted" return jsonify({"message": msg}), 200 except Exception as e: diff --git a/fedn/network/combiner/combiner.py b/fedn/network/combiner/combiner.py index f964f4609..8898ece05 100644 --- a/fedn/network/combiner/combiner.py +++ b/fedn/network/combiner/combiner.py @@ -20,6 +20,7 @@ from fedn.network.combiner.shared import client_store, combiner_store, prediction_store, repository, round_store, status_store, validation_store from fedn.network.grpc.server import Server, ServerConfig from fedn.network.storage.statestore.stores.dto import ClientDTO +from fedn.network.storage.statestore.stores.dto.combiner import CombinerDTO VALID_NAME_REGEX = "^[a-zA-Z0-9_-]*$" @@ -109,19 +110,17 @@ def __init__(self, config): self.round_store = round_store - # Add combiner to statestore - interface_config = { - "port": config["port"], - "fqdn": config["fqdn"], - "name": config["name"], - "address": config["host"], - "parent": "localhost", - "ip": "", - "updated_at": str(datetime.now()), - } # Check if combiner already exists in statestore - if combiner_store.get(config["name"]) is None: - combiner_store.add(interface_config) + if combiner_store.get_by_name(config["name"]) is None: + new_combiner = CombinerDTO() + new_combiner.port = config["port"] + new_combiner.fqdn = config["fqdn"] + new_combiner.name = config["name"] + new_combiner.address = config["host"] + new_combiner.parent = "localhost" + new_combiner.ip = "" + new_combiner.updated_at = str(datetime.now()) + combiner_store.add(new_combiner) # Fetch all clients previously connected to the combiner # If a client and a combiner goes down at the same time, diff --git a/fedn/network/storage/statestore/stores/combiner_store.py b/fedn/network/storage/statestore/stores/combiner_store.py index 7719868ec..cfa1a7564 100644 --- a/fedn/network/storage/statestore/stores/combiner_store.py +++ b/fedn/network/storage/statestore/stores/combiner_store.py @@ -1,217 +1,107 @@ -from datetime import datetime -from typing import Any, Dict, List, Optional, Tuple +from abc import abstractmethod +from typing import Dict, List import pymongo -from bson import ObjectId from pymongo.database import Database -from sqlalchemy import String, func, or_, select -from sqlalchemy.orm import Mapped, mapped_column - -from fedn.network.storage.statestore.stores.sql.shared import MyAbstractBase -from fedn.network.storage.statestore.stores.store import MongoDBStore, SQLStore, Store - -from .shared import from_document - - -class Combiner: - def __init__( - self, - id: str, - name: str, - address: str, - certificate: str, - config: dict, - fqdn: str, - ip: str, - key: str, - parent: dict, - port: int, - status: str, - updated_at: str, - ): - self.id = id - self.name = name - self.address = address - self.certificate = certificate - self.config = config - self.fqdn = fqdn - self.ip = ip - self.key = key - self.parent = parent - self.port = port - self.status = status - self.updated_at = updated_at - - -class CombinerStore(Store[Combiner]): - pass - - -class MongoDBCombinerStore(MongoDBStore[Combiner]): - def __init__(self, database: Database, collection: str): - super().__init__(database, collection) - - def get(self, id: str) -> Combiner: - """Get an entity by id - param id: The id of the entity - type: str - description: The id of the entity, can be either the id or the name (property) - return: The entity - """ - if ObjectId.is_valid(id): - id_obj = ObjectId(id) - document = self.database[self.collection].find_one({"_id": id_obj}) - else: - document = self.database[self.collection].find_one({"name": id}) - if document is None: - return None +from fedn.network.storage.statestore.stores.dto import CombinerDTO +from fedn.network.storage.statestore.stores.new_store import MongoDBStore, SQLStore, Store, from_document +from fedn.network.storage.statestore.stores.sql.shared import CombinerModel, from_orm_model - return from_document(document) - def update(self, id: str, item: Combiner) -> bool: - raise NotImplementedError("Update not implemented for CombinerStore") +class CombinerStore(Store[CombinerDTO]): + @abstractmethod + def get_by_name(name: str) -> CombinerDTO: + pass - def add(self, item: Combiner) -> Tuple[bool, Any]: - return super().add(item) - def delete(self, id: str) -> bool: - if ObjectId.is_valid(id): - kwargs = {"_id": ObjectId(id)} - else: - return False +class MongoDBCombinerStore(CombinerStore, MongoDBStore): + def __init__(self, database: Database, collection: str): + super().__init__(database, collection, "id") - document = self.database[self.collection].find_one(kwargs) + def get(self, id: str) -> CombinerDTO: + obj = self.mongo_get(id) + if obj is None: + return None + return self._dto_from_document(obj) - if document is None: - return False - - return super().delete(document["_id"]) - - def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDING, **kwargs) -> Dict[int, List[Combiner]]: - """List entities - param limit: The maximum number of entities to return - type: int - param skip: The number of entities to skip - type: int - param sort_key: The key to sort by - type: str - param sort_order: The order to sort by - type: pymongo.DESCENDING | pymongo.ASCENDING - param kwargs: Additional query parameters - type: dict - example: {"key": "models"} - return: A dictionary with the count and the result - """ - response = super().list(limit, skip, sort_key or "updated_at", sort_order, **kwargs) - - return response + def update(self, item: CombinerDTO): + raise NotImplementedError("Update not implemented for CombinerStore") - def count(self, **kwargs) -> int: - return super().count(**kwargs) + def add(self, item: CombinerDTO): + item_dict = item.to_db(exclude_unset=False) + success, obj = self.mongo_add(item_dict) + if success: + return success, self._dto_from_document(obj) + return success, obj + def delete(self, id: str) -> bool: + return self.mongo_delete(id) -class CombinerModel(MyAbstractBase): - __tablename__ = "combiners" + def select(self, limit: int = 0, skip: int = 0, sort_key: str = None, sort_order=pymongo.DESCENDING, **filter_kwargs) -> List[CombinerDTO]: + entities = self.mongo_select(limit, skip, sort_key, sort_order, **filter_kwargs) + result = [] + for entity in entities: + result.append(self._dto_from_document(entity)) + return result - address: Mapped[str] = mapped_column(String(255)) - fqdn: Mapped[Optional[str]] = mapped_column(String(255)) - ip: Mapped[Optional[str]] = mapped_column(String(255)) - name: Mapped[str] = mapped_column(String(255)) - parent: Mapped[Optional[str]] = mapped_column(String(255)) - port: Mapped[int] - updated_at: Mapped[datetime] = mapped_column(default=datetime.now()) + def count(self, **kwargs) -> int: + return self.mongo_count(**kwargs) + def get_by_name(self, name: str) -> CombinerDTO: + document = self.database[self.collection].find_one({"name": name}) + if document is None: + return None + return self._dto_from_document(document) -def from_row(row: CombinerModel) -> Combiner: - return { - "id": row.id, - "committed_at": row.committed_at, - "address": row.address, - "ip": row.ip, - "name": row.name, - "parent": row.parent, - "fqdn": row.fqdn, - "port": row.port, - "updated_at": row.updated_at, - } + def _dto_from_document(self, document: Dict) -> CombinerDTO: + return CombinerDTO().populate_with(from_document(document)) -class SQLCombinerStore(CombinerStore, SQLStore[Combiner]): +class SQLCombinerStore(CombinerStore, SQLStore[CombinerDTO]): def __init__(self, Session): - super().__init__(Session) + super().__init__(Session, CombinerModel) - def get(self, id: str) -> Combiner: + def get(self, id: str) -> CombinerDTO: with self.Session() as session: - stmt = select(CombinerModel).where(or_(CombinerModel.id == id, CombinerModel.name == id)) - item = session.scalars(stmt).first() - if item is None: + entity = self.sql_get(session, id) + if entity is None: return None - return from_row(item) + return self._dto_from_orm_model(entity) - def update(self, id, item): + def update(self, item): raise NotImplementedError def add(self, item): with self.Session() as session: - entity = CombinerModel( - address=item["address"], - fqdn=item["fqdn"], - ip=item["ip"], - name=item["name"], - parent=item["parent"], - port=item["port"], - ) - session.add(entity) - session.commit() - return True, from_row(entity) + item_dict = item.to_db(exclude_unset=False) + item_dict = self._to_orm_dict(item_dict) + entity = CombinerModel(**item_dict) + success, obj = self.sql_add(session, entity) + if success: + return success, self._dto_from_orm_model(obj) + return success, obj def delete(self, id: str) -> bool: - with self.Session() as session: - stmt = select(CombinerModel).where(CombinerModel.id == id) - item = session.scalars(stmt).first() - if item is None: - return False - session.delete(item) - return True - - def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDING, **kwargs): - with self.Session() as session: - stmt = select(CombinerModel) - - for key, value in kwargs.items(): - stmt = stmt.where(getattr(CombinerModel, key) == value) - - _sort_order: str = "DESC" if sort_order == pymongo.DESCENDING else "ASC" - _sort_key: str = sort_key or "committed_at" - - if _sort_key in CombinerModel.__table__.columns: - sort_obj = CombinerModel.__table__.columns.get(_sort_key) if _sort_order == "ASC" else CombinerModel.__table__.columns.get(_sort_key).desc() - - stmt = stmt.order_by(sort_obj) - - if limit: - stmt = stmt.offset(skip or 0).limit(limit) - elif skip: - stmt = stmt.offset(skip) + return self.sql_delete(id) - items = session.scalars(stmt).all() - - result = [] - for i in items: - result.append(from_row(i)) - - count = session.scalar(select(func.count()).select_from(CombinerModel)) - - return {"count": count, "result": result} + def select(self, limit=0, skip=0, sort_key=None, sort_order=pymongo.DESCENDING, **kwargs): + with self.Session() as session: + entities = self.sql_select(session, limit, skip, sort_key, sort_order, **kwargs) + return [self._dto_from_orm_model(item) for item in entities] def count(self, **kwargs): - with self.Session() as session: - stmt = select(func.count()).select_from(CombinerModel) + return self.sql_count(**kwargs) - for key, value in kwargs.items(): - stmt = stmt.where(getattr(CombinerModel, key) == value) + def get_by_name(self, name: str) -> CombinerDTO: + with self.Session() as session: + entity = session.query(CombinerModel).filter(CombinerModel.name == name).first() + if entity is None: + return None + return self._dto_from_orm_model(entity) - count = session.scalar(stmt) + def _to_orm_dict(self, item_dict: Dict) -> Dict: + return item_dict - return count + def _dto_from_orm_model(self, item: CombinerModel) -> CombinerDTO: + return CombinerDTO().populate_with(from_orm_model(item, CombinerModel)) diff --git a/fedn/network/storage/statestore/stores/dto/__init__.py b/fedn/network/storage/statestore/stores/dto/__init__.py index 8b11a4660..79d41088a 100644 --- a/fedn/network/storage/statestore/stores/dto/__init__.py +++ b/fedn/network/storage/statestore/stores/dto/__init__.py @@ -1,6 +1,8 @@ """DTOs for the StateStore.""" from fedn.network.storage.statestore.stores.dto.client import ClientDTO +from fedn.network.storage.statestore.stores.dto.combiner import CombinerDTO from fedn.network.storage.statestore.stores.dto.model import ModelDTO +from fedn.network.storage.statestore.stores.dto.session import SessionConfigDTO, SessionDTO -__all__ = ["ClientDTO", "ModelDTO"] +__all__ = ["ClientDTO", "ModelDTO", "SessionConfigDTO", "SessionDTO", "CombinerDTO"] diff --git a/fedn/network/storage/statestore/stores/dto/combiner.py b/fedn/network/storage/statestore/stores/dto/combiner.py new file mode 100644 index 000000000..28b7f09e5 --- /dev/null +++ b/fedn/network/storage/statestore/stores/dto/combiner.py @@ -0,0 +1,17 @@ +from datetime import datetime +from typing import Optional + +from fedn.network.storage.statestore.stores.dto.shared import BaseDTO, Field + + +class CombinerDTO(BaseDTO): + """Client data transfer object.""" + + id: Optional[str] = Field(None) + name: str = Field(None) + address: str = Field(None) + fqdn: str = Field(None) + ip: str = Field(None) + parent: dict = Field(None) + port: int = Field(None) + updated_at: datetime = Field(None) diff --git a/fedn/network/storage/statestore/stores/dto/shared.py b/fedn/network/storage/statestore/stores/dto/shared.py index 8994500cd..a85da3020 100644 --- a/fedn/network/storage/statestore/stores/dto/shared.py +++ b/fedn/network/storage/statestore/stores/dto/shared.py @@ -54,7 +54,7 @@ def to_dict(self) -> Dict[str, Any]: """ return self.model_dump(exclude_unset=False) - def to_db(self, exclude_unset: bool = True) -> Dict[str, Any]: + def to_db(self, exclude_unset: bool = False) -> Dict[str, Any]: """Return dict representation of BaseModel for database storage.""" return self.model_dump(exclude_unset=exclude_unset) diff --git a/fedn/network/storage/statestore/stores/session_store.py b/fedn/network/storage/statestore/stores/session_store.py index 0bcc90458..f0aac2530 100644 --- a/fedn/network/storage/statestore/stores/session_store.py +++ b/fedn/network/storage/statestore/stores/session_store.py @@ -85,12 +85,6 @@ def __init__(self, database: Database, collection: str): self.database[self.collection].create_index([("session_id", pymongo.DESCENDING)]) def get(self, id: str) -> SessionDTO: - """Get an entity by id - param id: The id of the entity - type: str - description: The id of the entity, can be either the id or the session_id (property) - return: The entity - """ entity = self.mongo_get(id) if entity is None: return None diff --git a/fedn/network/storage/statestore/stores/sql/shared.py b/fedn/network/storage/statestore/stores/sql/shared.py index ab5c19298..b065eb9c5 100644 --- a/fedn/network/storage/statestore/stores/sql/shared.py +++ b/fedn/network/storage/statestore/stores/sql/shared.py @@ -29,7 +29,7 @@ class MyAbstractBase(Base): __abstract__ = True id: Mapped[str] = mapped_column(primary_key=True, default=lambda: str(uuid.uuid4())) - committed_at: Mapped[datetime] = mapped_column(default=datetime.now()) + committed_at: Mapped[datetime] = mapped_column(default=datetime.now) class SessionConfigModel(MyAbstractBase): @@ -155,3 +155,15 @@ class ClientModel(MyAbstractBase): package: Mapped[Optional[str]] = mapped_column(String(255)) status: Mapped[str] = mapped_column(String(255)) last_seen: Mapped[datetime] = mapped_column(default=datetime.now()) + + +class CombinerModel(MyAbstractBase): + __tablename__ = "combiners" + + address: Mapped[str] = mapped_column(String(255)) + fqdn: Mapped[Optional[str]] = mapped_column(String(255)) + ip: Mapped[Optional[str]] = mapped_column(String(255)) + name: Mapped[str] = mapped_column(String(255)) + parent: Mapped[Optional[str]] = mapped_column(String(255)) + port: Mapped[int] + updated_at: Mapped[datetime] = mapped_column(default=datetime.now()) diff --git a/fedn/tests/stores/test_combiner_store.py b/fedn/tests/stores/test_combiner_store.py index 0396d4def..057691b06 100644 --- a/fedn/tests/stores/test_combiner_store.py +++ b/fedn/tests/stores/test_combiner_store.py @@ -1,3 +1,4 @@ +import time import pytest import pymongo @@ -6,84 +7,69 @@ import uuid from fedn.network.storage.dbconnection import DatabaseConnection - +from fedn.network.storage.statestore.stores.dto import CombinerDTO @pytest.fixture def test_combiners(): start_date = datetime.datetime(2021, 1, 4, 1, 2, 4) - # NOTE: The SQL version does not support the dicts as parameters parent and config - # TODO: Creat using Combiner class - return [{"id":str(uuid.uuid4()), "committed_at":start_date - datetime.timedelta(days=1), "name":"test_combiner1", - "parent":"localhost", - # "config":{}, - "ip":"123:13:12:2", "fqdn":"", "port":8080, - "updated_at":start_date - datetime.timedelta(days=52), "address":"test_address"}, - {"id":str(uuid.uuid4()), "committed_at":start_date - datetime.timedelta(days=2), "name":"test_combiner2", - "parent":"localhost", - # "config":{}, - "ip":"123:13:12:2", "fqdn":"", "port":8080, - "updated_at":start_date - datetime.timedelta(days=12), "address":"test_address"}, - {"id":str(uuid.uuid4()), "committed_at":start_date - datetime.timedelta(days=8), "name":"test_combiner3", - "parent":"localhost", - # "config":{}, - "ip":"123:13:12:5", "fqdn":"", "port":8080, - "updated_at":start_date - datetime.timedelta(days=322), "address":"test_address"}, - {"id":str(uuid.uuid4()), "committed_at":start_date - datetime.timedelta(days=4), "name":"test_combiner4", - "parent":"localhost", - # "config":{}, - "ip":"123:13:12:4", "fqdn":"", "port":8080, - "updated_at":start_date - datetime.timedelta(days=23), "address":"test_address"}, - {"id":str(uuid.uuid4()), "committed_at":start_date - datetime.timedelta(days=51), "name":"test_combiner5", - "parent":"localhost", - # "config":{}, - "ip":"123:13:12:3", "fqdn":"", "port":8080, - "updated_at":start_date - datetime.timedelta(days=22), "address":"test_address"}, - {"id":str(uuid.uuid4()), "committed_at":start_date - datetime.timedelta(days=9), "name":"test_combiner6", - "parent":"localhost", - # "config":{}, - "ip":"123:13:12:3", "fqdn":"", "port":8080, - "updated_at":start_date - datetime.timedelta(days=24), "address":"test_address"}, - {"id":str(uuid.uuid4()), "committed_at":start_date - datetime.timedelta(days=3), "name":"test_combiner8", - "parent":"localhost", - # "config":{}, - "ip":"123:13:12:3", "fqdn":"", "port":8080, - "updated_at":start_date - datetime.timedelta(days=42), "address":"test_address"}, - {"id":str(uuid.uuid4()), "committed_at":start_date - datetime.timedelta(days=14), "name":"test_combiner7", - "parent":"localhost", - # "config":{}, - "ip":"123:13:12:2", "fqdn":"", "port":8080, - "updated_at":start_date - datetime.timedelta(days=12), "address":"test_address1"}] + combiner1 = CombinerDTO(id=str(uuid.uuid4()), name="test_combiner1", + parent="localhost", ip="123:13:12:2", fqdn="", port=8080, + updated_at=start_date - datetime.timedelta(days=52), address="test_address") + combiner2 = CombinerDTO(id=str(uuid.uuid4()), name="test_combiner2", + parent="localhost", ip="123:13:12:2", fqdn="", port=8080, + updated_at=start_date - datetime.timedelta(days=12), address="test_address") + combiner3 = CombinerDTO(id=str(uuid.uuid4()), name="test_combiner3", + parent="localhost", ip="123:13:12:5", fqdn="", port=8080, + updated_at=start_date - datetime.timedelta(days=322), address="test_address") + combiner4 = CombinerDTO(id=str(uuid.uuid4()), name="test_combiner4", + parent="localhost", ip="123:13:12:4", fqdn="", port=8080, + updated_at=start_date - datetime.timedelta(days=23), address="test_address") + combiner5 = CombinerDTO(id=str(uuid.uuid4()), name="test_combiner5", + parent="localhost", ip="123:13:12:3", fqdn="", port=8080, + updated_at=start_date - datetime.timedelta(days=22), address="test_address") + combiner6 = CombinerDTO(id=str(uuid.uuid4()), name="test_combiner6", + parent="localhost", ip="123:13:12:3", fqdn="", port=8080, + updated_at=start_date - datetime.timedelta(days=24), address="test_address") + combiner7 = CombinerDTO(id=str(uuid.uuid4()), name="test_combiner8", + parent="localhost", ip="123:13:12:3", fqdn="", port=8080, + updated_at=start_date - datetime.timedelta(days=42), address="test_address") + combiner8 = CombinerDTO(id=str(uuid.uuid4()), name="test_combiner7", + parent="localhost", ip="123:13:12:2", fqdn="", port=8080, + updated_at=start_date - datetime.timedelta(days=12), address="test_address1") + return [combiner1, combiner2, combiner3, combiner4, combiner5, combiner6, combiner7, combiner8] +@pytest.fixture +def test_combiner(): + start_date = datetime.datetime(2021, 1, 4, 1, 2, 4) + combiner = CombinerDTO(name="test_combiner", + parent="localhost", ip="123:13:12:2", fqdn="", port=8080, + updated_at=start_date - datetime.timedelta(days=52), address="test_address") + return combiner @pytest.fixture def db_connections_with_data(postgres_connection:DatabaseConnection, sql_connection: DatabaseConnection, mongo_connection:DatabaseConnection, test_combiners): for c in test_combiners: - res, _ = mongo_connection.combiner_store.add(c) - assert res == True - - for c in test_combiners: - res, _ = postgres_connection.combiner_store.add(c) - assert res == True - - for c in test_combiners: - res, _ = sql_connection.combiner_store.add(c) - assert res == True + mongo_connection.combiner_store.add(c) + postgres_connection.combiner_store.add(c) + sql_connection.combiner_store.add(c) + time.sleep(0.01) yield [("postgres", postgres_connection), ("sqlite", sql_connection), ("mongo", mongo_connection)] - # TODO:Clean up - - + for c in test_combiners: + mongo_connection.combiner_store.delete(c.id) + postgres_connection.combiner_store.delete(c.id) + sql_connection.combiner_store.delete(c.id) + @pytest.fixture def options(): - sorting_keys = (#None, - "name", - #"committed_at", - #"updated_at", - #"invalid_key" - ) + sorting_keys = (None, + "name", + "committed_at", + "updated_at", + "invalid_key") limits = (None, 0, 1, 2, 99) skips = (None, 0, 1, 2, 99) desc = (None, pymongo.DESCENDING, pymongo.ASCENDING) @@ -93,36 +79,62 @@ def options(): class TestCombinerStore: - def test_add_update_delete(self, postgres_connection:DatabaseConnection, sql_connection: DatabaseConnection, mongo_connection:DatabaseConnection): - pass + def test_add_update_delete_postgres(self, postgres_connection:DatabaseConnection, test_combiner): + self.helper_add_update_delete(postgres_connection, test_combiner) + + def test_add_update_delete_sql(self, sql_connection: DatabaseConnection, test_combiner): + self.helper_add_update_delete(sql_connection, test_combiner) + + def test_add_update_delete_mongo(self, mongo_connection:DatabaseConnection, test_combiner): + self.helper_add_update_delete(mongo_connection, test_combiner) + + def helper_add_update_delete(self, db: DatabaseConnection, test_combiner:CombinerDTO): + # Add a combiner and check that we get the added combiner back + name = test_combiner.name + + success, read_combiner1 = db.combiner_store.add(test_combiner) + assert success == True + assert isinstance(read_combiner1.id, str) + read_combiner1_dict = read_combiner1.to_dict() + combiner_id = read_combiner1_dict["id"] + del read_combiner1_dict["id"] + del read_combiner1_dict["committed_at"] + + test_combiner_dict = test_combiner.to_dict() + del test_combiner_dict["id"] + del test_combiner_dict["committed_at"] + + assert read_combiner1_dict == test_combiner_dict + + # Assert we get the same combiner back + read_combiner2 = db.combiner_store.get(combiner_id) + assert read_combiner2 is not None + assert read_combiner2.to_dict() == read_combiner1.to_dict() + + # Assert we get the same combiner back by name + read_combiner3 = db.combiner_store.get_by_name(name) + assert read_combiner3 is not None + assert read_combiner3.to_dict() == read_combiner1.to_dict() + + # Delete the combiner and check that it is deleted + success = db.combiner_store.delete(combiner_id) + assert success == True def test_list(self, db_connections_with_data: list[tuple[str, DatabaseConnection]], options: list[tuple]): for (name1, db_1), (name2, db_2) in zip(db_connections_with_data[1:], db_connections_with_data[:-1]): print("Running tests between databases {} and {}".format(name1, name2)) for *opt,kwargs in options: - res = db_1.combiner_store.list(*opt, **kwargs) - count, gathered_combiners = res["count"], res["result"] - - res = db_2.combiner_store.list(*opt, **kwargs) - count2, gathered_combiners2 = res["count"], res["result"] - #TODO: The count is not equal to the number of clients in the list, but the number of clients returned by the query before skip and limit - #It is not clear what is the intended behavior - # assert(count == len(gathered_clients)) - # assert count == count2 + gathered_combiners = db_1.combiner_store.select(*opt, **kwargs) + count = db_1.combiner_store.count(**kwargs) + + gathered_combiners2 = db_2.combiner_store.select(*opt, **kwargs) + count2 = db_2.combiner_store.count(**kwargs) + + assert(count == count2) assert len(gathered_combiners) == len(gathered_combiners2) for i in range(len(gathered_combiners)): - #NOTE: id are not equal between the two databases, I think it is due to id being overwritten in the _id field - #assert gathered_combiners2[i]["id"] == gathered_combiners[i]["id"] - #TODO: committed_at is not equal between the two databases, one reades from init the other uses the current time - #assert gathered_combiners2[i]["committed_at"] == gathered_combiners[i]["committed_at"] - assert gathered_combiners2[i]["name"] == gathered_combiners[i]["name"] - assert gathered_combiners2[i]["parent"] == gathered_combiners[i]["parent"] - assert gathered_combiners2[i]["ip"] == gathered_combiners[i]["ip"] - assert gathered_combiners2[i]["fqdn"] == gathered_combiners[i]["fqdn"] - assert gathered_combiners2[i]["port"] == gathered_combiners[i]["port"] - #TODO: updated_at is not equal between the two databases, one reades from init the other uses the current time - #assert gathered_combiners2[i]["updated_at"] == gathered_combiners[i]["updated_atssert gathered_combiners2[i]["address"] == gathered_combiners[i]["address"] + assert gathered_combiners2[i].id == gathered_combiners[i].id