Skip to content

Commit

Permalink
started on prediction store
Browse files Browse the repository at this point in the history
  • Loading branch information
niklastheman committed Jan 8, 2025
1 parent dbf3f5a commit fd1c34b
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 5 deletions.
2 changes: 2 additions & 0 deletions fedn/network/api/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
status_store: StatusStore = SQLStatusStore()
# validation_store: ValidationStore = MongoDBValidationStore(mdb, "control.validations")
validation_store: ValidationStore = SQLValidationStore()
# prediction_store: PredictionStore = MongoDBPredictionStore(mdb, "control.predictions")
prediction_store: PredictionStore = SQLPredictionStore()

repository = Repository(modelstorage_config["storage_config"])

Expand Down
5 changes: 3 additions & 2 deletions fedn/network/combiner/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from fedn.network.storage.s3.repository import Repository
from fedn.network.storage.statestore.stores.client_store import ClientStore, MongoDBClientStore, SQLClientStore
from fedn.network.storage.statestore.stores.combiner_store import CombinerStore, MongoDBCombinerStore, SQLCombinerStore
from fedn.network.storage.statestore.stores.prediction_store import PredictionStore
from fedn.network.storage.statestore.stores.prediction_store import MongoDBPredictionStore, PredictionStore, SQLPredictionStore
from fedn.network.storage.statestore.stores.round_store import MongoDBRoundStore, RoundStore, SQLRoundStore
from fedn.network.storage.statestore.stores.status_store import MongoDBStatusStore, SQLStatusStore, StatusStore
from fedn.network.storage.statestore.stores.store import Base, MyAbstractBase, engine
Expand All @@ -31,7 +31,8 @@
combiner_store: CombinerStore = SQLCombinerStore()
# status_store: StatusStore = MongoDBStatusStore(mdb, "control.status")
status_store: StatusStore = SQLStatusStore()
prediction_store = PredictionStore(mdb, "control.predictions")
# prediction_store: PredictionStore = MongoDBPredictionStore(mdb, "control.predictions")
prediction_store: PredictionStore = SQLPredictionStore()
# round_store: RoundStore = MongoDBRoundStore(mdb, "control.rounds")
round_store: RoundStore = SQLRoundStore()

Expand Down
164 changes: 161 additions & 3 deletions fedn/network/storage/statestore/stores/prediction_store.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Optional, Tuple

import pymongo
from pymongo.database import Database
from sqlalchemy import ForeignKey, Integer, String, and_, func, or_, select
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.sql import text

from fedn.network.storage.statestore.stores.store import MongoDBStore
from fedn.network.storage.statestore.stores.shared import EntityNotFound
from fedn.network.storage.statestore.stores.store import MongoDBStore, MyAbstractBase, Session, SQLStore, Store


class Prediction:
Expand All @@ -21,7 +25,11 @@ def __init__(
self.receiver = receiver


class PredictionStore(MongoDBStore[Prediction]):
class PredictionStore(Store[Prediction]):
pass


class MongoDBPredictionStore(MongoDBStore[Prediction]):
def __init__(self, database: Database, collection: str):
super().__init__(database, collection)

Expand Down Expand Up @@ -61,3 +69,153 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI
return: A dictionary with the count and a list of entities
"""
return super().list(limit, skip, sort_key or "timestamp", sort_order, **kwargs)


class PredictionModel(MyAbstractBase):
__tablename__ = "predictions"

correlation_id: Mapped[str]
data: Mapped[Optional[str]]
model_id: Mapped[Optional[str]] = mapped_column(ForeignKey("models.id"))
receiver_name: Mapped[Optional[str]] = mapped_column(String(255))
receiver_role: Mapped[Optional[str]] = mapped_column(String(255))
sender_name: Mapped[Optional[str]] = mapped_column(String(255))
sender_role: Mapped[Optional[str]] = mapped_column(String(255))
timestamp: Mapped[str] = mapped_column(String(255))
prediction_id: Mapped[str] = mapped_column(String(255))


def from_row(row: PredictionModel) -> Prediction:
return {
"id": row.id,
"model_id": row.model_id,
"data": row.data,
"correlation_id": row.correlation_id,
"timestamp": row.timestamp,
"prediction_id": row.prediction_id,
"sender": {"name": row.sender_name, "role": row.sender_role},
"receiver": {"name": row.receiver_name, "role": row.receiver_role},
}


class SQLPredictionStore(PredictionStore, SQLStore[Prediction]):
def get(self, id: str) -> Prediction:
with Session() as session:
stmt = select(Prediction).where(Prediction.id == id)
item = session.scalars(stmt).first()

if item is None:
raise EntityNotFound(f"Entity with (id | round_id) {id} not found")

return from_row(item)

def update(self, id: str, item: Prediction) -> bool:
raise NotImplementedError("Update not implemented for PredictionStore")

def add(self, item: Prediction) -> Tuple[bool, Any]:
with Session() as session:
sender = item["sender"] if "sender" in item else None
receiver = item["receiver"] if "receiver" in item else None

validation = PredictionModel(
correlation_id=item.get("correlationId") or item.get("correlation_id"),
data=item.get("data"),
model_id=item.get("modelId") or item.get("model_id"),
receiver_name=receiver.get("name"),
receiver_role=receiver.get("role"),
sender_name=sender.get("name"),
sender_role=sender.get("role"),
prediction_id=item.get("predictionId") or item.get("prediction_id"),
timestamp=item.get("timestamp"),
)

session.add(validation)
session.commit()

return True, validation

def delete(self, id: str) -> bool:
raise NotImplementedError("Delete not implemented for PredictionStore")

def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDING, **kwargs):
with Session() as session:
stmt = select(PredictionModel)

for key, value in kwargs.items():
if key == "_id":
key = "id"
elif key == "sender.name":
key = "sender_name"
elif key == "sender.role":
key = "sender_role"
elif key == "receiver.name":
key = "receiver_name"
elif key == "receiver.role":
key = "receiver_role"
elif key == "correlationId":
key = "correlation_id"
elif key == "modelId":
key = "model_id"

stmt = stmt.where(getattr(PredictionModel, key) == value)

if sort_key:
_sort_order: str = "DESC" if sort_order == pymongo.DESCENDING else "ASC"
_sort_key: str = sort_key

if _sort_key == "_id":
_sort_key = "id"
elif _sort_key == "sender.name":
_sort_key = "sender_name"
elif _sort_key == "sender.role":
_sort_key = "sender_role"
elif _sort_key == "receiver.name":
_sort_key = "receiver_name"
elif _sort_key == "receiver.role":
_sort_key = "receiver_role"
elif _sort_key == "correlationId":
_sort_key = "correlation_id"
elif _sort_key == "modelId":
_sort_key = "model_id"

sort_obj = text(f"{_sort_key} {_sort_order}")

stmt = stmt.order_by(sort_obj)

if limit != 0:
stmt = stmt.offset(skip or 0).limit(limit)

items = session.execute(stmt)

result = []

for item in items:
(r,) = item

result.append(from_row(r))

return {"count": len(result), "result": result}

def count(self, **kwargs):
with Session() as session:
stmt = select(func.count()).select_from(PredictionModel)

for key, value in kwargs.items():
if key == "sender.name":
key = "sender_name"
elif key == "sender.role":
key = "sender_role"
elif key == "receiver.name":
key = "receiver_name"
elif key == "receiver.role":
key = "receiver_role"
elif key == "correlationId":
key = "correlation_id"
elif key == "modelId":
key = "model_id"

stmt = stmt.where(getattr(PredictionModel, key) == value)

count = session.scalar(stmt)

return count

0 comments on commit fd1c34b

Please sign in to comment.