diff --git a/fedn/network/api/v1/client_routes.py b/fedn/network/api/v1/client_routes.py index fb268905b..e1eb7ef5a 100644 --- a/fedn/network/api/v1/client_routes.py +++ b/fedn/network/api/v1/client_routes.py @@ -7,7 +7,6 @@ bp = Blueprint("client", __name__, url_prefix=f"/api/{api_version}/clients") - @bp.route("/", methods=["GET"]) @jwt_auth_required(role="admin") def get_clients(): @@ -109,14 +108,10 @@ def get_clients(): type: string """ try: - limit, skip, sort_key, sort_order, _ = get_typed_list_headers(request.headers) + limit, skip, sort_key, sort_order = get_typed_list_headers(request.headers) kwargs = request.args.to_dict() - clients = client_store.list(limit, skip, sort_key, sort_order, use_typing=False, **kwargs) - - result = clients["result"] - - response = {"count": clients["count"], "result": result} + response = client_store.list(limit, skip, sort_key, sort_order, **kwargs) return jsonify(response), 200 except Exception: @@ -195,13 +190,9 @@ def list_clients(): type: string """ try: - limit, skip, sort_key, sort_order, _ = get_typed_list_headers(request.headers) + limit, skip, sort_key, sort_order = get_typed_list_headers(request.headers) kwargs = get_post_data_to_kwargs(request) - clients = client_store.list(limit, skip, sort_key, sort_order, use_typing=False, **kwargs) - - result = clients["result"] - - response = {"count": clients["count"], "result": result} + response = client_store.list(limit, skip, sort_key, sort_order, **kwargs) return jsonify(response), 200 except Exception: @@ -357,9 +348,7 @@ def get_client(id: str): type: string """ try: - client = client_store.get(id, use_typing=False) - - response = client + response = client_store.get(id) return jsonify(response), 200 except EntityNotFound: @@ -367,7 +356,7 @@ def get_client(id: str): except Exception: return jsonify({"message": "An unexpected error occurred"}), 500 -# delete client + @bp.route("/", methods=["DELETE"]) @jwt_auth_required(role="admin") def delete_client(id: str): diff --git a/fedn/network/api/v1/combiner_routes.py b/fedn/network/api/v1/combiner_routes.py index 02617b7bb..966aea1dd 100644 --- a/fedn/network/api/v1/combiner_routes.py +++ b/fedn/network/api/v1/combiner_routes.py @@ -102,15 +102,11 @@ def get_combiners(): type: string """ try: - limit, skip, sort_key, sort_order, _ = get_typed_list_headers(request.headers) + limit, skip, sort_key, sort_order = get_typed_list_headers(request.headers) kwargs = request.args.to_dict() - combiners = combiner_store.list(limit, skip, sort_key, sort_order, use_typing=False, **kwargs) - - result = combiners["result"] - - response = {"count": combiners["count"], "result": result} + response = combiner_store.list(limit, skip, sort_key, sort_order, **kwargs) return jsonify(response), 200 except Exception: @@ -185,15 +181,11 @@ def list_combiners(): type: string """ try: - limit, skip, sort_key, sort_order, _ = get_typed_list_headers(request.headers) + limit, skip, sort_key, sort_order = get_typed_list_headers(request.headers) kwargs = get_post_data_to_kwargs(request) - combiners = combiner_store.list(limit, skip, sort_key, sort_order, use_typing=False, **kwargs) - - result = combiners["result"] - - response = {"count": combiners["count"], "result": result} + response = combiner_store.list(limit, skip, sort_key, sort_order, **kwargs) return jsonify(response), 200 except Exception: @@ -331,8 +323,7 @@ def get_combiner(id: str): type: string """ try: - combiner = combiner_store.get(id, use_typing=False) - response = combiner + response = combiner_store.get(id) return jsonify(response), 200 except EntityNotFound: @@ -340,6 +331,7 @@ def get_combiner(id: str): except Exception: return jsonify({"message": "An unexpected error occurred"}), 500 + @bp.route("/", methods=["DELETE"]) @jwt_auth_required(role="admin") def delete_combiner(id: str): @@ -421,9 +413,7 @@ def number_of_clients_connected(): combiners = combiners.split(",") if combiners else [] response = client_store.connected_client_count(combiners) - result = { - "result": response - } + result = {"result": response} return jsonify(result), 200 except Exception: diff --git a/fedn/network/api/v1/model_routes.py b/fedn/network/api/v1/model_routes.py index 9a3abdc47..c5f4e2fae 100644 --- a/fedn/network/api/v1/model_routes.py +++ b/fedn/network/api/v1/model_routes.py @@ -101,14 +101,10 @@ def get_models(): type: string """ try: - limit, skip, sort_key, sort_order, _ = get_typed_list_headers(request.headers) + limit, skip, sort_key, sort_order = get_typed_list_headers(request.headers) kwargs = request.args.to_dict() - models = model_store.list(limit, skip, sort_key, sort_order, use_typing=False, **kwargs) - - result = models["result"] - - response = {"count": models["count"], "result": result} + response = model_store.list(limit, skip, sort_key, sort_order, **kwargs) return jsonify(response), 200 except Exception: @@ -186,14 +182,10 @@ def list_models(): type: string """ try: - limit, skip, sort_key, sort_order, _ = get_typed_list_headers(request.headers) + limit, skip, sort_key, sort_order = get_typed_list_headers(request.headers) kwargs = get_post_data_to_kwargs(request) - models = model_store.list(limit, skip, sort_key, sort_order, use_typing=False, **kwargs) - - result = models["result"] - - response = {"count": models["count"], "result": result} + response = model_store.list(limit, skip, sort_key, sort_order, **kwargs) return jsonify(response), 200 except Exception: @@ -335,7 +327,7 @@ def get_model(id: str): type: string """ try: - model = model_store.get(id, use_typing=False) + model = model_store.get(id) response = model @@ -386,7 +378,7 @@ def patch_model(id: str): type: string """ try: - model = model_store.get(id, use_typing=False) + model = model_store.get(id) data = request.get_json() _id = model["id"] @@ -451,7 +443,7 @@ def put_model(id: str): type: string """ try: - model = model_store.get(id, use_typing=False) + model = model_store.get(id) data = request.get_json() _id = model["id"] @@ -511,7 +503,7 @@ def get_descendants(id: str): try: limit = get_limit(request.headers) - descendants = model_store.list_descendants(id, limit or 10, use_typing=False) + descendants = model_store.list_descendants(id, limit or 10) response = descendants @@ -580,7 +572,7 @@ def get_ancestors(id: str): include_self: bool = include_self_param and include_self_param.lower() == "true" - ancestors = model_store.list_ancestors(id, limit or 10, include_self=include_self, reverse=reverse, use_typing=False) + ancestors = model_store.list_ancestors(id, limit or 10, include_self=include_self, reverse=reverse) response = ancestors @@ -626,7 +618,7 @@ def download(id: str): """ try: if minio_repository is not None: - model = model_store.get(id, use_typing=False) + model = model_store.get(id) model_id = model["model"] file = minio_repository.get_artifact_stream(model_id, modelstorage_config["storage_config"]["storage_bucket"]) @@ -680,7 +672,7 @@ def get_parameters(id: str): """ try: if minio_repository is not None: - model = model_store.get(id, use_typing=False) + model = model_store.get(id) model_id = model["model"] file = minio_repository.get_artifact_stream(model_id, modelstorage_config["storage_config"]["storage_bucket"]) diff --git a/fedn/network/api/v1/package_routes.py b/fedn/network/api/v1/package_routes.py index c92d77cd4..4ed138369 100644 --- a/fedn/network/api/v1/package_routes.py +++ b/fedn/network/api/v1/package_routes.py @@ -5,15 +5,12 @@ from fedn.common.config import FEDN_COMPUTE_PACKAGE_DIR from fedn.network.api.auth import jwt_auth_required -from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs, - get_typed_list_headers, get_use_typing, - package_store, repository) +from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers, package_store, repository from fedn.network.storage.statestore.stores.shared import EntityNotFound bp = Blueprint("package", __name__, url_prefix=f"/api/{api_version}/packages") - @bp.route("/", methods=["GET"]) @jwt_auth_required(role="admin") def get_packages(): @@ -119,14 +116,10 @@ def get_packages(): """ try: - limit, skip, sort_key, sort_order, _ = get_typed_list_headers(request.headers) + limit, skip, sort_key, sort_order = get_typed_list_headers(request.headers) kwargs = request.args.to_dict() - packages = package_store.list(limit, skip, sort_key, sort_order, use_typing=True, **kwargs) - - result = [package.__dict__ for package in packages["result"]] - - response = {"count": packages["count"], "result": result} + response = package_store.list(limit, skip, sort_key, sort_order, **kwargs) return jsonify(response), 200 except Exception: @@ -207,14 +200,10 @@ def list_packages(): type: string """ try: - limit, skip, sort_key, sort_order, _ = get_typed_list_headers(request.headers) + limit, skip, sort_key, sort_order = get_typed_list_headers(request.headers) kwargs = get_post_data_to_kwargs(request) - packages = package_store.list(limit, skip, sort_key, sort_order, use_typing=True, **kwargs) - - result = [package.__dict__ for package in packages["result"]] - - response = {"count": packages["count"], "result": result} + response = package_store.list(limit, skip, sort_key, sort_order, **kwargs) return jsonify(response), 200 except Exception: @@ -379,10 +368,7 @@ def get_package(id: str): type: string """ try: - use_typing: bool = get_use_typing(request.headers) - package = package_store.get(id, use_typing=use_typing) - - response = package.__dict__ if use_typing else package + response = package_store.get(id) return jsonify(response), 200 except EntityNotFound: @@ -420,9 +406,7 @@ def get_active_package(): type: string """ try: - use_typing: bool = get_use_typing(request.headers) - package = package_store.get_active(use_typing=use_typing) - response = package.__dict__ if use_typing else package + response = package_store.get_active() return jsonify(response), 200 except EntityNotFound: diff --git a/fedn/network/api/v1/prediction_routes.py b/fedn/network/api/v1/prediction_routes.py index e5ce8edb7..0ea34224a 100644 --- a/fedn/network/api/v1/prediction_routes.py +++ b/fedn/network/api/v1/prediction_routes.py @@ -4,7 +4,7 @@ from fedn.network.api.auth import jwt_auth_required from fedn.network.api.shared import control -from fedn.network.api.v1.shared import api_version, mdb, get_typed_list_headers, get_post_data_to_kwargs +from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers, mdb from fedn.network.storage.statestore.stores.model_store import ModelStore from fedn.network.storage.statestore.stores.prediction_store import PredictionStore from fedn.network.storage.statestore.stores.shared import EntityNotFound @@ -170,14 +170,10 @@ def get_predictions(): type: string """ try: - limit, skip, sort_key, sort_order, use_typing = get_typed_list_headers(request.headers) + limit, skip, sort_key, sort_order = get_typed_list_headers(request.headers) kwargs = request.args.to_dict() - predictions = prediction_store.list(limit, skip, sort_key, sort_order, use_typing=use_typing, **kwargs) - - result = [prediction.__dict__ for prediction in predictions["result"]] if use_typing else predictions["result"] - - response = {"count": predictions["count"], "result": result} + response = prediction_store.list(limit, skip, sort_key, sort_order, **kwargs) return jsonify(response), 200 except Exception: @@ -268,14 +264,10 @@ def list_predictions(): type: string """ try: - limit, skip, sort_key, sort_order, use_typing = get_typed_list_headers(request.headers) + limit, skip, sort_key, sort_order = get_typed_list_headers(request.headers) kwargs = get_post_data_to_kwargs(request) - predictions = prediction_store.list(limit, skip, sort_key, sort_order, use_typing=use_typing, **kwargs) - - result = [prediction.__dict__ for prediction in predictions["result"]] if use_typing else predictions["result"] - - response = {"count": predictions["count"], "result": result} + response = prediction_store.list(limit, skip, sort_key, sort_order, **kwargs) return jsonify(response), 200 except Exception: diff --git a/fedn/network/api/v1/round_routes.py b/fedn/network/api/v1/round_routes.py index 14476a091..c4093059c 100644 --- a/fedn/network/api/v1/round_routes.py +++ b/fedn/network/api/v1/round_routes.py @@ -90,15 +90,11 @@ def get_rounds(): type: string """ try: - limit, skip, sort_key, sort_order, _ = get_typed_list_headers(request.headers) + limit, skip, sort_key, sort_order = get_typed_list_headers(request.headers) kwargs = request.args.to_dict() - rounds = round_store.list(limit, skip, sort_key, sort_order, use_typing=False, **kwargs) - - result = rounds["result"] - - response = {"count": rounds["count"], "result": result} + response = round_store.list(limit, skip, sort_key, sort_order, **kwargs) return jsonify(response), 200 except Exception: @@ -169,15 +165,11 @@ def list_rounds(): type: string """ try: - limit, skip, sort_key, sort_order, _ = get_typed_list_headers(request.headers) + limit, skip, sort_key, sort_order = get_typed_list_headers(request.headers) kwargs = get_post_data_to_kwargs(request) - rounds = round_store.list(limit, skip, sort_key, sort_order, use_typing=False, **kwargs) - - result = rounds["result"] - - response = {"count": rounds["count"], "result": result} + response = round_store.list(limit, skip, sort_key, sort_order, **kwargs) return jsonify(response), 200 except Exception: @@ -305,7 +297,7 @@ def get_round(id: str): type: string """ try: - round = round_store.get(id, use_typing=False) + round = round_store.get(id) response = round return jsonify(response), 200 diff --git a/fedn/network/api/v1/session_routes.py b/fedn/network/api/v1/session_routes.py index 9158b47df..1079566fe 100644 --- a/fedn/network/api/v1/session_routes.py +++ b/fedn/network/api/v1/session_routes.py @@ -90,14 +90,10 @@ def get_sessions(): type: string """ try: - limit, skip, sort_key, sort_order, _ = get_typed_list_headers(request.headers) + limit, skip, sort_key, sort_order = get_typed_list_headers(request.headers) kwargs = request.args.to_dict() - sessions = session_store.list(limit, skip, sort_key, sort_order, use_typing=False, **kwargs) - - result = sessions["result"] - - response = {"count": sessions["count"], "result": result} + response = session_store.list(limit, skip, sort_key, sort_order, **kwargs) return jsonify(response), 200 except Exception: @@ -168,14 +164,10 @@ def list_sessions(): type: string """ try: - limit, skip, sort_key, sort_order, _ = get_typed_list_headers(request.headers) + limit, skip, sort_key, sort_order = get_typed_list_headers(request.headers) kwargs = get_post_data_to_kwargs(request) - sessions = session_store.list(limit, skip, sort_key, sort_order, use_typing=False, **kwargs) - - result = sessions["result"] - - response = {"count": sessions["count"], "result": result} + response = session_store.list(limit, skip, sort_key, sort_order, **kwargs) return jsonify(response), 200 except Exception: @@ -303,8 +295,7 @@ def get_session(id: str): type: string """ try: - session = session_store.get(id, use_typing=False) - response = session + response = session_store.get(id) return jsonify(response), 200 except EntityNotFound: @@ -386,7 +377,7 @@ def start_session(): if not session_id or session_id == "": return jsonify({"message": "Session ID is required"}), 400 - session = session_store.get(session_id, use_typing=False) + session = session_store.get(session_id) session_config = session["session_config"] model_id = session_config["model_id"] @@ -402,7 +393,7 @@ def start_session(): if nr_available_clients < min_clients: return jsonify({"message": f"Number of available clients is lower than the required minimum of {min_clients}"}), 400 - _ = model_store.get(model_id, use_typing=False) + _ = model_store.get(model_id) threading.Thread(target=control.start_session, args=(session_id, rounds, round_timeout)).start() @@ -451,7 +442,7 @@ def patch_session(id: str): type: string """ try: - session = session_store.get(id, use_typing=False) + session = session_store.get(id) data = request.get_json() _id = session["id"] @@ -516,7 +507,7 @@ def put_session(id: str): type: string """ try: - session = session_store.get(id, use_typing=False) + session = session_store.get(id) data = request.get_json() _id = session["id"] diff --git a/fedn/network/api/v1/shared.py b/fedn/network/api/v1/shared.py index 75b5d264b..946382cd5 100644 --- a/fedn/network/api/v1/shared.py +++ b/fedn/network/api/v1/shared.py @@ -3,8 +3,7 @@ import pymongo from pymongo.database import Database -from fedn.network.api.shared import (modelstorage_config, network_id, - statestore_config) +from fedn.network.api.shared import modelstorage_config, network_id, statestore_config from fedn.network.storage.s3.base import RepositoryBase from fedn.network.storage.s3.miniorepository import MINIORepository from fedn.network.storage.s3.repository import Repository @@ -39,11 +38,6 @@ def is_positive_integer(s): return s is not None and s.isdigit() and int(s) > 0 -def get_use_typing(headers: object) -> bool: - skip_typing: str = headers.get("X-Skip-Typing", "false") - return False if skip_typing.lower() == "true" else True - - def get_limit(headers: object) -> int: limit: str = headers.get("X-Limit") if is_positive_integer(limit): @@ -73,25 +67,32 @@ def get_typed_list_headers( limit: int = get_limit(headers) skip: int = get_skip(headers) - use_typing: bool = get_use_typing(headers) if sort_order is not None: sort_order = pymongo.ASCENDING if sort_order.lower() == "asc" else pymongo.DESCENDING else: sort_order = pymongo.DESCENDING - return limit, skip, sort_key, sort_order, use_typing + return limit, skip, sort_key, sort_order def get_post_data_to_kwargs(request: object) -> dict: - request_data = request.form.to_dict() + try: + # Try to get data from form + request_data = request.form.to_dict() + except Exception: + request_data = None if not request_data: - request_data = request.json + try: + # Try to get data from JSON + request_data = request.get_json() + except Exception: + request_data = {} kwargs = {} for key, value in request_data.items(): - if "," in value: + if isinstance(value, str) and "," in value: kwargs[key] = {"$in": value.split(",")} else: kwargs[key] = value diff --git a/fedn/network/api/v1/status_routes.py b/fedn/network/api/v1/status_routes.py index 0716f965b..00c69dae6 100644 --- a/fedn/network/api/v1/status_routes.py +++ b/fedn/network/api/v1/status_routes.py @@ -1,7 +1,7 @@ from flask import Blueprint, jsonify, request from fedn.network.api.auth import jwt_auth_required -from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers, get_use_typing, mdb +from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers, mdb from fedn.network.storage.statestore.stores.shared import EntityNotFound from fedn.network.storage.statestore.stores.status_store import StatusStore @@ -121,18 +121,10 @@ def get_statuses(): type: string """ try: - limit, skip, sort_key, sort_order, use_typing = get_typed_list_headers(request.headers) + limit, skip, sort_key, sort_order = get_typed_list_headers(request.headers) kwargs = request.args.to_dict() - # print all the typed headers - print(f"limit: {limit}, skip: {skip}, sort_key: {sort_key}, sort_order: {sort_order}, use_typing: {use_typing}") - print(f"kwargs: {kwargs}") - statuses = status_store.list(limit, skip, sort_key, sort_order, use_typing=use_typing, **kwargs) - - print(f"statuses: {statuses}") - result = [status.__dict__ for status in statuses["result"]] if use_typing else statuses["result"] - - response = {"count": statuses["count"], "result": result} + response = status_store.list(limit, skip, sort_key, sort_order, **kwargs) return jsonify(response), 200 except Exception: @@ -220,14 +212,10 @@ def list_statuses(): type: string """ try: - limit, skip, sort_key, sort_order, use_typing = get_typed_list_headers(request.headers) + limit, skip, sort_key, sort_order = get_typed_list_headers(request.headers) kwargs = get_post_data_to_kwargs(request) - statuses = status_store.list(limit, skip, sort_key, sort_order, use_typing=use_typing, **kwargs) - - result = [status.__dict__ for status in statuses["result"]] if use_typing else statuses["result"] - - response = {"count": statuses["count"], "result": result} + response = status_store.list(limit, skip, sort_key, sort_order, **kwargs) return jsonify(response), 200 except Exception: @@ -393,10 +381,7 @@ def get_status(id: str): type: string """ try: - use_typing: bool = get_use_typing(request.headers) - status = status_store.get(id, use_typing=use_typing) - - response = status.__dict__ if use_typing else status + response = status_store.get(id) return jsonify(response), 200 except EntityNotFound: diff --git a/fedn/network/api/v1/validation_routes.py b/fedn/network/api/v1/validation_routes.py index 665abbb4b..8294d41d4 100644 --- a/fedn/network/api/v1/validation_routes.py +++ b/fedn/network/api/v1/validation_routes.py @@ -1,7 +1,7 @@ from flask import Blueprint, jsonify, request from fedn.network.api.auth import jwt_auth_required -from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers, get_use_typing, mdb +from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers, mdb from fedn.network.storage.statestore.stores.shared import EntityNotFound from fedn.network.storage.statestore.stores.validation_store import ValidationStore @@ -128,14 +128,10 @@ def get_validations(): type: string """ try: - limit, skip, sort_key, sort_order, use_typing = get_typed_list_headers(request.headers) + limit, skip, sort_key, sort_order = get_typed_list_headers(request.headers) kwargs = request.args.to_dict() - validations = validation_store.list(limit, skip, sort_key, sort_order, use_typing=use_typing, **kwargs) - - result = [validation.__dict__ for validation in validations["result"]] if use_typing else validations["result"] - - response = {"count": validations["count"], "result": result} + response = validation_store.list(limit, skip, sort_key, sort_order, **kwargs) return jsonify(response), 200 except Exception: @@ -226,14 +222,10 @@ def list_validations(): type: string """ try: - limit, skip, sort_key, sort_order, use_typing = get_typed_list_headers(request.headers) + limit, skip, sort_key, sort_order = get_typed_list_headers(request.headers) kwargs = get_post_data_to_kwargs(request) - validations = validation_store.list(limit, skip, sort_key, sort_order, use_typing=use_typing, **kwargs) - - result = [validation.__dict__ for validation in validations["result"]] if use_typing else validations["result"] - - response = {"count": validations["count"], "result": result} + response = validation_store.list(limit, skip, sort_key, sort_order, **kwargs) return jsonify(response), 200 except Exception: @@ -406,10 +398,7 @@ def get_validation(id: str): type: string """ try: - use_typing: bool = get_use_typing(request.headers) - validation = validation_store.get(id, use_typing=use_typing) - - response = validation.__dict__ if use_typing else validation + response = validation_store.get(id) return jsonify(response), 200 except EntityNotFound: diff --git a/fedn/network/storage/statestore/stores/client_store.py b/fedn/network/storage/statestore/stores/client_store.py index dd521860f..4f5cd18e6 100644 --- a/fedn/network/storage/statestore/stores/client_store.py +++ b/fedn/network/storage/statestore/stores/client_store.py @@ -5,7 +5,7 @@ from bson import ObjectId from pymongo.database import Database -from fedn.network.storage.statestore.stores.store import Store +from fedn.network.storage.statestore.stores.store import MongoDBStore from .shared import EntityNotFound, from_document @@ -21,33 +21,23 @@ def __init__(self, id: str, name: str, combiner: str, combiner_preferred: str, i self.updated_at = updated_at self.last_seen = last_seen - def from_dict(data: dict) -> "Client": - return Client( - id=str(data["_id"]), - name=data["name"] if "name" in data else None, - combiner=data["combiner"] if "combiner" in data else None, - combiner_preferred=data["combiner_preferred"] if "combiner_preferred" in data else None, - ip=data["ip"] if "ip" in data else None, - status=data["status"] if "status" in data else None, - updated_at=data["updated_at"] if "updated_at" in data else None, - last_seen=data["last_seen"] if "last_seen" in data else None, - ) - -class ClientStore(Store[Client]): +class ClientStore(MongoDBStore[Client]): def __init__(self, database: Database, collection: str): super().__init__(database, collection) - def get(self, id: str, use_typing: bool = False) -> Client: + def get(self, id: str) -> Client: """Get an entity by id param id: The id of the entity type: str - param use_typing: Whether to return the entity as a typed object or as a dict - type: bool return: The entity """ - response = super().get(id, use_typing=use_typing) - return Client.from_dict(response) if use_typing else response + if ObjectId.is_valid(id): + response = super().get(id) + else: + obj = self._get_client_by_client_id(id) + response = from_document(obj) + return response def _get_client_by_client_id(self, client_id: str) -> Dict: document = self.database[self.collection].find_one({"client_id": client_id}) @@ -85,7 +75,7 @@ def delete(self, id: str) -> bool: return super().delete(document["_id"]) - def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDING, use_typing: bool = False, **kwargs) -> Dict[int, List[Client]]: + def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDING, **kwargs) -> Dict[int, List[Client]]: """List entities param limit: The maximum number of entities to return type: int @@ -95,18 +85,12 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI type: str param sort_order: The order to sort by type: pymongo.DESCENDING | pymongo.ASCENDING - param use_typing: Whether to return the entities as typed objects or as dicts - type: bool param kwargs: Additional query parameters type: dict example: {"key": "models"} return: A dictionary with the count and the result """ - response = super().list(limit, skip, sort_key or "last_seen", sort_order, use_typing=use_typing, **kwargs) - - result = [Client.from_dict(item) for item in response["result"]] if use_typing else response["result"] - - return {"count": response["count"], "result": result} + return super().list(limit, skip, sort_key or "last_seen", sort_order, **kwargs) def count(self, **kwargs) -> int: return super().count(**kwargs) diff --git a/fedn/network/storage/statestore/stores/combiner_store.py b/fedn/network/storage/statestore/stores/combiner_store.py index 2ad6437ea..8a938d06c 100644 --- a/fedn/network/storage/statestore/stores/combiner_store.py +++ b/fedn/network/storage/statestore/stores/combiner_store.py @@ -4,7 +4,7 @@ from bson import ObjectId from pymongo.database import Database -from fedn.network.storage.statestore.stores.store import Store +from fedn.network.storage.statestore.stores.store import MongoDBStore from .shared import EntityNotFound, from_document @@ -38,34 +38,16 @@ def __init__( self.status = status self.updated_at = updated_at - def from_dict(data: dict) -> "Combiner": - return Combiner( - id=str(data["_id"]), - name=data["name"] if "name" in data else None, - address=data["address"] if "address" in data else None, - certificate=data["certificate"] if "certificate" in data else None, - config=data["config"] if "config" in data else None, - fqdn=data["fqdn"] if "fqdn" in data else None, - ip=data["ip"] if "ip" in data else None, - key=data["key"] if "key" in data else None, - parent=data["parent"] if "parent" in data else None, - port=data["port"] if "port" in data else None, - status=data["status"] if "status" in data else None, - updated_at=data["updated_at"] if "updated_at" in data else None, - ) - - -class CombinerStore(Store[Combiner]): + +class CombinerStore(MongoDBStore[Combiner]): def __init__(self, database: Database, collection: str): super().__init__(database, collection) - def get(self, id: str, use_typing: bool = False) -> Combiner: + def get(self, id: str) -> Combiner: """Get an entity by id param id: The id of the entity type: str description: The id of the entity, can be either the id or the name (property) - param use_typing: Whether to return the entity as a typed object or as a dict - type: bool return: The entity """ if ObjectId.is_valid(id): @@ -77,7 +59,7 @@ def get(self, id: str, use_typing: bool = False) -> Combiner: if document is None: raise EntityNotFound(f"Entity with (id | name) {id} not found") - return Combiner.from_dict(document) if use_typing else from_document(document) + return from_document(document) def update(self, id: str, item: Combiner) -> bool: raise NotImplementedError("Update not implemented for CombinerStore") @@ -98,7 +80,7 @@ def delete(self, id: str) -> bool: return super().delete(document["_id"]) - def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDING, use_typing: bool = False, **kwargs) -> Dict[int, List[Combiner]]: + def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDING, **kwargs) -> Dict[int, List[Combiner]]: """List entities param limit: The maximum number of entities to return type: int @@ -108,18 +90,14 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI type: str param sort_order: The order to sort by type: pymongo.DESCENDING | pymongo.ASCENDING - param use_typing: Whether to return the entities as typed objects or as dicts - type: bool param kwargs: Additional query parameters type: dict example: {"key": "models"} return: A dictionary with the count and the result """ - response = super().list(limit, skip, sort_key or "updated_at", sort_order, use_typing=use_typing, **kwargs) - - result = [Combiner.from_dict(item) for item in response["result"]] if use_typing else response["result"] + response = super().list(limit, skip, sort_key or "updated_at", sort_order, **kwargs) - return {"count": response["count"], "result": result} + return response def count(self, **kwargs) -> int: return super().count(**kwargs) diff --git a/fedn/network/storage/statestore/stores/model_store.py b/fedn/network/storage/statestore/stores/model_store.py index 27efcc9a3..d6b96121b 100644 --- a/fedn/network/storage/statestore/stores/model_store.py +++ b/fedn/network/storage/statestore/stores/model_store.py @@ -5,7 +5,7 @@ from bson import ObjectId from pymongo.database import Database -from fedn.network.storage.statestore.stores.store import Store +from fedn.network.storage.statestore.stores.store import MongoDBStore from .shared import EntityNotFound, from_document @@ -19,28 +19,16 @@ def __init__(self, id: str, key: str, model: str, parent_model: str, session_id: self.session_id = session_id self.committed_at = committed_at - def from_dict(data: dict) -> "Model": - return Model( - id=str(data["_id"]), - key=data["key"] if "key" in data else None, - model=data["model"] if "model" in data else None, - parent_model=data["parent_model"] if "parent_model" in data else None, - session_id=data["session_id"] if "session_id" in data else None, - committed_at=data["committed_at"] if "committed_at" in data else None, - ) - -class ModelStore(Store[Model]): +class ModelStore(MongoDBStore[Model]): def __init__(self, database: Database, collection: str): super().__init__(database, collection) - def get(self, id: str, use_typing: bool = False) -> Model: + def get(self, id: str) -> Model: """Get an entity by id param id: The id of the entity type: str description: The id of the entity, can be either the id or the model (property) - param use_typing: Whether to return the entity as a typed object or as a dict - type: bool return: The entity """ kwargs = {"key": "models"} @@ -55,7 +43,7 @@ def get(self, id: str, use_typing: bool = False) -> Model: if document is None: raise EntityNotFound(f"Entity with (id | model) {id} not found") - return Model.from_dict(document) if use_typing else from_document(document) + return from_document(document) def _validate(self, item: Model) -> Tuple[bool, str]: if "model" not in item or not item["model"]: @@ -82,7 +70,7 @@ def add(self, item: Model) -> Tuple[bool, Any]: def delete(self, id: str) -> bool: raise NotImplementedError("Delete not implemented for ModelStore") - def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDING, use_typing: bool = False, **kwargs) -> Dict[int, List[Model]]: + def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDING, **kwargs) -> Dict[int, List[Model]]: """List entities param limit: The maximum number of entities to return type: int @@ -92,8 +80,6 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI type: str param sort_order: The order to sort by type: pymongo.DESCENDING | pymongo.ASCENDING - param use_typing: Whether to return the entities as typed objects or as dicts - type: bool param kwargs: Additional query parameters type: dict example: {"key": "models"} @@ -101,20 +87,15 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI """ kwargs["key"] = "models" - response = super().list(limit, skip, sort_key or "committed_at", sort_order, use_typing=use_typing, **kwargs) - - result = [Model.from_dict(item) for item in response["result"]] if use_typing else response["result"] - return {"count": response["count"], "result": result} + return super().list(limit, skip, sort_key or "committed_at", sort_order, **kwargs) - def list_descendants(self, id: str, limit: int, use_typing: bool = False) -> List[Model]: + def list_descendants(self, id: str, limit: int) -> List[Model]: """List descendants param id: The id of the entity type: str description: The id of the entity, can be either the id or the model (property) param limit: The maximum number of entities to return type: int - param use_typing: Whether to return the entities as typed objects or as dicts - type: bool return: A list of entities """ kwargs = {"key": "models"} @@ -139,7 +120,7 @@ def list_descendants(self, id: str, limit: int, use_typing: bool = False) -> Lis model: str = self.database[self.collection].find_one({"key": "models", "parent_model": current_model_id}) if model is not None: - formatted_model = Model.from_dict(model) if use_typing else from_document(model) + formatted_model = Model.from_dict(model) result.append(formatted_model) current_model_id = model["model"] else: @@ -149,15 +130,13 @@ def list_descendants(self, id: str, limit: int, use_typing: bool = False) -> Lis return result - def list_ancestors(self, id: str, limit: int, include_self: bool = False, reverse: bool = False, use_typing: bool = False) -> List[Model]: + def list_ancestors(self, id: str, limit: int, include_self: bool = False, reverse: bool = False) -> List[Model]: """List ancestors param id: The id of the entity type: str description: The id of the entity, can be either the id or the model (property) param limit: The maximum number of entities to return type: int - param use_typing: Whether to return the entities as typed objects or as dicts - type: bool return: A list of entities """ kwargs = {"key": "models"} @@ -176,7 +155,7 @@ def list_ancestors(self, id: str, limit: int, include_self: bool = False, revers result: list = [] if include_self: - formatted_model = Model.from_dict(model) if use_typing else from_document(model) + formatted_model = from_document(model) result.append(formatted_model) for _ in range(limit): @@ -186,7 +165,7 @@ def list_ancestors(self, id: str, limit: int, include_self: bool = False, revers model = self.database[self.collection].find_one({"key": "models", "model": current_model_id}) if model is not None: - formatted_model = Model.from_dict(model) if use_typing else from_document(model) + formatted_model = from_document(model) result.append(formatted_model) current_model_id = model["parent_model"] else: diff --git a/fedn/network/storage/statestore/stores/package_store.py b/fedn/network/storage/statestore/stores/package_store.py index de55d888c..44dece2ab 100644 --- a/fedn/network/storage/statestore/stores/package_store.py +++ b/fedn/network/storage/statestore/stores/package_store.py @@ -7,9 +7,28 @@ from pymongo.database import Database from werkzeug.utils import secure_filename -from fedn.network.storage.statestore.stores.store import Store +from fedn.network.storage.statestore.stores.store import MongoDBStore -from .shared import EntityNotFound, from_document +from .shared import EntityNotFound + + +def from_document(data: dict, active_package: dict): + active = False + if active_package: + if "id" in active_package and "id" in data: + active = active_package["id"] == data["id"] + + return { + "id": data["id"] if "id" in data else None, + "key": data["key"] if "key" in data else None, + "committed_at": data["committed_at"] if "committed_at" in data else None, + "description": data["description"] if "description" in data else None, + "file_name": data["file_name"] if "file_name" in data else None, + "helper": data["helper"] if "helper" in data else None, + "name": data["name"] if "name" in data else None, + "storage_file_name": data["storage_file_name"] if "storage_file_name" in data else None, + "active": active, + } class Package: @@ -26,38 +45,16 @@ def __init__( self.storage_file_name = storage_file_name self.active = active - def from_dict(data: dict, active_package: dict) -> "Package": - active = False - if active_package: - if "id" in active_package and "id" in data: - active = active_package["id"] == data["id"] - - return Package( - id=data["id"] if "id" in data else None, - key=data["key"] if "key" in data else None, - committed_at=data["committed_at"] if "committed_at" in data else None, - description=data["description"] if "description" in data else None, - file_name=data["file_name"] if "file_name" in data else None, - helper=data["helper"] if "helper" in data else None, - name=data["name"] if "name" in data else None, - storage_file_name=data["storage_file_name"] if "storage_file_name" in data else None, - active=active, - ) - - -class PackageStore(Store[Package]): + +class PackageStore(MongoDBStore[Package]): def __init__(self, database: Database, collection: str): super().__init__(database, collection) - def get(self, id: str, use_typing: bool = False) -> Package: + def get(self, id: str) -> Package: """Get an entity by id param id: The id of the entity type: str - description: The id of the entity, can be either the id or the model (property) - param use_typing: Whether to return the entity as a typed object or as a dict - type: bool - description: Whether to return the entities as typed objects or as dicts. - If True, and active property will be set based on the active package. + description: The id of the entity, can be either the id or the docuemnt _id return: The entity """ document = self.database[self.collection].find_one({"id": id}) @@ -65,12 +62,9 @@ def get(self, id: str, use_typing: bool = False) -> Package: if document is None: raise EntityNotFound(f"Entity with id {id} not found") - if not use_typing: - return from_document(document) - response_active = self.database[self.collection].find_one({"key": "active"}) - return Package.from_dict(document, response_active) + return from_document(document, response_active) def _validate(self, item: Package) -> Tuple[bool, str]: if "file_name" not in item or not item["file_name"]: @@ -130,19 +124,16 @@ def set_active(self, id: str) -> bool: return True - def get_active(self, use_typing: bool = False) -> Package: + def get_active(self) -> Package: """Get the active entity - param use_typing: Whether to return the entity as a typed object or as a dict - type: bool return: The entity """ kwargs = {"key": "active"} response = self.database[self.collection].find_one(kwargs) - if response is None: raise EntityNotFound("Entity not found") - return Package.from_dict(response, response) if use_typing else from_document(response) + return from_document(response, {"id": response["id"]}) def set_active_helper(self, helper: str) -> bool: """Set the active helper @@ -217,7 +208,7 @@ def delete_active(self): return super().delete(document_active["_id"]) - def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDING, use_typing: bool = False, **kwargs) -> Dict[int, List[Package]]: + def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDING, **kwargs) -> Dict[int, List[Package]]: """List entities param limit: The maximum number of entities to return type: int @@ -227,10 +218,6 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI type: str param sort_order: The order to sort by type: pymongo.DESCENDING | pymongo.ASCENDING - param use_typing: Whether to return the entities as typed objects or as dicts - type: bool - description: Whether to return the entities as typed objects or as dicts. - If True, and active property will be set based on the active package. param kwargs: Additional query parameters type: dict example: {"key": "models"} @@ -238,13 +225,15 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI """ kwargs["key"] = "package_trail" - response = super().list(limit, skip, sort_key or "committed_at", sort_order, use_typing=True, **kwargs) + response = self.database[self.collection].find(kwargs).sort(sort_key or "committed_at", sort_order).skip(skip or 0).limit(limit or 0) + + count = self.database[self.collection].count_documents(kwargs) response_active = self.database[self.collection].find_one({"key": "active"}) - result = [Package.from_dict(item, response_active) for item in response["result"]] + result = [from_document(item, response_active) for item in response] - return {"count": response["count"], "result": result} + return {"count": count, "result": result} def count(self, **kwargs) -> int: kwargs["key"] = "package_trail" diff --git a/fedn/network/storage/statestore/stores/prediction_store.py b/fedn/network/storage/statestore/stores/prediction_store.py index 1ae29b94c..5b918c41e 100644 --- a/fedn/network/storage/statestore/stores/prediction_store.py +++ b/fedn/network/storage/statestore/stores/prediction_store.py @@ -3,7 +3,7 @@ import pymongo from pymongo.database import Database -from fedn.network.storage.statestore.stores.store import Store +from fedn.network.storage.statestore.stores.store import MongoDBStore class Prediction: @@ -20,35 +20,20 @@ def __init__( self.sender = sender self.receiver = receiver - def from_dict(data: dict) -> "Prediction": - return Prediction( - id=str(data["_id"]), - model_id=data["modelId"] if "modelId" in data else None, - data=data["data"] if "data" in data else None, - correlation_id=data["correlationId"] if "correlationId" in data else None, - timestamp=data["timestamp"] if "timestamp" in data else None, - prediction_id=data["predictionId"] if "predictionId" in data else None, - meta=data["meta"] if "meta" in data else None, - sender=data["sender"] if "sender" in data else None, - receiver=data["receiver"] if "receiver" in data else None, - ) - -class PredictionStore(Store[Prediction]): +class PredictionStore(MongoDBStore[Prediction]): def __init__(self, database: Database, collection: str): super().__init__(database, collection) - def get(self, id: str, use_typing: bool = False) -> Prediction: + def get(self, id: str) -> Prediction: """Get an entity by id param id: The id of the entity type: str description: The id of the entity, can be either the id or the Prediction (property) - param use_typing: Whether to return the entity as a typed object or as a dict - type: bool return: The entity """ - response = super().get(id, use_typing=use_typing) - return Prediction.from_dict(response) if use_typing else response + response = super().get(id) + return response def update(self, id: str, item: Prediction) -> bool: raise NotImplementedError("Update not implemented for PredictionStore") @@ -59,7 +44,7 @@ def add(self, item: Prediction) -> Tuple[bool, Any]: def delete(self, id: str) -> bool: raise NotImplementedError("Delete not implemented for PredictionStore") - def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDING, use_typing: bool = False, **kwargs) -> Dict[int, List[Prediction]]: + def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDING, **kwargs) -> Dict[int, List[Prediction]]: """List entities param limit: The maximum number of entities to return type: int @@ -73,12 +58,6 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI param sort_order: The order to sort by type: pymongo.DESCENDING description: The order to sort by - param use_typing: Whether to return the entities as typed objects or as dicts - type: bool - description: Whether to return the entities as typed objects or as dicts return: A dictionary with the count and a list of entities """ - response = super().list(limit, skip, sort_key or "timestamp", sort_order, use_typing=use_typing, **kwargs) - - result = [Prediction.from_dict(item) for item in response["result"]] if use_typing else response["result"] - return {"count": response["count"], "result": result} + return super().list(limit, skip, sort_key or "timestamp", sort_order, **kwargs) diff --git a/fedn/network/storage/statestore/stores/round_store.py b/fedn/network/storage/statestore/stores/round_store.py index 2eff2a993..9148f0c63 100644 --- a/fedn/network/storage/statestore/stores/round_store.py +++ b/fedn/network/storage/statestore/stores/round_store.py @@ -3,7 +3,7 @@ import pymongo from pymongo.database import Database -from fedn.network.storage.statestore.stores.store import Store +from fedn.network.storage.statestore.stores.store import MongoDBStore class Round: @@ -15,42 +15,29 @@ def __init__(self, id: str, round_id: str, status: str, round_config: dict, comb self.combiners = combiners self.round_data = round_data - def from_dict(data: dict) -> "Round": - return Round( - id=str(data["_id"]), - round_id=data["round_id"] if "round_id" in data else None, - status=data["status"] if "status" in data else None, - round_config=data["round_config"] if "round_config" in data else None, - combiners=data["combiners"] if "combiners" in data else None, - round_data=data["round_data"] if "round_data" in data else None - ) - -class RoundStore(Store[Round]): +class RoundStore(MongoDBStore[Round]): def __init__(self, database: Database, collection: str): super().__init__(database, collection) - def get(self, id: str, use_typing: bool = False) -> Round: + def get(self, id: str) -> Round: """Get an entity by id param id: The id of the entity type: str - param use_typing: Whether to return the entity as a typed object or as a dict - type: bool return: The entity """ - response = super().get(id, use_typing=use_typing) - return Round.from_dict(response) if use_typing else response + return super().get(id) def update(self, id: str, item: Round) -> bool: raise NotImplementedError("Update not implemented for RoundStore") - def add(self, item: Round)-> Tuple[bool, Any]: + def add(self, item: Round) -> Tuple[bool, Any]: raise NotImplementedError("Add not implemented for RoundStore") def delete(self, id: str) -> bool: raise NotImplementedError("Delete not implemented for RoundStore") - def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDING, use_typing: bool = False, **kwargs) -> Dict[int, List[Round]]: + def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDING, **kwargs) -> Dict[int, List[Round]]: """List entities param limit: The maximum number of entities to return type: int @@ -64,15 +51,6 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI param sort_order: The order to sort by type: pymongo.DESCENDING description: The order to sort by - param use_typing: Whether to return the entity as a typed object or as a dict - type: bool return: The entities """ - response = super().list(limit, skip, sort_key or "round_id", sort_order, use_typing=use_typing, **kwargs) - - result = [Round.from_dict(item) for item in response["result"]] if use_typing else response["result"] - - return { - "count": response["count"], - "result": result - } + return super().list(limit, skip, sort_key or "round_id", sort_order, **kwargs) diff --git a/fedn/network/storage/statestore/stores/session_store.py b/fedn/network/storage/statestore/stores/session_store.py index 7b29354ce..cd0a333de 100644 --- a/fedn/network/storage/statestore/stores/session_store.py +++ b/fedn/network/storage/statestore/stores/session_store.py @@ -6,7 +6,7 @@ from bson import ObjectId from pymongo.database import Database -from fedn.network.storage.statestore.stores.store import Store +from fedn.network.storage.statestore.stores.store import MongoDBStore from .shared import EntityNotFound, from_document @@ -18,16 +18,8 @@ def __init__(self, id: str, session_id: str, status: str, session_config: dict = self.status = status self.session_config = session_config - def from_dict(data: dict) -> "Session": - return Session( - id=str(data["_id"]), - session_id=data["session_id"] if "session_id" in data else None, - status=data["status"] if "status" in data else None, - session_config=data["session_config"] if "session_config" in data else None, - ) - -class SessionStore(Store[Session]): +class SessionStore(MongoDBStore[Session]): def __init__(self, database: Database, collection: str): super().__init__(database, collection) @@ -101,13 +93,11 @@ def _complement(self, item: Session): if "session_id" not in item or item["session_id"] == "" or not isinstance(item["session_id"], str): item["session_id"] = str(uuid.uuid4()) - def get(self, id: str, use_typing: bool = False) -> Session: + def get(self, id: str) -> Session: """Get an entity by id param id: The id of the entity type: str description: The id of the entity, can be either the id or the session_id (property) - param use_typing: Whether to return the entity as a typed object or as a dict - type: bool return: The entity """ if ObjectId.is_valid(id): @@ -119,7 +109,7 @@ def get(self, id: str, use_typing: bool = False) -> Session: if document is None: raise EntityNotFound(f"Entity with (id | session_id) {id} not found") - return Session.from_dict(document) if use_typing else from_document(document) + return from_document(document) def update(self, id: str, item: Session) -> Tuple[bool, Any]: valid, message = self._validate(item) @@ -146,7 +136,7 @@ def add(self, item: Session) -> Tuple[bool, Any]: def delete(self, id: str) -> bool: raise NotImplementedError("Delete not implemented for SessionStore") - def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDING, use_typing: bool = False, **kwargs) -> Dict[int, List[Session]]: + def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDING, **kwargs) -> Dict[int, List[Session]]: """List entities param limit: The maximum number of entities to return type: int @@ -160,16 +150,9 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI param sort_order: The order to sort by type: pymongo.DESCENDING description: The order to sort by - param use_typing: Whether to return the entity as a typed object or as a dict - type: bool - description: Whether to return the entities as typed objects or as dicts. param kwargs: Additional query parameters type: dict description: Additional query parameters return: The entities """ - response = super().list(limit, skip, sort_key or "session_id", sort_order, use_typing=use_typing, **kwargs) - - result = [Session.from_dict(item) for item in response["result"]] if use_typing else response["result"] - - return {"count": response["count"], "result": result} + return super().list(limit, skip, sort_key or "session_id", sort_order, **kwargs) diff --git a/fedn/network/storage/statestore/stores/status_store.py b/fedn/network/storage/statestore/stores/status_store.py index 9233d0b23..a6aae34e8 100644 --- a/fedn/network/storage/statestore/stores/status_store.py +++ b/fedn/network/storage/statestore/stores/status_store.py @@ -3,7 +3,7 @@ import pymongo from pymongo.database import Database -from fedn.network.storage.statestore.stores.store import Store +from fedn.network.storage.statestore.stores.store import MongoDBStore class Status: @@ -21,36 +21,19 @@ def __init__( self.session_id = session_id self.sender = sender - def from_dict(data: dict) -> "Status": - return Status( - id=str(data["_id"]), - status=data["status"] if "status" in data else None, - timestamp=data["timestamp"] if "timestamp" in data else None, - log_level=data["logLevel"] if "logLevel" in data else None, - data=data["data"] if "data" in data else None, - correlation_id=data["correlationId"] if "correlationId" in data else None, - type=data["type"] if "type" in data else None, - extra=data["extra"] if "extra" in data else None, - session_id=data["sessionId"] if "sessionId" in data else None, - sender=data["sender"] if "sender" in data else None, - ) - -class StatusStore(Store[Status]): +class StatusStore(MongoDBStore[Status]): def __init__(self, database: Database, collection: str): super().__init__(database, collection) - def get(self, id: str, use_typing: bool = False) -> Status: + def get(self, id: str) -> Status: """Get an entity by id param id: The id of the entity type: str description: The id of the entity, can be either the id or the status (property) - param use_typing: Whether to return the entity as a typed object or as a dict - type: bool return: The entity """ - response = super().get(id, use_typing=use_typing) - return Status.from_dict(response) if use_typing else response + return super().get(id) def update(self, id: str, item: Status) -> bool: raise NotImplementedError("Update not implemented for StatusStore") @@ -61,7 +44,7 @@ def add(self, item: Status) -> Tuple[bool, Any]: def delete(self, id: str) -> bool: raise NotImplementedError("Delete not implemented for StatusStore") - def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDING, use_typing: bool = False, **kwargs) -> Dict[int, List[Status]]: + def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDING, **kwargs) -> Dict[int, List[Status]]: """List entities param limit: The maximum number of entities to return type: int @@ -75,12 +58,5 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI param sort_order: The order to sort by type: pymongo.DESCENDING description: The order to sort by - param use_typing: Whether to return the entities as typed objects or as dicts - type: bool - description: Whether to return the entities as typed objects or as dicts. """ - response = super().list(limit, skip, sort_key or "timestamp", sort_order, use_typing=use_typing, **kwargs) - - result = [Status.from_dict(item) for item in response["result"]] if use_typing else response["result"] - - return {"count": response["count"], "result": result} + return super().list(limit, skip, sort_key or "timestamp", sort_order, **kwargs) diff --git a/fedn/network/storage/statestore/stores/store.py b/fedn/network/storage/statestore/stores/store.py index f6a8f67e0..ec5e4e9be 100644 --- a/fedn/network/storage/statestore/stores/store.py +++ b/fedn/network/storage/statestore/stores/store.py @@ -1,3 +1,4 @@ +from abc import ABC, abstractmethod from typing import Any, Dict, Generic, List, Tuple, TypeVar import pymongo @@ -9,26 +10,51 @@ T = TypeVar("T") -class Store(Generic[T]): +class Store(ABC, Generic[T]): + @abstractmethod + def get(self, id: str) -> T: + pass + + @abstractmethod + def update(self, id: str, item: T) -> Tuple[bool, Any]: + pass + + @abstractmethod + def add(self, item: T) -> Tuple[bool, Any]: + pass + + @abstractmethod + def delete(self, id: str) -> bool: + pass + + @abstractmethod + def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDING, **kwargs) -> Dict[int, List[T]]: + pass + + @abstractmethod + def count(self, **kwargs) -> int: + pass + + +class MongoDBStore(Store[T], Generic[T]): def __init__(self, database: Database, collection: str): self.database = database self.collection = collection - def get(self, id: str, use_typing: bool = False) -> T: + def get(self, id: str) -> T: """Get an entity by id param id: The id of the entity type: str - param use_typing: Whether to return the entity as a typed object or as a dict - type: bool return: The entity """ + if not ObjectId.is_valid(id): + raise EntityNotFound(f"Invalid id {id}") id_obj = ObjectId(id) document = self.database[self.collection].find_one({"_id": id_obj}) - if document is None: raise EntityNotFound(f"Entity with id {id} not found") - return from_document(document) if not use_typing else document + return from_document(document) def update(self, id: str, item: T) -> Tuple[bool, Any]: try: @@ -54,7 +80,7 @@ def delete(self, id: str) -> bool: result = self.database[self.collection].delete_one({"_id": ObjectId(id)}) return result.deleted_count == 1 - def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDING, use_typing: bool = False, **kwargs) -> Dict[int, List[T]]: + def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDING, **kwargs) -> Dict[int, List[T]]: """List entities param limit: The maximum number of entities to return type: int @@ -64,8 +90,6 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI type: str param sort_order: The order to sort by type: pymongo.DESCENDING | pymongo.ASCENDING - param use_typing: Whether to return the entities as typed objects or as dicts - type: bool param kwargs: Additional query parameters type: dict example: {"key": "models"} @@ -75,7 +99,7 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI count = self.database[self.collection].count_documents(kwargs) - result = [document for document in cursor] if use_typing else [from_document(document) for document in cursor] + result = [from_document(document) for document in cursor] return {"count": count, "result": result} diff --git a/fedn/network/storage/statestore/stores/validation_store.py b/fedn/network/storage/statestore/stores/validation_store.py index 59b5a0730..f5e9ef604 100644 --- a/fedn/network/storage/statestore/stores/validation_store.py +++ b/fedn/network/storage/statestore/stores/validation_store.py @@ -3,7 +3,7 @@ import pymongo from pymongo.database import Database -from fedn.network.storage.statestore.stores.store import Store +from fedn.network.storage.statestore.stores.store import MongoDBStore class Validation: @@ -20,35 +20,19 @@ def __init__( self.sender = sender self.receiver = receiver - def from_dict(data: dict) -> "Validation": - return Validation( - id=str(data["_id"]), - model_id=data["modelId"] if "modelId" in data else None, - data=data["data"] if "data" in data else None, - correlation_id=data["correlationId"] if "correlationId" in data else None, - timestamp=data["timestamp"] if "timestamp" in data else None, - session_id=data["sessionId"] if "sessionId" in data else None, - meta=data["meta"] if "meta" in data else None, - sender=data["sender"] if "sender" in data else None, - receiver=data["receiver"] if "receiver" in data else None, - ) - -class ValidationStore(Store[Validation]): +class ValidationStore(MongoDBStore[Validation]): def __init__(self, database: Database, collection: str): super().__init__(database, collection) - def get(self, id: str, use_typing: bool = False) -> Validation: + def get(self, id: str) -> Validation: """Get an entity by id param id: The id of the entity type: str description: The id of the entity, can be either the id or the validation (property) - param use_typing: Whether to return the entity as a typed object or as a dict - type: bool return: The entity """ - response = super().get(id, use_typing=use_typing) - return Validation.from_dict(response) if use_typing else response + return super().get(id) def update(self, id: str, item: Validation) -> bool: raise NotImplementedError("Update not implemented for ValidationStore") @@ -59,7 +43,7 @@ def add(self, item: Validation) -> Tuple[bool, Any]: def delete(self, id: str) -> bool: raise NotImplementedError("Delete not implemented for ValidationStore") - def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDING, use_typing: bool = False, **kwargs) -> Dict[int, List[Validation]]: + def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDING, **kwargs) -> Dict[int, List[Validation]]: """List entities param limit: The maximum number of entities to return type: int @@ -73,12 +57,6 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI param sort_order: The order to sort by type: pymongo.DESCENDING description: The order to sort by - param use_typing: Whether to return the entities as typed objects or as dicts - type: bool - description: Whether to return the entities as typed objects or as dicts return: A dictionary with the count and a list of entities """ - response = super().list(limit, skip, sort_key or "timestamp", sort_order, use_typing=use_typing, **kwargs) - - result = [Validation.from_dict(item) for item in response["result"]] if use_typing else response["result"] - return {"count": response["count"], "result": result} + return super().list(limit, skip, sort_key or "timestamp", sort_order, **kwargs)