diff --git a/fedn/fedn/network/api/interface.py b/fedn/fedn/network/api/interface.py index 6f3b993b6..4370c0428 100644 --- a/fedn/fedn/network/api/interface.py +++ b/fedn/fedn/network/api/interface.py @@ -768,6 +768,24 @@ def get_models(self, session_id: str = None, limit: str = None, skip: str = None return jsonify(result) + def get_model(self, model_id: str): + result = self.statestore.get_model(model_id) + + if result is None: + return ( + jsonify({"success": False, "message": "No model found."}), + 404, + ) + + payload = { + "committed_at": result["committed_at"], + "parent_model": result["parent_model"], + "model": result["model"], + "session_id": result["session_id"], + } + + return jsonify(payload) + def get_model_trail(self): """Get the model trail for a given session. @@ -784,6 +802,86 @@ def get_model_trail(self): {"success": False, "message": "No model trail available."} ) + def get_model_ancestors(self, model_id: str, limit: str = None): + """Get the model ancestors for a given model. + + :param model_id: The model id to get the model ancestors for. + :type model_id: str + :param limit: The number of ancestors to return. + :type limit: str + :return: The model ancestors for the given model as a json response. + :rtype: :class:`flask.Response` + """ + if model_id is None: + return jsonify( + {"success": False, "message": "No model id provided."} + ) + + limit: int = int(limit) if limit is not None else 10 # if limit is None, default to 10 + + response = self.statestore.get_model_ancestors(model_id, limit) + if response: + + arr: list = [] + + for element in response: + obj = { + "model": element["model"], + "committed_at": element["committed_at"], + "session_id": element["session_id"], + "parent_model": element["parent_model"], + } + arr.append(obj) + + result = {"result": arr} + + return jsonify(result) + else: + return jsonify( + {"success": False, "message": "No model ancestors available."} + ) + + def get_model_descendants(self, model_id: str, limit: str = None): + """Get the model descendants for a given model. + + :param model_id: The model id to get the model descendants for. + :type model_id: str + :param limit: The number of descendants to return. + :type limit: str + :return: The model descendants for the given model as a json response. + :rtype: :class:`flask.Response` + """ + + if model_id is None: + return jsonify( + {"success": False, "message": "No model id provided."} + ) + + limit: int = int(limit) if limit is not None else 10 + + response: list = self.statestore.get_model_descendants(model_id, limit) + + if response: + + arr: list = [] + + for element in response: + obj = { + "model": element["model"], + "committed_at": element["committed_at"], + "session_id": element["session_id"], + "parent_model": element["parent_model"], + } + arr.append(obj) + + result = {"result": arr} + + return jsonify(result) + else: + return jsonify( + {"success": False, "message": "No model descendants available."} + ) + def get_all_rounds(self): """Get all rounds. diff --git a/fedn/fedn/network/api/server.py b/fedn/fedn/network/api/server.py index 0b385c566..e88c140ca 100644 --- a/fedn/fedn/network/api/server.py +++ b/fedn/fedn/network/api/server.py @@ -29,6 +29,38 @@ def get_model_trail(): return api.get_model_trail() +@app.route("/get_model_ancestors", methods=["GET"]) +def get_model_ancestors(): + """Get the ancestors of a model. + param: model: The model id to get the ancestors for. + type: model: str + param: limit: The maximum number of ancestors to return. + type: limit: int + return: A list of model objects that the model derives from. + rtype: json + """ + model = request.args.get("model", None) + limit = request.args.get("limit", None) + + return api.get_model_ancestors(model, limit) + + +@app.route("/get_model_descendants", methods=["GET"]) +def get_model_descendants(): + """Get the ancestors of a model. + param: model: The model id to get the child for. + type: model: str + param: limit: The maximum number of descendants to return. + type: limit: int + return: A list of model objects that are descendents of the provided model id. + rtype: json + """ + model = request.args.get("model", None) + limit = request.args.get("limit", None) + + return api.get_model_descendants(model, limit) + + @app.route("/list_models", methods=["GET"]) def list_models(): """Get models from the statestore. @@ -50,6 +82,21 @@ def list_models(): return api.get_models(session_id, limit, skip, include_active) +@app.route("/get_model", methods=["GET"]) +def get_model(): + """Get a model from the statestore. + param: model: The model id to get. + type: model: str + return: The model as a json object. + rtype: json + """ + model = request.args.get("model", None) + if model is None: + return jsonify({"success": False, "message": "Missing model id."}), 400 + + return api.get_model(model) + + @app.route("/delete_model_trail", methods=["GET", "POST"]) def delete_model_trail(): """Delete the model trail for a given session. diff --git a/fedn/fedn/network/storage/statestore/mongostatestore.py b/fedn/fedn/network/storage/statestore/mongostatestore.py index fe6f93c51..aa7671802 100644 --- a/fedn/fedn/network/storage/statestore/mongostatestore.py +++ b/fedn/fedn/network/storage/statestore/mongostatestore.py @@ -184,11 +184,20 @@ def set_latest_model(self, model_id, session_id=None): """ committed_at = datetime.now() + current_model = self.model.find_one({"key": "current_model"}) + parent_model = None + + # if session_id is set the it means the model is generated from a session + # and we need to set the parent model + # if not the model is uploaded by the user and we don't need to set the parent model + if session_id is not None: + parent_model = current_model["model"] if current_model and "model" in current_model else None self.model.insert_one( { "key": "models", "model": model_id, + "parent_model": parent_model, "session_id": session_id, "committed_at": committed_at, } @@ -534,6 +543,71 @@ def get_model_trail(self): except (KeyError, IndexError): return None + def get_model_ancestors(self, model_id: str, limit: int): + """Get the model ancestors. + + :param model_id: The model id. + :type model_id: str + :param limit: The maximum number of ancestors to return. + :type limit: int + :return: List of model ancestors. + :rtype: list + """ + model = self.model.find_one({"key": "models", "model": model_id}) + current_model_id = model["parent_model"] if model is not None else None + result = [] + + for _ in range(limit): + if current_model_id is None: + break + + model = self.model.find_one({"key": "models", "model": current_model_id}) + + if model is not None: + result.append(model) + current_model_id = model["parent_model"] + + return result + + def get_model_descendants(self, model_id: str, limit: int): + """Get the model descendants. + + :param model_id: The model id. + :type model_id: str + :param limit: The maximum number of descendants to return. + :type limit: int + :return: List of model descendants. + :rtype: list + """ + + model: object = self.model.find_one({"key": "models", "model": model_id}) + current_model_id: str = model["model"] if model is not None else None + result: list = [] + + for _ in range(limit): + if current_model_id is None: + break + + model: str = self.model.find_one({"key": "models", "parent_model": current_model_id}) + + if model is not None: + result.append(model) + current_model_id = model["model"] + + result.reverse() + + return result + + def get_model(self, model_id): + """Get model with id. + + :param model_id: id of model to get + :type model_id: str + :return: model with id + :rtype: ObjectId + """ + return self.model.find_one({"key": "models", "model": model_id}) + def get_events(self, **kwargs): """Get events from the database.