From fd1c34b1d36f2ea7fd953ec3187a03cb0e5e4ace Mon Sep 17 00:00:00 2001 From: niklastheman Date: Wed, 8 Jan 2025 16:46:57 +0100 Subject: [PATCH] started on prediction store --- fedn/network/api/shared.py | 2 + fedn/network/combiner/shared.py | 5 +- .../statestore/stores/prediction_store.py | 164 +++++++++++++++++- 3 files changed, 166 insertions(+), 5 deletions(-) diff --git a/fedn/network/api/shared.py b/fedn/network/api/shared.py index 45de37e3a..4c27e1bbd 100644 --- a/fedn/network/api/shared.py +++ b/fedn/network/api/shared.py @@ -50,6 +50,8 @@ status_store: StatusStore = SQLStatusStore() # validation_store: ValidationStore = MongoDBValidationStore(mdb, "control.validations") validation_store: ValidationStore = SQLValidationStore() +# prediction_store: PredictionStore = MongoDBPredictionStore(mdb, "control.predictions") +prediction_store: PredictionStore = SQLPredictionStore() repository = Repository(modelstorage_config["storage_config"]) diff --git a/fedn/network/combiner/shared.py b/fedn/network/combiner/shared.py index 7e2aa6607..2980e7822 100644 --- a/fedn/network/combiner/shared.py +++ b/fedn/network/combiner/shared.py @@ -6,7 +6,7 @@ from fedn.network.storage.s3.repository import Repository from fedn.network.storage.statestore.stores.client_store import ClientStore, MongoDBClientStore, SQLClientStore from fedn.network.storage.statestore.stores.combiner_store import CombinerStore, MongoDBCombinerStore, SQLCombinerStore -from fedn.network.storage.statestore.stores.prediction_store import PredictionStore +from fedn.network.storage.statestore.stores.prediction_store import MongoDBPredictionStore, PredictionStore, SQLPredictionStore from fedn.network.storage.statestore.stores.round_store import MongoDBRoundStore, RoundStore, SQLRoundStore from fedn.network.storage.statestore.stores.status_store import MongoDBStatusStore, SQLStatusStore, StatusStore from fedn.network.storage.statestore.stores.store import Base, MyAbstractBase, engine @@ -31,7 +31,8 @@ combiner_store: CombinerStore = SQLCombinerStore() # status_store: StatusStore = MongoDBStatusStore(mdb, "control.status") status_store: StatusStore = SQLStatusStore() -prediction_store = PredictionStore(mdb, "control.predictions") +# prediction_store: PredictionStore = MongoDBPredictionStore(mdb, "control.predictions") +prediction_store: PredictionStore = SQLPredictionStore() # round_store: RoundStore = MongoDBRoundStore(mdb, "control.rounds") round_store: RoundStore = SQLRoundStore() diff --git a/fedn/network/storage/statestore/stores/prediction_store.py b/fedn/network/storage/statestore/stores/prediction_store.py index 5b918c41e..fcad17e1d 100644 --- a/fedn/network/storage/statestore/stores/prediction_store.py +++ b/fedn/network/storage/statestore/stores/prediction_store.py @@ -1,9 +1,13 @@ -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple import pymongo from pymongo.database import Database +from sqlalchemy import ForeignKey, Integer, String, and_, func, or_, select +from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy.sql import text -from fedn.network.storage.statestore.stores.store import MongoDBStore +from fedn.network.storage.statestore.stores.shared import EntityNotFound +from fedn.network.storage.statestore.stores.store import MongoDBStore, MyAbstractBase, Session, SQLStore, Store class Prediction: @@ -21,7 +25,11 @@ def __init__( self.receiver = receiver -class PredictionStore(MongoDBStore[Prediction]): +class PredictionStore(Store[Prediction]): + pass + + +class MongoDBPredictionStore(MongoDBStore[Prediction]): def __init__(self, database: Database, collection: str): super().__init__(database, collection) @@ -61,3 +69,153 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI return: A dictionary with the count and a list of entities """ return super().list(limit, skip, sort_key or "timestamp", sort_order, **kwargs) + + +class PredictionModel(MyAbstractBase): + __tablename__ = "predictions" + + correlation_id: Mapped[str] + data: Mapped[Optional[str]] + model_id: Mapped[Optional[str]] = mapped_column(ForeignKey("models.id")) + receiver_name: Mapped[Optional[str]] = mapped_column(String(255)) + receiver_role: Mapped[Optional[str]] = mapped_column(String(255)) + sender_name: Mapped[Optional[str]] = mapped_column(String(255)) + sender_role: Mapped[Optional[str]] = mapped_column(String(255)) + timestamp: Mapped[str] = mapped_column(String(255)) + prediction_id: Mapped[str] = mapped_column(String(255)) + + +def from_row(row: PredictionModel) -> Prediction: + return { + "id": row.id, + "model_id": row.model_id, + "data": row.data, + "correlation_id": row.correlation_id, + "timestamp": row.timestamp, + "prediction_id": row.prediction_id, + "sender": {"name": row.sender_name, "role": row.sender_role}, + "receiver": {"name": row.receiver_name, "role": row.receiver_role}, + } + + +class SQLPredictionStore(PredictionStore, SQLStore[Prediction]): + def get(self, id: str) -> Prediction: + with Session() as session: + stmt = select(Prediction).where(Prediction.id == id) + item = session.scalars(stmt).first() + + if item is None: + raise EntityNotFound(f"Entity with (id | round_id) {id} not found") + + return from_row(item) + + def update(self, id: str, item: Prediction) -> bool: + raise NotImplementedError("Update not implemented for PredictionStore") + + def add(self, item: Prediction) -> Tuple[bool, Any]: + with Session() as session: + sender = item["sender"] if "sender" in item else None + receiver = item["receiver"] if "receiver" in item else None + + validation = PredictionModel( + correlation_id=item.get("correlationId") or item.get("correlation_id"), + data=item.get("data"), + model_id=item.get("modelId") or item.get("model_id"), + receiver_name=receiver.get("name"), + receiver_role=receiver.get("role"), + sender_name=sender.get("name"), + sender_role=sender.get("role"), + prediction_id=item.get("predictionId") or item.get("prediction_id"), + timestamp=item.get("timestamp"), + ) + + session.add(validation) + session.commit() + + return True, validation + + def delete(self, id: str) -> bool: + raise NotImplementedError("Delete not implemented for PredictionStore") + + def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDING, **kwargs): + with Session() as session: + stmt = select(PredictionModel) + + for key, value in kwargs.items(): + if key == "_id": + key = "id" + elif key == "sender.name": + key = "sender_name" + elif key == "sender.role": + key = "sender_role" + elif key == "receiver.name": + key = "receiver_name" + elif key == "receiver.role": + key = "receiver_role" + elif key == "correlationId": + key = "correlation_id" + elif key == "modelId": + key = "model_id" + + stmt = stmt.where(getattr(PredictionModel, key) == value) + + if sort_key: + _sort_order: str = "DESC" if sort_order == pymongo.DESCENDING else "ASC" + _sort_key: str = sort_key + + if _sort_key == "_id": + _sort_key = "id" + elif _sort_key == "sender.name": + _sort_key = "sender_name" + elif _sort_key == "sender.role": + _sort_key = "sender_role" + elif _sort_key == "receiver.name": + _sort_key = "receiver_name" + elif _sort_key == "receiver.role": + _sort_key = "receiver_role" + elif _sort_key == "correlationId": + _sort_key = "correlation_id" + elif _sort_key == "modelId": + _sort_key = "model_id" + + sort_obj = text(f"{_sort_key} {_sort_order}") + + stmt = stmt.order_by(sort_obj) + + if limit != 0: + stmt = stmt.offset(skip or 0).limit(limit) + + items = session.execute(stmt) + + result = [] + + for item in items: + (r,) = item + + result.append(from_row(r)) + + return {"count": len(result), "result": result} + + def count(self, **kwargs): + with Session() as session: + stmt = select(func.count()).select_from(PredictionModel) + + for key, value in kwargs.items(): + if key == "sender.name": + key = "sender_name" + elif key == "sender.role": + key = "sender_role" + elif key == "receiver.name": + key = "receiver_name" + elif key == "receiver.role": + key = "receiver_role" + elif key == "correlationId": + key = "correlation_id" + elif key == "modelId": + key = "model_id" + + stmt = stmt.where(getattr(PredictionModel, key) == value) + + count = session.scalar(stmt) + + return count