Skip to content

Commit

Permalink
refator
Browse files Browse the repository at this point in the history
  • Loading branch information
niklastheman committed May 17, 2024
1 parent b4de80c commit 012e8ae
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 43 deletions.
40 changes: 3 additions & 37 deletions fedn/network/api/server.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,13 @@
import os
import threading

from flask import Flask, jsonify, request

from fedn.common.config import get_controller_config, get_modelstorage_config, get_network_config, get_statestore_config
from fedn.common.config import get_controller_config
from fedn.network.api.auth import jwt_auth_required
from fedn.network.api.interface import API
from fedn.network.api.v1 import _routes, session_store
from fedn.network.controller.control import Control
from fedn.network.storage.statestore.mongostatestore import MongoStateStore
from fedn.network.api.v1 import _routes
from fedn.network.api.shared import statestore, control

statestore_config = get_statestore_config()
network_id = get_network_config()
modelstorage_config = get_modelstorage_config()
statestore = MongoStateStore(network_id, statestore_config["mongo_config"])
statestore.set_storage_backend(modelstorage_config)
control = Control(statestore=statestore)

custom_url_prefix = os.environ.get("FEDN_CUSTOM_URL_PREFIX", False)
api = API(statestore, control)
Expand All @@ -31,32 +23,6 @@ def health_check():
return "OK", 200


@app.route("/api/v1/sessions/start", methods=["POST"])
@jwt_auth_required(role="admin")
def start_session_v2():
try:
data = request.json if request.headers["Content-Type"] == "application/json" else request.form.to_dict()
session_id: str = data.get("session_id")
rounds: int = data.get("rounds", "")

if not session_id or session_id == "":
return jsonify({"message": "Session ID is required"}), 400

if not rounds or rounds == "":
return jsonify({"message": "Rounds is required"}), 400

if not isinstance(rounds, int):
return jsonify({"message": "Rounds must be an integer"}), 400

_ = session_store.get(session_id, use_typing=False)

threading.Thread(target=control.start_session, args=(session_id, rounds)).start()

return jsonify({"message": "Session started"}), 200
except Exception as e:
return jsonify({"message": str(e)}), 500


if custom_url_prefix:
app.add_url_rule(f"{custom_url_prefix}/health", view_func=health_check, methods=["GET"])

Expand Down
11 changes: 11 additions & 0 deletions fedn/network/api/shared.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from fedn.common.config import get_modelstorage_config, get_network_config, get_statestore_config
from fedn.network.controller.control import Control
from fedn.network.storage.statestore.mongostatestore import MongoStateStore

statestore_config = get_statestore_config()
modelstorage_config = get_modelstorage_config()
network_id = get_network_config()

statestore = MongoStateStore(network_id, statestore_config["mongo_config"])
statestore.set_storage_backend(modelstorage_config)
control = Control(statestore=statestore)
3 changes: 2 additions & 1 deletion fedn/network/api/v1/model_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from flask import Blueprint, jsonify, request, send_file

from fedn.network.api.auth import jwt_auth_required
from fedn.network.api.v1.shared import api_version, get_limit, get_post_data_to_kwargs, get_reverse, get_typed_list_headers, mdb, modelstorage_config
from fedn.network.api.v1.shared import api_version, get_limit, get_post_data_to_kwargs, get_reverse, get_typed_list_headers, mdb
from fedn.network.api.shared import modelstorage_config
from fedn.network.storage.s3.base import RepositoryBase
from fedn.network.storage.s3.miniorepository import MINIORepository
from fedn.network.storage.statestore.stores.model_store import ModelStore
Expand Down
40 changes: 40 additions & 0 deletions fedn/network/api/v1/session_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers, mdb
from fedn.network.storage.statestore.stores.session_store import SessionStore
from fedn.network.storage.statestore.stores.shared import EntityNotFound
from .model_routes import model_store
from ..shared import control

bp = Blueprint("session", __name__, url_prefix=f"/api/{api_version}/sessions")

Expand Down Expand Up @@ -307,6 +309,7 @@ def get_session(id: str):
except Exception as e:
return jsonify({"message": str(e)}), 500


@bp.route("/", methods=["POST"])
@jwt_auth_required(role="admin")
def post():
Expand Down Expand Up @@ -348,3 +351,40 @@ def post():
return jsonify(response), status_code
except Exception as e:
return jsonify({"message": str(e)}), 500


@bp.route("/start", methods=["POST"])
@jwt_auth_required(role="admin")
def start_session():
"""Start a new session.
param: session_id: The session id to start.
type: session_id: str
param: rounds: The number of rounds to run.
type: rounds: int
"""
try:
data = request.json if request.headers["Content-Type"] == "application/json" else request.form.to_dict()
session_id: str = data.get("session_id")
rounds: int = data.get("rounds", "")

if not session_id or session_id == "":
return jsonify({"message": "Session ID is required"}), 400

if not rounds or rounds == "":
return jsonify({"message": "Rounds is required"}), 400

if not isinstance(rounds, int):
return jsonify({"message": "Rounds must be an integer"}), 400

session = session_store.get(session_id, use_typing=False)

session_config = session["session_config"]
model_id = session_config["model_id"]

_ = model_store.get(model_id, use_typing=False)

threading.Thread(target=control.start_session, args=(session_id, rounds)).start()

return jsonify({"message": "Session started"}), 200
except Exception as e:
return jsonify({"message": str(e)}), 500
6 changes: 1 addition & 5 deletions fedn/network/api/v1/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,10 @@
import pymongo
from pymongo.database import Database

from fedn.common.config import get_modelstorage_config, get_network_config, get_statestore_config
from ..shared import statestore_config, network_id

api_version = "v1"

statestore_config = get_statestore_config()
modelstorage_config = get_modelstorage_config()
network_id = get_network_config()

mc = pymongo.MongoClient(**statestore_config["mongo_config"])
mc.server_info()
mdb: Database = mc[network_id]
Expand Down

0 comments on commit 012e8ae

Please sign in to comment.