Skip to content

Commit

Permalink
started on rounds sql store
Browse files Browse the repository at this point in the history
  • Loading branch information
niklastheman committed Jan 6, 2025
1 parent 9ea7304 commit 2df0655
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 7 deletions.
5 changes: 3 additions & 2 deletions fedn/network/api/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from fedn.network.storage.statestore.stores.combiner_store import CombinerStore, MongoDBCombinerStore, SQLCombinerStore
from fedn.network.storage.statestore.stores.model_store import MongoDBModelStore, SQLModelStore
from fedn.network.storage.statestore.stores.package_store import MongoDBPackageStore, PackageStore, SQLPackageStore
from fedn.network.storage.statestore.stores.round_store import RoundStore
from fedn.network.storage.statestore.stores.round_store import MongoDBRoundStore, RoundStore, SQLRoundStore
from fedn.network.storage.statestore.stores.session_store import MongoDBSessionStore, SQLSessionStore
from fedn.network.storage.statestore.stores.shared import EntityNotFound
from fedn.network.storage.statestore.stores.status_store import StatusStore
Expand Down Expand Up @@ -43,7 +43,8 @@
model_store = SQLModelStore()
# combiner_store: CombinerStore = MongoDBCombinerStore(mdb, "network.combiners")
combiner_store: CombinerStore = SQLCombinerStore()
round_store = RoundStore(mdb, "control.rounds")
# round_store: RoundStore = MongoDBRoundStore(mdb, "control.rounds")
round_store: RoundStore = SQLRoundStore()
status_store = StatusStore(mdb, "control.status")
validation_store = ValidationStore(mdb, "control.validations")

Expand Down
5 changes: 3 additions & 2 deletions fedn/network/combiner/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from fedn.network.storage.statestore.stores.client_store import ClientStore
from fedn.network.storage.statestore.stores.combiner_store import CombinerStore, MongoDBCombinerStore, SQLCombinerStore
from fedn.network.storage.statestore.stores.prediction_store import PredictionStore
from fedn.network.storage.statestore.stores.round_store import RoundStore
from fedn.network.storage.statestore.stores.round_store import MongoDBRoundStore, RoundStore, SQLRoundStore
from fedn.network.storage.statestore.stores.status_store import StatusStore
from fedn.network.storage.statestore.stores.store import Base, MyAbstractBase, engine
from fedn.network.storage.statestore.stores.validation_store import ValidationStore
Expand All @@ -29,7 +29,8 @@
combiner_store: CombinerStore = SQLCombinerStore()
status_store = StatusStore(mdb, "control.status")
prediction_store = PredictionStore(mdb, "control.predictions")
round_store = RoundStore(mdb, "control.rounds")
# round_store: RoundStore = MongoDBRoundStore(mdb, "control.rounds")
round_store: RoundStore = SQLRoundStore()

repository = Repository(modelstorage_config["storage_config"], init_buckets=False)

Expand Down
108 changes: 105 additions & 3 deletions fedn/network/storage/statestore/stores/round_store.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Optional, Tuple

import pymongo
from bson import ObjectId
from pymongo.database import Database
from sqlalchemy import ForeignKey, String, func, select
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.sql import text

from fedn.network.storage.statestore.stores.store import MongoDBStore
from fedn.network.storage.statestore.stores.sql_models import RoundCombinerModel, RoundConfigModel, RoundDataModel, RoundModel
from fedn.network.storage.statestore.stores.store import MongoDBStore, MyAbstractBase, Session, SQLStore, Store

from .shared import EntityNotFound, from_document

Expand All @@ -19,7 +23,11 @@ def __init__(self, id: str, round_id: str, status: str, round_config: dict, comb
self.round_data = round_data


class RoundStore(MongoDBStore[Round]):
class RoundStore(Store[Round]):
pass


class MongoDBRoundStore(MongoDBStore[Round]):
def __init__(self, database: Database, collection: str):
super().__init__(database, collection)
self.database[self.collection].create_index([("round_id", pymongo.DESCENDING)])
Expand Down Expand Up @@ -76,3 +84,97 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI
return: The entities
"""
return super().list(limit, skip, sort_key or "round_id", sort_order, **kwargs)


def from_row(row: dict) -> Round:
return {
"id": row.id,
"committed_at": row.committed_at,
"address": row.address,
"ip": row.ip,
"name": row.name,
"parent": row.parent,
"fqdn": row.fqdn,
"port": row.port,
"updated_at": row.updated_at,
}


class SQLRoundStore(RoundStore, SQLStore[Round]):
def get(self, id: str) -> Round:
raise NotImplementedError

def update(self, id, item):
raise NotImplementedError

def add(self, item: Round) -> Tuple[bool, Any]:
with Session() as session:
round_id = item["round_id"]
stmt = select(RoundModel).where(RoundModel.round_id == round_id)
item = session.scalars(stmt).first()

if item is not None:
return False, "Round with round_id already exists"

round_data = RoundDataModel(
time_commit=item["round_data"]["time_commit"],
reduce_time_aggregate_models=item["round_data"]["reduce_time_aggregate_models"],
reduce_time_fetch_models=item["round_data"]["reduce_time_fetch_models"],
)

round_config = RoundConfigModel(
aggregator=item["round_config"]["aggregator"],
round_timeout=item["round_config"]["round_timeout"],
buffer_size=item["round_config"]["buffer_size"],
delete_models_storage=item["round_config"]["delete_models_storage"],
clients_required=item["round_config"]["clients_required"],
validate=item["round_config"]["validate"],
helper_type=item["round_config"]["helper_type"],
task=item["round_config"]["task"],
)

combiners = []

for combiner in item["combiners"]:
combiners.append(
RoundCombinerModel(
committed_at=combiner["committed_at"],
address=combiner["address"],
ip=combiner["ip"],
name=combiner["name"],
parent=combiner["parent"],
fqdn=combiner["fqdn"],
port=combiner["port"],
updated_at=combiner["updated_at"],
)
)

entity = RoundModel(
round_id=item["round_id"],
status=item["status"],
round_config=round_config,
combiners=combiners,
round_data=round_data,
)

session.add(entity)
session.commit()

return True, from_row(entity)

def delete(self, id: str) -> bool:
raise NotImplementedError

def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDING, **kwargs):
raise NotImplementedError

def count(self, **kwargs):
with Session() as session:
stmt = select(func.count()).select_from(RoundModel)

for key, value in kwargs.items():
stmt = stmt.where(getattr(RoundModel, key) == value)

count = session.scalar(stmt)

return count
70 changes: 70 additions & 0 deletions fedn/network/storage/statestore/stores/sql_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,73 @@ class ModelModel(MyAbstractBase):
session_configs: Mapped[List["SessionConfigModel"]] = relationship()
session_id: Mapped[Optional[str]] = mapped_column(ForeignKey("sessions.id"))
session: Mapped[Optional["SessionModel"]] = relationship(back_populates="models")


class RoundConfigModel(MyAbstractBase):
__tablename__ = "round_configs"

aggregator: Mapped[str] = mapped_column(String(255))
round_timeout: Mapped[int]
buffer_size: Mapped[int]
delete_models_storage: Mapped[bool]
clients_required: Mapped[int]
validate: Mapped[bool]
helper_type: Mapped[str] = mapped_column(String(255))
model_id: Mapped[str] = mapped_column(ForeignKey("models.id"))
session_id: Mapped[str] = mapped_column(ForeignKey("sessions.id"))
session: Mapped[Optional["SessionModel"]] = relationship(back_populates="round_configs")
round: Mapped["RoundModel"] = relationship(back_populates="round_config")
task: Mapped[str] = mapped_column(String(255))


class RoundDataModel(MyAbstractBase):
__tablename__ = "round_data"

time_commit: Mapped[float]
reduce_time_aggregate_models: Mapped[float]
reduce_time_fetch_models: Mapped[float]
reduce_time_load_model: Mapped[float]
round: Mapped["RoundModel"] = relationship(back_populates="round_data")


class RoundCombinerModel(MyAbstractBase):
__tablename__ = "round_combiners"

model_id: Mapped[str] = mapped_column(ForeignKey("models.id"))
name: Mapped[str] = mapped_column(String(255))
round_id: Mapped[str]
parent_round_id: Mapped[str] = mapped_column(ForeignKey("rounds.id"))
status: Mapped[str] = mapped_column(String(255))
time_exec_training: Mapped[float]

config__job_id: Mapped[str] = mapped_column(String(255))
config_aggregator: Mapped[str] = mapped_column(String(255))
config_buffer_size: Mapped[int]
config_clients_required: Mapped[int]
config_delete_models_storage: Mapped[bool]
config_helper_type: Mapped[str] = mapped_column(String(255))
config_model_id: Mapped[str] = mapped_column(ForeignKey("models.id"))
config_round_id: Mapped[str]
config_round_timeout: Mapped[int]
config_rounds: Mapped[int]
config_session_id: Mapped[str] = mapped_column(ForeignKey("sessions.id"))
config_task: Mapped[str] = mapped_column(String(255))
config_validate: Mapped[bool]

data_aggregation_time_nr_aggregated_models: Mapped[int]
data_aggregation_time_time_model_aggregation: Mapped[float]
data_aggregation_time_time_model_load: Mapped[float]
data_nr_expected_updates: Mapped[int]
data_nr_required_updates: Mapped[int]
data_time_combination: Mapped[float]
data_timeout: Mapped[float]


class RoundModel(MyAbstractBase):
__tablename__ = "rounds"

round_id: Mapped[str] = mapped_column(unique=True) # TODO: Add unique constraint. Does this work?
status: Mapped[str] = mapped_column()
round_config: Mapped["RoundConfigModel"] = relationship(back_populates="round")
round_data: Mapped["RoundDataModel"] = relationship(back_populates="round")
combiners: Mapped[List["RoundCombinerModel"]] = relationship(back_populates="round")

0 comments on commit 2df0655

Please sign in to comment.