Skip to content

Commit

Permalink
massive refactors
Browse files Browse the repository at this point in the history
  • Loading branch information
visualDust committed Nov 23, 2023
1 parent e724b20 commit e2b000a
Show file tree
Hide file tree
Showing 11 changed files with 158 additions and 47 deletions.
2 changes: 1 addition & 1 deletion neetbox/config/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
DEFAULT_WORKSPACE_CONFIG = {
"name": None,
"version": None,
"logging": {"logdir": None},
"logging": {"level": "INFO", "logdir": None},
"pipeline": {
"updateInterval": 10,
},
Expand Down
5 changes: 3 additions & 2 deletions neetbox/daemon/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import neetbox
from neetbox.config import get_module_level_config
from neetbox.daemon._agent import neet_action as action
from neetbox.daemon.client._action_agent import _NeetActionManager as NeetActionManager
from neetbox.daemon.client._daemon_client import connect_daemon
from neetbox.daemon.server.daemonable_process import DaemonableProcess
from neetbox.logging import logger
Expand Down Expand Up @@ -83,4 +83,5 @@ def _try_attach_daemon():
__attach_daemon(_cfg)


__all__ = ["watch", "listen", "action", "_try_attach_daemon"]
action = NeetActionManager.register
__all__ = ["watch", "listen", "action", "NeetActionManager", "_try_attach_daemon"]
34 changes: 16 additions & 18 deletions neetbox/daemon/_agent.py → neetbox/daemon/client/_action_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,29 +26,29 @@ def eval_call(self, params: dict):
return self.function(**eval_params)


class _NeetAction(metaclass=Singleton):
class _NeetActionManager(metaclass=Singleton):
__ACTION_POOL: Registry = Registry("__NEET_ACTIONS")

def get_action_names():
action_names = _NeetAction.__ACTION_POOL.keys()
action_names = _NeetActionManager.__ACTION_POOL.keys()
actions = {}
for n in action_names:
actions[n] = _NeetAction.__ACTION_POOL[n].argspec
actions[n] = _NeetActionManager.__ACTION_POOL[n].argspec
return actions

def get_action_dict():
action_dict = {}
action_names = _NeetAction.__ACTION_POOL.keys()
action_names = _NeetActionManager.__ACTION_POOL.keys()
for name in action_names:
action = _NeetAction.__ACTION_POOL[name]
action = _NeetActionManager.__ACTION_POOL[name]
action_dict[name] = action.argspec.args
return action_dict

def eval_call(self, name: str, params: dict, callback: None):
if name not in _NeetAction.__ACTION_POOL:
if name not in _NeetActionManager.__ACTION_POOL:
logger.err(f"Could not find action with name {name}, action stopped.")
return False
target_action = _NeetAction.__ACTION_POOL[name]
target_action = _NeetActionManager.__ACTION_POOL[name]
logger.log(
f"Agent runs function '{target_action.name}', blocking = {target_action.blocking}"
)
Expand All @@ -70,35 +70,33 @@ def run_and_callback(target_action, params, callback):
@watch(initiative=True)
def _update_action_dict():
# for status updater
return _NeetAction.get_action_dict()
return _NeetActionManager.get_action_dict()

def register(self, *, name: Optional[str] = None, blocking: bool = False):
return functools.partial(self._register, name=name, blocking=blocking)
def register(name: Optional[str] = None, blocking: bool = False):
return functools.partial(_NeetActionManager._register, name=name, blocking=blocking)

def _register(self, function: Callable, name: str = None, blocking: bool = False):
def _register(function: Callable, name: str = None, blocking: bool = False):
packed = PackedAction(function=function, name=name, blocking=blocking)
_NeetAction.__ACTION_POOL._register(what=packed, name=packed.name, force=True)
_NeetAction._update_action_dict() # update for sync
_NeetActionManager.__ACTION_POOL._register(what=packed, name=packed.name, force=True)
_NeetActionManager._update_action_dict() # update for sync
return function


# singleton
neet_action = _NeetAction()
neet_action = _NeetActionManager()


# example
if __name__ == "__main__":
import time

action = neet_action

@action.register(name="some")
@_NeetActionManager.register(name="some")
def some(a, b):
time.sleep(1)
return f"a = {a}, b = {b}"

print("registered actions:")
action_dict = _NeetAction.get_action_dict()
action_dict = _NeetActionManager.get_action_dict()
print(action_dict)

def callback_fun(text):
Expand Down
4 changes: 2 additions & 2 deletions neetbox/daemon/client/_client_apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


from neetbox.config import get_module_level_config
from neetbox.daemon.client._connection import _local_http_client
from neetbox.daemon.client._connection import connection
from neetbox.logging import logger
from neetbox.utils import pkg
from neetbox.utils.framing import get_frame_module_traceback
Expand All @@ -27,6 +27,6 @@ def get_status_of(name=None):
name = name or ""
api_addr = f"{base_addr}/status"
logger.info(f"Fetching from {api_addr}")
r = _local_http_client.get(api_addr)
r = connection.http.get(api_addr)
_data = r.json()
return _data
75 changes: 66 additions & 9 deletions neetbox/daemon/client/_connection.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import asyncio
import functools
import logging
from typing import Callable, Optional

import httpx
import websocket

from neetbox.config import get_module_level_config
from neetbox.core import Registry
from neetbox.logging import logger
from neetbox.utils.mvc import Singleton

httpx_logger = logging.getLogger("httpx")
Expand All @@ -13,17 +19,68 @@
"https://": None,
}


def __load_http_client():
__local_http_client = httpx.Client(proxies=__no_proxy) # type: ignore
return __local_http_client
EVENT_TYPE_NAME_KEY = "event-type"
EVENT_PAYLOAD_NAME_KEY = "payload"


# singleton
_local_http_client: httpx.Client = __load_http_client()
class ClientConn(metaclass=Singleton):
http: httpx.Client = None

__ws_client: None # _websocket_client
__ws_subscription = Registry("__client_ws_subscription") # { event-type-name : list(Callable)}

def __init__(self) -> None:
cfg = get_module_level_config()

def __load_http_client():
__local_http_client = httpx.Client(proxies=__no_proxy) # type: ignore
return __local_http_client

# create htrtp client
ClientConn.http = __load_http_client()
# todo establishing socket connection

def __on_ws_message(msg):
logger.debug(f"ws received {msg}")
# message should be json
event_type_name = msg[EVENT_TYPE_NAME_KEY]
if event_type_name not in ClientConn.__ws_subscription:
logger.warn(
f"Client received a(n) {event_type_name} event but nobody subscribes it. Ignoring anyway."
)
for subscriber in ClientConn._ws_subscribe[event_type_name]:
try:
subscriber(msg[EVENT_PAYLOAD_NAME_KEY]) # pass payload message into subscriber
except Exception as e:
# subscriber throws error
logger.err(
f"Subscriber {subscriber} crashed on message event {event_type_name}, ignoring."
)

class Connection(metaclass = Singleton):
_http_client: httpx.Client
# _websocket_client
def __init__(self, cfg) -> None:
def ws_send(msg):
logger.debug(f"ws sending {msg}")
# send to ws if ws is connected, otherwise drop message? idk
pass

def ws_subscribe(event_type_name: str):
"""let a function subscribe to ws messages with event type name.
!!! dfor inner APIs only, do not use this in your code!
!!! developers should contorl blocking on their own functions
Args:
function (Callable): who is subscribing the event type
event_type_name (str, optional): Which event to listen. Defaults to None.
"""
return functools.partial(ClientConn._ws_subscribe, event_type_name=event_type_name)

def _ws_subscribe(function: Callable, event_type_name: str):
if event_type_name not in ClientConn.__ws_subscription:
# create subscriber list for event-type name if not exist
ClientConn.__ws_subscription._register([], event_type_name)
ClientConn.__ws_subscription[event_type_name].append(function)


# singleton
ClientConn() # run init
connection = ClientConn
6 changes: 3 additions & 3 deletions neetbox/daemon/client/_daemon_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Union

from neetbox.config import get_module_level_config
from neetbox.daemon.client._connection import _local_http_client
from neetbox.daemon.client._connection import connection
from neetbox.daemon.server._server import CLIENT_API_ROOT
from neetbox.logging import logger
from neetbox.pipeline._signal_and_slot import _update_value_dict
Expand All @@ -37,7 +37,7 @@ def _upload_thread(daemon_config, base_addr, display_name):
_headers = {"Content-Type": "application/json"}
try:
# upload data
resp = _local_http_client.post(_api_addr, json=_data, headers=_headers)
resp = connection.http.post(_api_addr, json=_data, headers=_headers)
if resp.is_error: # upload failed
raise IOError(f"Failed to upload data to daemon. ({resp.status_code})")
except Exception as e:
Expand Down Expand Up @@ -76,7 +76,7 @@ def connect_daemon(cfg=None, launch_upload_thread=True):
def _check_daemon_alive(_api_addr):
_api_name = "hello"
_api_addr = f"{_api_addr}/{_api_name}"
r = _local_http_client.get(_api_addr)
r = connection.http.get(_api_addr)
if r.is_error:
raise IOError(f"Daemon at {_api_addr} is not alive. ({r.status_code})")

Expand Down
2 changes: 0 additions & 2 deletions neetbox/daemon/server/_daemon_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

# sys.stdout=open(r'D:\Projects\ML\neetbox\logdir\daemon.log', 'a+')


# from neetbox.daemon._local_http_client import _local_http_client
print("========= Daemon =========")


Expand Down
2 changes: 2 additions & 0 deletions neetbox/logging/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from neetbox.config import get_module_level_config
from neetbox.logging.logger import DEFAULT_LOGGER as logger
from neetbox.logging.logger import set_log_level

_cfg = get_module_level_config()
logger.set_log_dir(_cfg["logdir"])
set_log_level(_cfg["level"])
from neetbox.logging.logger import LogSplitStrategies

__all__ = ["logger", "LogSplitStrategies"]
27 changes: 27 additions & 0 deletions neetbox/logging/_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
class LogWriter:
def write(self, raw_msg):
pass


class ConsoleLogWriter(metaclass=LogWriter):
def __init__(self) -> None:
pass

def write(self, raw_msg):
pass


class FileLogWriter(metaclass=LogWriter):
def __init__(self) -> None:
pass

def write(self, raw_msg):
pass


class WebSocketLogWriter(metaclass=LogWriter):
def __init__(self) -> None:
pass

def write(self, raw_msg):
pass
28 changes: 18 additions & 10 deletions neetbox/logging/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@

class LogLevel(Enum):
ALL = 4
INFO = 3
DEBUG = 2
DEBUG = 3
INFO = 2
WARNING = 1
ERROR = 0

Expand All @@ -52,15 +52,23 @@ def __ge__(self, other):
style_dict = {}
loggers_dict = {}

_global_log_level = LogLevel.ALL
_GLOBAL_LOG_LEVEL = LogLevel.INFO


def set_log_level(level: LogLevel):
if type(level) is str:
level = {
"ALL": LogLevel.ALL,
"DEBUG": LogLevel.DEBUG,
"INFO": LogLevel.INFO,
"WARNING": LogLevel.WARNING,
"ERROR": LogLevel.ERROR,
}[level]
if type(level) is int:
assert level >= 0 and level <= 3
level = LogLevel(level)
global _global_log_level
_global_log_level = level
global _GLOBAL_LOG_LEVEL
_GLOBAL_LOG_LEVEL = level


class LogMetadata:
Expand Down Expand Up @@ -322,7 +330,7 @@ def log(
return self

def ok(self, *message, flag="OK"):
if _global_log_level >= LogLevel.INFO:
if _GLOBAL_LOG_LEVEL >= LogLevel.INFO:
self.log(
*message,
prefix=f"[{colored_text(flag, 'green')}]",
Expand All @@ -333,7 +341,7 @@ def ok(self, *message, flag="OK"):
return self

def debug(self, *message, flag="DEBUG"):
if _global_log_level >= LogLevel.DEBUG:
if _GLOBAL_LOG_LEVEL >= LogLevel.DEBUG:
self.log(
*message,
prefix=f"[{colored_text(flag, 'cyan')}]",
Expand All @@ -344,7 +352,7 @@ def debug(self, *message, flag="DEBUG"):
return self

def info(self, *message, flag="INFO"):
if _global_log_level >= LogLevel.INFO:
if _GLOBAL_LOG_LEVEL >= LogLevel.INFO:
self.log(
*message,
prefix=f"[{colored_text(flag, 'white')}]",
Expand All @@ -355,7 +363,7 @@ def info(self, *message, flag="INFO"):
return self

def warn(self, *message, flag="WARNING"):
if _global_log_level >= LogLevel.WARNING:
if _GLOBAL_LOG_LEVEL >= LogLevel.WARNING:
self.log(
*message,
prefix=f"[{colored_text(flag, 'yellow')}]",
Expand All @@ -368,7 +376,7 @@ def warn(self, *message, flag="WARNING"):
def err(self, err, flag="ERROR", reraise=False):
if type(err) is not Exception:
err = RuntimeError(str(err))
if _global_log_level >= LogLevel.ERROR:
if _GLOBAL_LOG_LEVEL >= LogLevel.ERROR:
self.log(
str(err),
prefix=f"[{colored_text(flag,'red')}]",
Expand Down
20 changes: 20 additions & 0 deletions tests/test_daemon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
def test_neet_action():
import time

from neetbox.daemon import NeetActionManager, action

@action(name="some_func")
def some(a, b):
time.sleep(1)
return f"a = {a}, b = {b}"

print("registered actions:")
action_dict = NeetActionManager.get_action_dict()
print(action_dict)

def callback_fun(text):
print(f"callback_fun print: {text}")

NeetActionManager.eval_call("some", {"a": "3", "b": "4"}, callback=callback_fun)
print("you should see this line first before callback_fun print")
time.sleep(4)

0 comments on commit e2b000a

Please sign in to comment.