From 42733d2937ce6ff0ad2617e64dad72011bbcfd4d Mon Sep 17 00:00:00 2001 From: voidZXL Date: Tue, 10 Dec 2024 16:56:15 +0800 Subject: [PATCH] black all code, compat OrderBy for django 3.0 --- utilmeta/__init__.py | 22 +- utilmeta/bin/base.py | 139 +++-- utilmeta/bin/commands/base.py | 57 +- utilmeta/bin/commands/setup.py | 140 +++-- utilmeta/bin/constant.py | 61 +- utilmeta/bin/meta.py | 203 ++++--- utilmeta/bin/template/full/config/conf.py | 21 +- utilmeta/bin/template/full/config/env.py | 4 +- utilmeta/bin/template/full/config/service.py | 15 +- utilmeta/bin/template/full/main.py | 4 +- utilmeta/bin/template/full/service/api.py | 4 +- utilmeta/bin/template/lite/server.py | 27 +- utilmeta/bin/utils.py | 8 +- utilmeta/conf/base.py | 6 +- utilmeta/conf/env.py | 59 +- utilmeta/conf/http.py | 49 +- utilmeta/conf/pool.py | 5 +- utilmeta/conf/preference.py | 2 +- utilmeta/conf/time.py | 29 +- utilmeta/core/api/__init__.py | 17 +- utilmeta/core/api/base.py | 253 ++++---- utilmeta/core/api/chain.py | 71 ++- utilmeta/core/api/decorator.py | 154 +++-- utilmeta/core/api/endpoint.py | 179 +++--- utilmeta/core/api/hook.py | 97 +-- utilmeta/core/api/plugins/base.py | 7 +- utilmeta/core/api/plugins/cors.py | 117 ++-- utilmeta/core/api/plugins/retry.py | 125 ++-- utilmeta/core/api/route.py | 173 +++--- utilmeta/core/api/specs/base.py | 2 +- utilmeta/core/api/specs/openapi.py | 451 ++++++++------ utilmeta/core/auth/basic.py | 18 +- utilmeta/core/auth/jwt.py | 81 +-- utilmeta/core/auth/oauth2.py | 22 +- utilmeta/core/auth/plugins/require.py | 16 +- utilmeta/core/auth/properties.py | 86 ++- utilmeta/core/auth/session/base.py | 53 +- utilmeta/core/auth/session/cache.py | 38 +- utilmeta/core/auth/session/cached_db.py | 20 +- utilmeta/core/auth/session/db.py | 33 +- utilmeta/core/auth/session/django.py | 32 +- utilmeta/core/auth/session/schema.py | 108 ++-- utilmeta/core/auth/signature.py | 61 +- utilmeta/core/cache/backends/django.py | 40 +- .../core/cache/backends/redis/aioredis.py | 36 +- utilmeta/core/cache/backends/redis/config.py | 20 +- utilmeta/core/cache/backends/redis/entity.py | 94 ++- utilmeta/core/cache/backends/redis/lock.py | 7 +- .../cache/backends/redis/scripts/__init__.py | 10 +- utilmeta/core/cache/base.py | 21 +- utilmeta/core/cache/config.py | 171 +++--- utilmeta/core/cache/lock.py | 13 +- utilmeta/core/cache/plugins/api.py | 257 ++++---- utilmeta/core/cache/plugins/base.py | 182 +++--- utilmeta/core/cache/plugins/entity.py | 112 ++-- utilmeta/core/cache/plugins/sdk.py | 91 +-- utilmeta/core/cli/backends/aiohttp.py | 10 +- utilmeta/core/cli/backends/base.py | 11 +- utilmeta/core/cli/backends/httpx.py | 10 +- utilmeta/core/cli/backends/requests.py | 9 +- utilmeta/core/cli/backends/urllib.py | 16 +- utilmeta/core/cli/base.py | 554 ++++++++++-------- utilmeta/core/cli/chain.py | 66 ++- utilmeta/core/cli/endpoint.py | 128 ++-- utilmeta/core/cli/hook.py | 10 +- utilmeta/core/cli/specs/base.py | 2 +- utilmeta/core/cli/specs/openapi.py | 489 +++++++++------- utilmeta/core/file/backends/base.py | 23 +- utilmeta/core/file/backends/response.py | 1 + utilmeta/core/file/base.py | 63 +- utilmeta/core/orm/__init__.py | 10 +- utilmeta/core/orm/backends/base.py | 52 +- utilmeta/core/orm/backends/django/compiler.py | 281 +++++---- utilmeta/core/orm/backends/django/constant.py | 251 ++++---- utilmeta/core/orm/backends/django/database.py | 38 +- utilmeta/core/orm/backends/django/deletion.py | 50 +- .../core/orm/backends/django/exceptions.py | 12 +- .../core/orm/backends/django/expressions.py | 37 +- utilmeta/core/orm/backends/django/field.py | 156 +++-- .../core/orm/backends/django/generator.py | 43 +- utilmeta/core/orm/backends/django/model.py | 95 +-- utilmeta/core/orm/backends/django/models.py | 122 ++-- utilmeta/core/orm/backends/django/query.py | 44 +- utilmeta/core/orm/backends/django/queryset.py | 91 +-- utilmeta/core/orm/backends/peewee/example.py | 19 +- utilmeta/core/orm/backends/peewee/peewee.py | 8 +- utilmeta/core/orm/compiler.py | 77 ++- utilmeta/core/orm/context.py | 4 +- utilmeta/core/orm/databases/base.py | 4 +- utilmeta/core/orm/databases/config.py | 110 ++-- utilmeta/core/orm/databases/encode.py | 100 ++-- utilmeta/core/orm/encoder.py | 2 +- utilmeta/core/orm/exceptions.py | 9 +- utilmeta/core/orm/fields/field.py | 274 +++++---- utilmeta/core/orm/fields/filter.py | 49 +- utilmeta/core/orm/fields/order.py | 95 +-- utilmeta/core/orm/fields/pagination.py | 6 +- utilmeta/core/orm/fields/scope.py | 22 +- utilmeta/core/orm/generator.py | 29 +- utilmeta/core/orm/parser.py | 33 +- utilmeta/core/orm/plugins/atomic.py | 21 +- utilmeta/core/orm/schema.py | 163 +++--- utilmeta/core/request/backends/base.py | 38 +- utilmeta/core/request/backends/django.py | 47 +- utilmeta/core/request/backends/sanic.py | 4 +- utilmeta/core/request/backends/starlette.py | 15 +- utilmeta/core/request/backends/tornado.py | 4 +- utilmeta/core/request/backends/werkzeug.py | 4 +- utilmeta/core/request/base.py | 37 +- utilmeta/core/request/client.py | 72 ++- utilmeta/core/request/properties.py | 339 ++++++----- utilmeta/core/request/var.py | 54 +- utilmeta/core/response/backends/aiohttp.py | 5 +- utilmeta/core/response/backends/base.py | 21 +- utilmeta/core/response/backends/django.py | 10 +- utilmeta/core/response/backends/httpx.py | 4 +- utilmeta/core/response/backends/sanic.py | 4 +- utilmeta/core/response/backends/starlette.py | 12 +- utilmeta/core/response/backends/urllib.py | 3 +- utilmeta/core/response/backends/werkzeug.py | 4 +- utilmeta/core/response/base.py | 259 ++++---- utilmeta/core/server/backends/apiflask.py | 20 +- utilmeta/core/server/backends/base.py | 34 +- .../core/server/backends/django/adaptor.py | 201 ++++--- utilmeta/core/server/backends/django/cmd.py | 144 +++-- .../core/server/backends/django/settings.py | 428 ++++++++------ utilmeta/core/server/backends/django/utils.py | 21 +- utilmeta/core/server/backends/fastapi.py | 10 +- utilmeta/core/server/backends/flask.py | 82 +-- utilmeta/core/server/backends/sanic.py | 99 ++-- utilmeta/core/server/backends/starlette.py | 138 +++-- utilmeta/core/server/backends/tornado.py | 57 +- utilmeta/core/server/backends/werkzeug.py | 10 +- utilmeta/core/server/service.py | 256 +++++--- utilmeta/core/websocket/__init__.py | 1 - utilmeta/ops/__init__.py | 4 +- utilmeta/ops/aggregation.py | 139 ++--- utilmeta/ops/api/__init__.py | 183 +++--- utilmeta/ops/api/data.py | 64 +- utilmeta/ops/api/log.py | 157 ++--- utilmeta/ops/api/servers.py | 289 +++++---- utilmeta/ops/api/token.py | 30 +- utilmeta/ops/api/utils.py | 18 +- utilmeta/ops/client.py | 300 ++++++---- utilmeta/ops/cmd.py | 356 +++++++---- utilmeta/ops/config.py | 389 +++++++----- utilmeta/ops/connect.py | 161 ++--- utilmeta/ops/key.py | 14 +- utilmeta/ops/log.py | 340 ++++++----- utilmeta/ops/migrations/0001_initial.py | 301 ++++++++-- .../ops/migrations/0003_aggregationlog.py | 7 +- utilmeta/ops/models.py | 428 +++++++++----- utilmeta/ops/monitor.py | 167 ++++-- utilmeta/ops/proxy.py | 18 +- utilmeta/ops/query.py | 40 +- utilmeta/ops/resources.py | 305 +++++----- utilmeta/ops/schema.py | 45 +- utilmeta/ops/task.py | 472 ++++++++------- utilmeta/utils/__init__.py | 1 + utilmeta/utils/adaptor.py | 28 +- utilmeta/utils/base.py | 35 +- utilmeta/utils/constant/data.py | 271 +++++---- utilmeta/utils/constant/i18n.py | 296 ++++++---- utilmeta/utils/constant/vendor.py | 36 +- utilmeta/utils/constant/web.py | 257 ++++---- utilmeta/utils/context.py | 22 +- utilmeta/utils/datastructure.py | 43 +- utilmeta/utils/decorator.py | 119 +++- utilmeta/utils/error.py | 50 +- utilmeta/utils/exceptions/config.py | 6 +- utilmeta/utils/exceptions/http.py | 81 ++- utilmeta/utils/exceptions/runtime.py | 6 +- utilmeta/utils/functional/data.py | 238 +++++--- utilmeta/utils/functional/orm.py | 45 +- utilmeta/utils/functional/py.py | 149 +++-- utilmeta/utils/functional/sys.py | 300 ++++++---- utilmeta/utils/functional/time.py | 67 ++- utilmeta/utils/functional/web.py | 304 ++++++---- utilmeta/utils/logical.py | 34 +- utilmeta/utils/plugin.py | 191 ++++-- utilmeta/utils/schema/backends/attrs.py | 4 +- utilmeta/utils/schema/backends/dataclass.py | 4 +- utilmeta/utils/schema/base.py | 1 - 183 files changed, 10207 insertions(+), 6867 deletions(-) diff --git a/utilmeta/__init__.py b/utilmeta/__init__.py index dc6a885..41e527a 100644 --- a/utilmeta/__init__.py +++ b/utilmeta/__init__.py @@ -1,7 +1,7 @@ -__website__ = 'https://utilmeta.com' -__homepage__ = 'https://utilmeta.com/py' -__author__ = 'Xulin Zhou (@voidZXL)' -__version__ = '2.7.0' +__website__ = "https://utilmeta.com" +__homepage__ = "https://utilmeta.com/py" +__author__ = "Xulin Zhou (@voidZXL)" +__version__ = "2.7.0" def version_info() -> str: @@ -10,12 +10,14 @@ def version_info() -> str: from pathlib import Path info = { - 'utilmeta version': __version__, - 'installed path': Path(__file__).resolve().parent, - 'python version': sys.version, - 'platform': platform.platform(), + "utilmeta version": __version__, + "installed path": Path(__file__).resolve().parent, + "python version": sys.version, + "platform": platform.platform(), } - return '\n'.join('{:>30} {}'.format(k + ':', str(v).replace('\n', ' ')) for k, v in info.items()) + return "\n".join( + "{:>30} {}".format(k + ":", str(v).replace("\n", " ")) for k, v in info.items() + ) def init_settings(): @@ -35,6 +37,6 @@ def init_settings(): from .core.server.service import UtilMeta -service: 'UtilMeta' # current service in this process +service: "UtilMeta" # current service in this process _cmd_env = False diff --git a/utilmeta/bin/base.py b/utilmeta/bin/base.py index a234f06..f4183ce 100644 --- a/utilmeta/bin/base.py +++ b/utilmeta/bin/base.py @@ -11,7 +11,7 @@ from functools import partial -__all__ = ['command', 'Arg', 'BaseCommand'] +__all__ = ["command", "Arg", "BaseCommand"] def command(name: str = None, *aliases, options: Options = None): @@ -27,50 +27,52 @@ def wrapper(f): f.__command__ = f.__name__ if name is None else name f.__aliases__ = aliases return f + return wrapper class Arg(Field): - def __init__(self, - alias: str = None, - alias_from: Union[str, List[str], Callable, List[Callable]] = None, - *, - required: bool = False, - default=None, - default_factory: Callable = None, - case_insensitive: bool = None, - mode: str = None, - deprecated: Union[bool, str] = False, - discriminator=None, # discriminate the schema union by it's field - no_input: Union[bool, str, Callable] = False, - on_error: Literal["exclude", "preserve", "throw"] = None, # follow the options - dependencies: Union[list, str, property] = None, - # --- ANNOTATES --- - title: str = None, - description: str = None, - example=unprovided, - # --- CONSTRAINTS --- - const=unprovided, - enum: Iterable = None, - gt=None, - ge=None, - lt=None, - le=None, - regex: str = None, - length: Union[int, ConstraintMode] = None, - max_length: Union[int, ConstraintMode] = None, - min_length: int = None, - # number - max_digits: Union[int, ConstraintMode] = None, - decimal_places: Union[int, ConstraintMode] = None, - round: int = None, - multiple_of: Union[int, ConstraintMode] = None, - # array - contains: type = None, - max_contains: int = None, - min_contains: int = None, - unique_items: Union[bool, ConstraintMode] = None, - ): + def __init__( + self, + alias: str = None, + alias_from: Union[str, List[str], Callable, List[Callable]] = None, + *, + required: bool = False, + default=None, + default_factory: Callable = None, + case_insensitive: bool = None, + mode: str = None, + deprecated: Union[bool, str] = False, + discriminator=None, # discriminate the schema union by it's field + no_input: Union[bool, str, Callable] = False, + on_error: Literal["exclude", "preserve", "throw"] = None, # follow the options + dependencies: Union[list, str, property] = None, + # --- ANNOTATES --- + title: str = None, + description: str = None, + example=unprovided, + # --- CONSTRAINTS --- + const=unprovided, + enum: Iterable = None, + gt=None, + ge=None, + lt=None, + le=None, + regex: str = None, + length: Union[int, ConstraintMode] = None, + max_length: Union[int, ConstraintMode] = None, + min_length: int = None, + # number + max_digits: Union[int, ConstraintMode] = None, + decimal_places: Union[int, ConstraintMode] = None, + round: int = None, + multiple_of: Union[int, ConstraintMode] = None, + # array + contains: type = None, + max_contains: int = None, + min_contains: int = None, + unique_items: Union[bool, ConstraintMode] = None, + ): if required: default = default_factory = unprovided kwargs = dict(locals()) @@ -80,7 +82,7 @@ def __init__(self, class BaseCommand: - _commands: Dict[str, Union[Type['BaseCommand'], Callable]] = {} + _commands: Dict[str, Union[Type["BaseCommand"], Callable]] = {} _documents: Dict[str, str] = {} _aliases: Dict[str, str] = {} @@ -101,7 +103,7 @@ def __init_subclass__(cls, **kwargs): documents = {} aliases = {} - for base in reversed(cls.__bases__): # mro + for base in reversed(cls.__bases__): # mro if issubclass(base, BaseCommand): commands.update(base._commands) documents.update(base._documents) @@ -111,13 +113,13 @@ def __init_subclass__(cls, **kwargs): if isinstance(cmd, type) and issubclass(cmd, BaseCommand): commands[name] = cmd cmd_documents = cmd._documents - documents[name] = get_doc(cmd) or cmd_documents.get('') + documents[name] = get_doc(cmd) or cmd_documents.get("") for name, func in cls.__dict__.items(): name: str if name in commands: continue - if name.startswith('_'): + if name.startswith("_"): continue cls_func = False if isinstance(func, classmethod): @@ -126,10 +128,10 @@ def __init_subclass__(cls, **kwargs): # func.__classmethod__ = True if not inspect.isfunction(func): continue - command_name = getattr(func, '__command__', None) + command_name = getattr(func, "__command__", None) if command_name is None: continue - command_aliases = getattr(func, '__aliases__', []) + command_aliases = getattr(func, "__aliases__", []) documents[command_name] = get_doc(func) # if cls_func: # func = partial(func, cls) @@ -150,15 +152,15 @@ def __init__(self, *argv: str, cwd: str): if argv: self.arg_name, *self.args = argv else: - self.arg_name = '' + self.arg_name = "" self.args = [] - def get_command_cls(self, name: str) -> Union[Type['BaseCommand'], Callable]: + def get_command_cls(self, name: str) -> Union[Type["BaseCommand"], Callable]: alias = self._aliases.get(name, name) return self._commands.get(alias) def command_not_found(self): - print(RED % F'{self.script_name or "meta"}: command not found: {self.arg_name}') + print(RED % f'{self.script_name or "meta"}: command not found: {self.arg_name}') exit(1) def __call__(self, **kwargs): @@ -180,13 +182,15 @@ def __call__(self, **kwargs): return elif self.name: # subclasses - root_cmd = self.get_command_cls('') + root_cmd = self.get_command_cls("") if root_cmd: # the arg_name is actually the calling args for root cmd cmd_cls = root_cmd self.args = self.argv else: - raise ValueError(f'{self.script_name or "meta"} {self.name or ""}: Invalid command: {self.argv}') + raise ValueError( + f'{self.script_name or "meta"} {self.name or ""}: Invalid command: {self.argv}' + ) else: self.command_not_found() @@ -212,15 +216,15 @@ def __call__(self, **kwargs): args = [] for arg in self.args: arg = str(arg) - if arg.startswith('--'): - if '=' in arg: - key, *values = arg.split('=') - val = '='.join(values) + if arg.startswith("--"): + if "=" in arg: + key, *values = arg.split("=") + val = "=".join(values) kwargs[key] = val # = kwargs[str(key).strip('--')] else: kwargs[arg] = True # = kwargs[str(arg).strip('--')] - elif arg.startswith('-'): - kwargs[arg] = True # kwargs[arg.strip('-')] = + elif arg.startswith("-"): + kwargs[arg] = True # kwargs[arg.strip('-')] = else: args.append(arg) try: @@ -230,27 +234,34 @@ def __call__(self, **kwargs): def handle_parse_error(self, e: Exception): if isinstance(e, AbsenceError): - message = f'required command argument: {repr(e.item)} is absence' + message = f"required command argument: {repr(e.item)} is absence" else: message = str(e) error = Error(e) error.setup() print(error.full_info) - print(RED % F'{self.script_name or "meta"} {self.name or ""}: command [{self.arg_name}] failed: {message}') + print( + RED + % f'{self.script_name or "meta"} {self.name or ""}: command [{self.arg_name}] failed: {message}' + ) exit(1) @classmethod - def mount(cls, cmd_cls: Type['BaseCommand'], name: str = '', *aliases: str): + def mount(cls, cmd_cls: Type["BaseCommand"], name: str = "", *aliases: str): if not issubclass(cmd_cls, BaseCommand): - raise TypeError(f'Invalid command class: {cmd_cls}, should be BaseCommand subclass') + raise TypeError( + f"Invalid command class: {cmd_cls}, should be BaseCommand subclass" + ) for alias in aliases: cls._aliases[alias] = name cls._commands[name] = cmd_cls @classmethod - def merge(cls, cmd_cls: Type['BaseCommand']): + def merge(cls, cmd_cls: Type["BaseCommand"]): if not issubclass(cmd_cls, BaseCommand): - raise TypeError(f'Invalid command class: {cmd_cls}, should be BaseCommand subclass') + raise TypeError( + f"Invalid command class: {cmd_cls}, should be BaseCommand subclass" + ) cls._commands.update(cmd_cls._commands) cls._aliases.update(cmd_cls._aliases) cls._documents.update(cmd_cls._documents) diff --git a/utilmeta/bin/commands/base.py b/utilmeta/bin/commands/base.py index 6adf987..a9c8e21 100644 --- a/utilmeta/bin/commands/base.py +++ b/utilmeta/bin/commands/base.py @@ -10,19 +10,21 @@ class BaseServiceCommand(BaseCommand): META_INI = META_INI - script_name = 'meta' + script_name = "meta" def __init__(self, exe: str = None, *args: str, cwd: str = os.getcwd()): - self.exe = exe # absolute path of meta command tool + self.exe = exe # absolute path of meta command tool self.sys_args = list(args) if exe: - os.environ.setdefault('META_ABSOLUTE_PATH', exe) + os.environ.setdefault("META_ABSOLUTE_PATH", exe) if not os.path.isabs(cwd): cwd = path_join(os.getcwd(), cwd) - self.cwd = cwd.replace('\\', '/') - self.ini_path = search_file('utilmeta.ini', path=cwd) or search_file(META_INI, path=cwd) + self.cwd = cwd.replace("\\", "/") + self.ini_path = search_file("utilmeta.ini", path=cwd) or search_file( + META_INI, path=cwd + ) self.base_path = os.path.dirname(self.ini_path) if self.ini_path else self.cwd self.service_config = {} self._service = None @@ -37,33 +39,38 @@ def __init__(self, exe: str = None, *args: str, cwd: str = os.getcwd()): super().__init__(*self.sys_args, cwd=self.cwd) def command_not_found(self): - print(RED % F'{self.script_name or "meta"}: command not found: {self.arg_name}') + print(RED % f'{self.script_name or "meta"}: command not found: {self.arg_name}') if not self.ini_path: - print(f'It probably due to your utilmeta project not initialized') - print(f'please use {BLUE % "meta init"} in the project directory to initialize your project first') + print(f"It probably due to your utilmeta project not initialized") + print( + f'please use {BLUE % "meta init"} in the project directory to initialize your project first' + ) exit(1) def load_meta(self) -> dict: config = load_ini(read_from(self.ini_path), parse_key=True) - return config.get('utilmeta') or config.get('service') or {} + return config.get("utilmeta") or config.get("service") or {} @property def service_ref(self): - return self.service_config.get('service') + return self.service_config.get("service") @property def main_file(self) -> Optional[str]: - file: str = self.service_config.get('main') + file: str = self.service_config.get("main") if not file: return None - return os.path.join(self.service.project_dir, file if file.endswith('.py') else f'{file}.py') + return os.path.join( + self.service.project_dir, file if file.endswith(".py") else f"{file}.py" + ) @property def application_ref(self): - return self.service_config.get('app') + return self.service_config.get("app") def load_service(self): import utilmeta + utilmeta._cmd_env = True if not self.service_ref: @@ -72,18 +79,24 @@ def load_service(self): try: from utilmeta import service except ImportError: - raise RuntimeError('UtilMeta service not configured, ' - 'make sure you are inside a path with meta.ini, ' - 'and service is declared in meta.ini') + raise RuntimeError( + "UtilMeta service not configured, " + "make sure you are inside a path with meta.ini, " + "and service is declared in meta.ini" + ) else: self._service = service return service else: - raise RuntimeError('UtilMeta service not configured, make sure you are inside a path with meta.ini') + raise RuntimeError( + "UtilMeta service not configured, make sure you are inside a path with meta.ini" + ) service = import_obj(self.service_ref) if not isinstance(service, UtilMeta): - raise RuntimeError(f'Invalid UtilMeta service: {self.service}, should be an UtilMeta instance') + raise RuntimeError( + f"Invalid UtilMeta service: {self.service}, should be an UtilMeta instance" + ) self._service = service return service @@ -98,13 +111,15 @@ def service(self) -> UtilMeta: # raise RuntimeError('UtilMeta service not configured, make sure you are inside a path with meta.ini') @classmethod - @command('-h') + @command("-h") def help(cls): """ for helping """ - print(f'meta management tool usage: ') + print( + f"meta management tool usage: " + ) for key, doc in cls._documents.items(): if not key: continue - print(' ', BLUE % key, doc, '\n') + print(" ", BLUE % key, doc, "\n") diff --git a/utilmeta/bin/commands/setup.py b/utilmeta/bin/commands/setup.py index de8dae6..88dadbe 100644 --- a/utilmeta/bin/commands/setup.py +++ b/utilmeta/bin/commands/setup.py @@ -47,16 +47,9 @@ class SetupCommand(BaseServiceCommand): - SERVER_BACKENDS = 'utilmeta.core.server.backends' - DEFAULT_SUPPORTS = [ - 'django', - 'flask', - 'fastapi', - 'starlette', - 'sanic', - 'tornado' - ] - name = 'setup' + SERVER_BACKENDS = "utilmeta.core.server.backends" + DEFAULT_SUPPORTS = ["django", "flask", "fastapi", "starlette", "sanic", "tornado"] + name = "setup" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -74,17 +67,22 @@ def __init__(self, *args, **kwargs): def default_host(self): if not self.project_name: return - return f'{self.project_name}.com'.lower() + return f"{self.project_name}.com".lower() @property def default_backend(self): - return 'django' - - @command('') - def setup(self, name: str = '', *, - template: Literal['full', 'lite'] = Arg('--temp', alias_from=['--t', '--template'], default='lite'), - with_operations: bool = Arg('--ops', default=False), - ): + return "django" + + @command("") + def setup( + self, + name: str = "", + *, + template: Literal["full", "lite"] = Arg( + "--temp", alias_from=["--t", "--template"], default="lite" + ), + with_operations: bool = Arg("--ops", default=False), + ): """ Set up a new project --t: select template: full / lite @@ -92,51 +90,63 @@ def setup(self, name: str = '', *, """ if self.ini_path: - print(RED % 'meta Error: you are already inside an utilmeta project, ' - 'please chose a empty dir to setup your new project') + print( + RED % "meta Error: you are already inside an utilmeta project, " + "please chose a empty dir to setup your new project" + ) exit(1) while not name: - print(f'Enter the name of your project:') - name = input('>>> ').strip() - - if not re.fullmatch(r'[A-Za-z0-9_-]+', name): - print(RED % 'UtilMeta project name can only contains alphanumeric characters, ' - 'underscore "_" and hyphen "-"') + print(f"Enter the name of your project:") + name = input(">>> ").strip() + + if not re.fullmatch(r"[A-Za-z0-9_-]+", name): + print( + RED + % "UtilMeta project name can only contains alphanumeric characters, " + 'underscore "_" and hyphen "-"' + ) exit(1) project_dir = self.cwd - if os.sep in name or '/' in name: - print(RED % f'meta Error: project name ({repr(name)}) should not contains path separator') + if os.sep in name or "/" in name: + print( + RED + % f"meta Error: project name ({repr(name)}) should not contains path separator" + ) exit(1) self.project_name = name - print(f'meta: setting up project [{name}]') + print(f"meta: setting up project [{name}]") - print(f'description of this project (optional)') - self.description = input('>>> ') or '' + print(f"description of this project (optional)") + self.description = input(">>> ") or "" - print(f'Choose the http backend of your project') + print(f"Choose the http backend of your project") for pkg in self.DEFAULT_SUPPORTS: - print(f' - {pkg}%s' % (' (default)' if pkg == self.default_backend else '')) + print(f" - {pkg}%s" % (" (default)" if pkg == self.default_backend else "")) while not self.backend: - self.backend = (input('>>> ') or self.default_backend).lower() + self.backend = (input(">>> ") or self.default_backend).lower() try: - import_obj(f'{self.SERVER_BACKENDS}.{self.backend}') + import_obj(f"{self.SERVER_BACKENDS}.{self.backend}") except ModuleNotFoundError: if self.backend in self.DEFAULT_SUPPORTS: requires(self.backend) else: - print(f'backend: {repr(self.backend)} not supported or not installed, please enter again') + print( + f"backend: {repr(self.backend)} not supported or not installed, please enter again" + ) self.backend = None # if self.backend == 'starlette': # check_requirement('uvicorn', install_when_require=True) - print(f'Enter the production host of your service (default: {self.default_host})') - self.host = input('>>> ') or self.default_host + print( + f"Enter the production host of your service (default: {self.default_host})" + ) + self.host = input(">>> ") or self.default_host temp_path = os.path.join(TEMP_PATH, template) project_path = os.path.join(project_dir, name) @@ -148,7 +158,10 @@ def setup(self, name: str = '', *, # -------------------------------------- if os.path.exists(project_path): - print(RED % f'meta Error: project path {project_path} already exist, chose a different name') + print( + RED + % f"meta Error: project path {project_path} already exist, chose a different name" + ) exit(1) shutil.copytree(temp_path, project_path) @@ -157,24 +170,33 @@ def setup(self, name: str = '', *, for file in files: path = os.path.join(ab_path, file) - if str(file).endswith('.py') or str(file).endswith('.ini'): + if str(file).endswith(".py") or str(file).endswith(".ini"): content = read_from(path) - write_to(path, self.render( - content.replace('# noqa', ''), - with_operations=with_operations, - template=template - )) + write_to( + path, + self.render( + content.replace("# noqa", ""), + with_operations=with_operations, + template=template, + ), + ) # add gitignore file - write_to(os.path.join(project_path, '.gitignore'), DEFAULT_GIT_IGNORE) - requirements = ['utilmeta'] + write_to(os.path.join(project_path, ".gitignore"), DEFAULT_GIT_IGNORE) + requirements = ["utilmeta"] if self.backend: requirements.append(self.backend) - write_to(os.path.join(project_path, 'requirements.txt'), '\n'.join(requirements)) + write_to( + os.path.join(project_path, "requirements.txt"), "\n".join(requirements) + ) - print(f'UtilMeta project <{BLUE % self.project_name}> successfully setup at path: {project_path}') + print( + f"UtilMeta project <{BLUE % self.project_name}> successfully setup at path: {project_path}" + ) - def render(self, content, with_operations: bool = False, template: str = None) -> str: + def render( + self, content, with_operations: bool = False, template: str = None + ) -> str: def _format(text: str, *args, **kwargs): for i in range(0, len(args)): k = "{" + str(i) + "}" @@ -184,7 +206,7 @@ def _format(text: str, *args, **kwargs): text = text.replace(k, str(val)) return text - if template == 'full': + if template == "full": operations_text = """ from utilmeta.ops.config import Operations from utilmeta.conf.time import Time @@ -201,7 +223,9 @@ def _format(text: str, *args, **kwargs): engine='sqlite3', ), )) -""".format(name=self.project_name) +""".format( + name=self.project_name + ) else: operations_text = """ from utilmeta.ops.config import Operations @@ -226,16 +250,20 @@ def _format(text: str, *args, **kwargs): engine='sqlite3', ), )) -""".format(name=self.project_name) +""".format( + name=self.project_name + ) return _format( content, name=self.project_name, backend=self.backend, - import_backend=f'import {self.backend}', + import_backend=f"import {self.backend}", description=self.description, host=self.host, - operations=operations_text if with_operations else '', + operations=operations_text if with_operations else "", plugins=""" -@api.CORS(allow_origin='*')""" if with_operations else '' +@api.CORS(allow_origin='*')""" + if with_operations + else "", ) diff --git a/utilmeta/bin/constant.py b/utilmeta/bin/constant.py index bfacd03..40b3f2f 100644 --- a/utilmeta/bin/constant.py +++ b/utilmeta/bin/constant.py @@ -1,26 +1,27 @@ import sys -LINUX = 'linux' -BSD = 'bsd' -WIN = 'win' -META_INI = 'meta.ini' -INIT_FILE = '__init__.py' -UWSGI = 'uwsgi' -GUNICORN = 'gunicorn' -NGINX = 'nginx' -APACHE = 'apache' -DOT = '●' -PRODUCTION = 'PRODUCTION' -DEBUG = 'DEBUG' -PYTHON = 'python' -SPACE_4 = ' ' * 4 -JOINER = '\n' + SPACE_4 + +LINUX = "linux" +BSD = "bsd" +WIN = "win" +META_INI = "meta.ini" +INIT_FILE = "__init__.py" +UWSGI = "uwsgi" +GUNICORN = "gunicorn" +NGINX = "nginx" +APACHE = "apache" +DOT = "●" +PRODUCTION = "PRODUCTION" +DEBUG = "DEBUG" +PYTHON = "python" +SPACE_4 = " " * 4 +JOINER = "\n" + SPACE_4 JOINER2 = JOINER + SPACE_4 -PY_NAMES = (UWSGI, GUNICORN, PYTHON, '%s.exe' % PYTHON) +PY_NAMES = (UWSGI, GUNICORN, PYTHON, "%s.exe" % PYTHON) SERVER_NAMES = (NGINX, APACHE, UWSGI, GUNICORN) -UWSGI_CONFIG = f'{UWSGI}.ini' -GUNICORN_CONFIG = f'{GUNICORN}.conf' -NGINX_CONFIG = f'conf.{NGINX}' -APACHE_CONFIG = f'{APACHE}.conf' +UWSGI_CONFIG = f"{UWSGI}.ini" +GUNICORN_CONFIG = f"{GUNICORN}.conf" +NGINX_CONFIG = f"conf.{NGINX}" +APACHE_CONFIG = f"{APACHE}.conf" WSGI_CHOICES = [UWSGI, GUNICORN] WEB_CHOICES = [NGINX, APACHE] DEFAULT_GITIGNORE = """ @@ -41,11 +42,11 @@ """ -GREEN = '\033[1;32m%s\033[0m' -RED = '\033[1;31m%s\033[0m' -BLUE = '\033[1;34m%s\033[0m' -YELLOW = '\033[1;33m%s\033[0m' -BANNER = '\033[1;30;47m%s\033[0m' +GREEN = "\033[1;32m%s\033[0m" +RED = "\033[1;31m%s\033[0m" +BLUE = "\033[1;34m%s\033[0m" +YELLOW = "\033[1;33m%s\033[0m" +BANNER = "\033[1;30;47m%s\033[0m" LINE_WIDTH = 100 if sys.platform != LINUX: @@ -56,8 +57,8 @@ colorama.init(autoreset=True) except ModuleNotFoundError: # DOWN GRADE - GREEN = '%s' - RED = '%s' - BLUE = '%s' - YELLOW = '%s' - BANNER = '%s' + GREEN = "%s" + RED = "%s" + BLUE = "%s" + YELLOW = "%s" + BANNER = "%s" diff --git a/utilmeta/bin/meta.py b/utilmeta/bin/meta.py index 58f8048..9da6300 100644 --- a/utilmeta/bin/meta.py +++ b/utilmeta/bin/meta.py @@ -22,11 +22,17 @@ def __init__(self, *args, **kwargs): def command_not_found(self): if self.ini_path: # maybe figure out some non-invasion method in the future - if self.arg_name in ['connect', 'stats', 'sync', 'migrate_ops']: - print(YELLOW % f'meta {self.arg_name}: Operations config not integrated to application, ' - 'please follow the document at https://docs.utilmeta.com/py/en/guide/ops/') - elif self.arg_name in ['add', 'makemigrations', 'migrate']: - print(YELLOW % f'meta {self.arg_name}: DjangoSettings config not used in application') + if self.arg_name in ["connect", "stats", "sync", "migrate_ops"]: + print( + YELLOW + % f"meta {self.arg_name}: Operations config not integrated to application, " + "please follow the document at https://docs.utilmeta.com/py/en/guide/ops/" + ) + elif self.arg_name in ["add", "makemigrations", "migrate"]: + print( + YELLOW + % f"meta {self.arg_name}: DjangoSettings config not used in application" + ) super().command_not_found() def load_commands(self): @@ -42,20 +48,21 @@ def load_commands(self): self.mount(cmd, name=name) @classmethod - @command('') + @command("") def intro(cls): - print(f'UtilMeta v{__version__} Management Command Line Tool') - print('use meta -h for help') + print(f"UtilMeta v{__version__} Management Command Line Tool") + print("use meta -h for help") cls.help() - @command('init') - def init(self, - name: str = Arg(default=None), - app: str = Arg(alias='--app', default=None), - service: str = Arg(alias='--service', default=None), - main_file: str = Arg(alias='--main', default=None), - pid_file: str = Arg(alias='--pid', default='service.pid'), - ): + @command("init") + def init( + self, + name: str = Arg(default=None), + app: str = Arg(alias="--app", default=None), + service: str = Arg(alias="--service", default=None), + main_file: str = Arg(alias="--main", default=None), + pid_file: str = Arg(alias="--pid", default="service.pid"), + ): """ Initialize utilmeta project with a meta.ini file --app: specify the wsgi / asgi application @@ -64,103 +71,118 @@ def init(self, """ if not self.ini_path: self.ini_path = os.path.join(self.cwd, self.META_INI) - print(f'Initialize UtilMeta project with {self.ini_path}') + print(f"Initialize UtilMeta project with {self.ini_path}") else: config = self.service_config if config: - print('UtilMeta project already initialized at {}'.format(self.ini_path)) + print( + "UtilMeta project already initialized at {}".format(self.ini_path) + ) return - print(f'Re-initialize UtilMeta project at {self.ini_path}') + print(f"Re-initialize UtilMeta project at {self.ini_path}") while True: if not app: - print(f'Please specify the reference of your WSGI / ASGI application, like package.to.your.app') - app = input('>>> ') + print( + f"Please specify the reference of your WSGI / ASGI application, like package.to.your.app" + ) + app = input(">>> ") # try to load try: app_obj = import_obj(app) if inspect.ismodule(app_obj): - raise ValueError(f'--app should be a python application object, got module: {app_obj}') + raise ValueError( + f"--app should be a python application object, got module: {app_obj}" + ) break except Exception as e: err = Error(e) err.setup() print(err.message) - print(RED % f'python application reference: {repr(app)} failed to load: {e}') + print( + RED + % f"python application reference: {repr(app)} failed to load: {e}" + ) app = None if not name: base_name = os.path.basename(os.path.dirname(self.ini_path)) - print(f'Please enter your project name (default: {base_name})') - name = input('>>> ') or base_name + print(f"Please enter your project name (default: {base_name})") + name = input(">>> ") or base_name settings = dict(app=app) if name: - settings['name'] = name + settings["name"] = name if main_file: - settings['main'] = main_file + settings["main"] = main_file if service: - settings['service'] = service + settings["service"] = service if pid_file: - settings['pidfile'] = pid_file + settings["pidfile"] = pid_file - print(f'Initializing UtilMeta project [{BLUE % name}] with python application: {BLUE % app}') + print( + f"Initializing UtilMeta project [{BLUE % name}] with python application: {BLUE % app}" + ) from utilmeta.utils import write_config - write_config({ - 'utilmeta': settings - }, self.ini_path) + + write_config({"utilmeta": settings}, self.ini_path) def _get_openapi(self): from utilmeta.ops.config import Operations + ops_config = self.service.get_config(Operations) if ops_config: return ops_config.openapi from utilmeta.core.api.specs.openapi import OpenAPI + return OpenAPI(self.service)() @command() - def gen_openapi(self, to: str = Arg(alias='--to', default='openapi.json')): + def gen_openapi(self, to: str = Arg(alias="--to", default="openapi.json")): """ Generate OpenAPI document file for current service --to: target file name, default to be openapi.json """ self.service.setup() # setup here - print(f'generate openapi document file for service: [{self.service.name}]') + print(f"generate openapi document file for service: [{self.service.name}]") from utilmeta.core.api.specs.openapi import OpenAPI + openapi = self._get_openapi() path = OpenAPI.save_to(openapi, to) - print(f'OpenAPI document generated at {path}') + print(f"OpenAPI document generated at {path}") @command() - def gen_client(self, - openapi: str = Arg(alias='--openapi', default=None), - to: str = Arg(alias='--to', default='client.py'), - split_body_params: str = Arg(alias='--split-body-params', default=True), - black: str = Arg(alias='--black', default=True), - space_indent: str = Arg(alias='--spaces-indent', default=True), - ): + def gen_client( + self, + openapi: str = Arg(alias="--openapi", default=None), + to: str = Arg(alias="--to", default="client.py"), + split_body_params: str = Arg(alias="--split-body-params", default=True), + black: str = Arg(alias="--black", default=True), + space_indent: str = Arg(alias="--spaces-indent", default=True), + ): """ Generate UtilMeta Client code for current service or specified OpenAPI document (url or file) --openapi: specify target OpenAPI document (url / filepath / document string), default to be the document of current UtilMeta service --to: target file name, default to be openapi.json """ from utilmeta.core.cli.specs.openapi import OpenAPIClientGenerator + if openapi: - print(f'generate client file based on openapi: {repr(openapi)}') + print(f"generate client file based on openapi: {repr(openapi)}") generator = OpenAPIClientGenerator.generate_from(openapi) else: self.service.setup() # setup here - print(f'generate client file for service: [{self.service.name}]') + print(f"generate client file for service: [{self.service.name}]") openapi_docs = self._get_openapi() generator = OpenAPIClientGenerator(openapi_docs) generator.space_ident = space_indent generator.black_format = black generator.split_body_params = split_body_params path = generator(to) - print(f'Client file generated at {path}') + print(f"Client file generated at {path}") - @command('-v', 'version') + @command("-v", "version") def version(self): """ display the current UtilMeta version and service meta-info @@ -168,41 +190,58 @@ def version(self): import platform import sys from utilmeta.bin.constant import BLUE, GREEN, DOT - print(f' UtilMeta: v{ __version__}') + + print(f" UtilMeta: v{ __version__}") try: _ = self.service except RuntimeError: # service not detect - print(f' service:', 'not detected') + print(f" service:", "not detected") else: - print(f' service:', BLUE % self.service.name, f'({self.service.version_str})') - print(f' stage:', (BLUE % f'{DOT} production') if self.service.production else (GREEN % f'{DOT} debug')) - print(f' backend:', f'{self.service.backend_name} ({self.service.backend_version})', - f'| asynchronous' if self.service.asynchronous else '') - print(f' environment:', sys.version, platform.platform()) + print( + f" service:", + BLUE % self.service.name, + f"({self.service.version_str})", + ) + print( + f" stage:", + (BLUE % f"{DOT} production") + if self.service.production + else (GREEN % f"{DOT} debug"), + ) + print( + f" backend:", + f"{self.service.backend_name} ({self.service.backend_version})", + f"| asynchronous" if self.service.asynchronous else "", + ) + print(f" environment:", sys.version, platform.platform()) @command - def run(self, - daemon: bool = Arg('-d', default=False), - connect: bool = Arg('-c', default=False), - log: str = Arg('--log', default='service.log'), - ): + def run( + self, + daemon: bool = Arg("-d", default=False), + connect: bool = Arg("-c", default=False), + log: str = Arg("--log", default="service.log"), + ): """ run utilmeta service and start to serve requests (for debug only) """ if not self.main_file: - print(RED % 'meta run: no main file specified in meta.ini') + print(RED % "meta run: no main file specified in meta.ini") exit(1) - print(f'UtilMeta service {BLUE % self.service.name} running at {self.main_file}') - cmd = f'{sys.executable} {self.main_file}' + print( + f"UtilMeta service {BLUE % self.service.name} running at {self.main_file}" + ) + cmd = f"{sys.executable} {self.main_file}" if daemon: - if os.name == 'posix': - print(f'running service with nohup in background, writing log to {log}') - cmd = f'nohup {cmd} > {log} 2>&1 &' + if os.name == "posix": + print(f"running service with nohup in background, writing log to {log}") + cmd = f"nohup {cmd} > {log} 2>&1 &" else: - print(YELLOW % 'ignoring daemon mode since only posix system support') + print(YELLOW % "ignoring daemon mode since only posix system support") if connect: from utilmeta.ops.cmd import try_to_connect + try_to_connect() run(cmd) @@ -212,38 +251,46 @@ def down(self): pid = self.service.pid if not pid: if self.service.pid_file: - print(RED % f'meta down: PID not found in pidfile, service may not started yet') + print( + RED + % f"meta down: PID not found in pidfile, service may not started yet" + ) else: - print(RED % f'meta down: requires pidfile set in meta.ini, no pid found') + print( + RED % f"meta down: requires pidfile set in meta.ini, no pid found" + ) exit(1) try: proc = psutil.Process(pid) except psutil.NoSuchProcess: - print(f'meta down: service [{self.service.name}](pid={pid}) already stopped') + print( + f"meta down: service [{self.service.name}](pid={pid}) already stopped" + ) return except psutil.Error as e: - print(RED % f'meta down: load process: {pid} failed with error: {e}') + print(RED % f"meta down: load process: {pid} failed with error: {e}") exit(1) proc.kill() - print(f'meta down: service [{self.service.name}](pid={pid}) stopped') + print(f"meta down: service [{self.service.name}](pid={pid}) stopped") @command - def restart(self, - connect: bool = Arg('-c', default=False), - log: str = Arg('--log', default='service.log'), - ): + def restart( + self, + connect: bool = Arg("-c", default=False), + log: str = Arg("--log", default="service.log"), + ): pid = self.service.pid if not pid: if self.service.pid_file: return self.run(daemon=True, connect=connect, log=log) - print(RED % f'meta restart: requires pidfile set in meta.ini, no pid found') + print(RED % f"meta restart: requires pidfile set in meta.ini, no pid found") exit(1) try: proc = psutil.Process(pid) except psutil.NoSuchProcess: return self.run(daemon=True, connect=connect, log=log) except psutil.Error as e: - print(RED % f'meta restart: load process: {pid} failed with error: {e}') + print(RED % f"meta restart: load process: {pid} failed with error: {e}") exit(1) proc.kill() return self.run(daemon=True, connect=connect, log=log) @@ -253,5 +300,5 @@ def main(): MetaCommand(*sys.argv)() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/utilmeta/bin/template/full/config/conf.py b/utilmeta/bin/template/full/config/conf.py index 14b4ce3..bb6450c 100644 --- a/utilmeta/bin/template/full/config/conf.py +++ b/utilmeta/bin/template/full/config/conf.py @@ -6,14 +6,15 @@ def configure(service: UtilMeta): from utilmeta.core.server.backends.django import DjangoSettings from utilmeta.core.orm import DatabaseConnections, Database - service.use(DjangoSettings( - apps_package='domain', - secret_key=env.DJANGO_SECRET_KEY - )) - service.use(DatabaseConnections({ - 'default': Database( - name='db', - engine='sqlite3', + service.use(DjangoSettings(apps_package="domain", secret_key=env.DJANGO_SECRET_KEY)) + service.use( + DatabaseConnections( + { + "default": Database( + name="db", + engine="sqlite3", + ) + } ) - })) - {operations} # noqa + ) + {operations} # noqa diff --git a/utilmeta/bin/template/full/config/env.py b/utilmeta/bin/template/full/config/env.py index d7d3021..cbf1fc9 100644 --- a/utilmeta/bin/template/full/config/env.py +++ b/utilmeta/bin/template/full/config/env.py @@ -3,7 +3,7 @@ class ServiceEnvironment(Env): PRODUCTION: bool = False - DJANGO_SECRET_KEY: str = '' + DJANGO_SECRET_KEY: str = "" -env = ServiceEnvironment(sys_env='META_') +env = ServiceEnvironment(sys_env="META_") diff --git a/utilmeta/bin/template/full/config/service.py b/utilmeta/bin/template/full/config/service.py index ce4e3c6..06cde77 100644 --- a/utilmeta/bin/template/full/config/service.py +++ b/utilmeta/bin/template/full/config/service.py @@ -1,19 +1,20 @@ from utilmeta import UtilMeta from config.conf import configure from config.env import env -{import_backend} # noqa + +{import_backend} # noqa service = UtilMeta( __name__, - name='{name}', - description='{description}', + name="{name}", + description="{description}", backend={backend}, # noqa production=env.PRODUCTION, version=(0, 1, 0), - host='127.0.0.1', + host="127.0.0.1", port=8000, - origin='https://{host}' if env.PRODUCTION else None, - api='service.api.RootAPI', - route='/api' + origin="https://{host}" if env.PRODUCTION else None, + api="service.api.RootAPI", + route="/api", ) configure(service) diff --git a/utilmeta/bin/template/full/main.py b/utilmeta/bin/template/full/main.py index 5168ef6..1f8a318 100644 --- a/utilmeta/bin/template/full/main.py +++ b/utilmeta/bin/template/full/main.py @@ -2,5 +2,5 @@ app = service.application() -if __name__ == '__main__': - service.run() \ No newline at end of file +if __name__ == "__main__": + service.run() diff --git a/utilmeta/bin/template/full/service/api.py b/utilmeta/bin/template/full/service/api.py index 383b9be..e5edca9 100644 --- a/utilmeta/bin/template/full/service/api.py +++ b/utilmeta/bin/template/full/service/api.py @@ -1,7 +1,9 @@ from utilmeta.core import api {plugins} # noqa + + class RootAPI(api.API): @api.get def hello(self): - return 'world' + return "world" diff --git a/utilmeta/bin/template/lite/server.py b/utilmeta/bin/template/lite/server.py index cca9a1e..3f42e99 100644 --- a/utilmeta/bin/template/lite/server.py +++ b/utilmeta/bin/template/lite/server.py @@ -4,34 +4,37 @@ from utilmeta import UtilMeta from utilmeta.core import api import os -{import_backend} # noqa + +{import_backend} # noqa {plugins} # noqa + + class RootAPI(api.API): @api.get def hello(self): - return 'world' + return "world" -production = bool(os.getenv('UTILMETA_PRODUCTION')) +production = bool(os.getenv("UTILMETA_PRODUCTION")) service = UtilMeta( __name__, - name='{name}', - description='{description}', + name="{name}", + description="{description}", backend={backend}, # noqa production=production, version=(0, 1, 0), - host='127.0.0.1', + host="127.0.0.1", port=8000, - origin='https://{host}' if production else None, - route='/api', - api=RootAPI + origin="https://{host}" if production else None, + route="/api", + api=RootAPI, ) -{operations} # noqa +{operations} # noqa -app = service.application() # used in wsgi/asgi server +app = service.application() # used in wsgi/asgi server -if __name__ == '__main__': +if __name__ == "__main__": service.run() # try: http://127.0.0.1:8000/api/hello diff --git a/utilmeta/bin/utils.py b/utilmeta/bin/utils.py index 50ab814..502be9b 100644 --- a/utilmeta/bin/utils.py +++ b/utilmeta/bin/utils.py @@ -6,14 +6,16 @@ def update_meta_ini_file(path: str = None, /, **settings): if not settings: return cwd = path or os.getcwd() - ini_path = search_file('utilmeta.ini', path=cwd) or search_file('meta.ini', path=cwd) + ini_path = search_file("utilmeta.ini", path=cwd) or search_file( + "meta.ini", path=cwd + ) if not ini_path: return config = load_ini(read_from(ini_path), parse_key=True) - service_config = dict(config.get('utilmeta') or config.get('service') or {}) + service_config = dict(config.get("utilmeta") or config.get("service") or {}) for key, val in settings.items(): if val is None: pop(service_config, key) else: service_config[key] = val - write_config({'utilmeta': service_config}, ini_path) + write_config({"utilmeta": service_config}, ini_path) diff --git a/utilmeta/conf/base.py b/utilmeta/conf/base.py index 911e5f1..e34c429 100644 --- a/utilmeta/conf/base.py +++ b/utilmeta/conf/base.py @@ -2,7 +2,7 @@ from utype import DataClass from utilmeta.utils import pop -T = TypeVar('T') +T = TypeVar("T") class Config(DataClass): @@ -10,8 +10,8 @@ class Config(DataClass): def __init__(self, kwargs=None): if kwargs: - pop(kwargs, '__class__') - pop(kwargs, 'self') + pop(kwargs, "__class__") + pop(kwargs, "self") self._kwargs = kwargs or {} super().__init__(**self._kwargs) diff --git a/utilmeta/conf/env.py b/utilmeta/conf/env.py index f993566..e5a66a6 100644 --- a/utilmeta/conf/env.py +++ b/utilmeta/conf/env.py @@ -12,35 +12,37 @@ class EnvVarUndefined(ValueError): class Env(Schema): __options__ = Options(case_insensitive=True) - def __init__(self, - data: Union[Mapping, dict] = None, - sys_env: Union[bool, str] = None, - ref: str = None, - file: str = None): + def __init__( + self, + data: Union[Mapping, dict] = None, + sys_env: Union[bool, str] = None, + ref: str = None, + file: str = None, + ): self._data = data or {} self._sys_env = bool(sys_env) - self._sys_env_prefix = sys_env if isinstance(sys_env, str) else '' + self._sys_env_prefix = sys_env if isinstance(sys_env, str) else "" self._ref = ref self._file = file for items in ( self._load_from_ref(), self._load_from_file(), - self._load_from_sys_env() + self._load_from_sys_env(), ): if items: self._data.update(items) try: super().__init__(**self._data) except AbsenceError as e: - route = '.'.join(e.routes) if e.routes else e.item + route = ".".join(e.routes) if e.routes else e.item if self._sys_env_prefix: msg = f'Environment variable "{self._sys_env_prefix}{route}" not set"' elif self._file: - msg = f'variable not set in file: {self._file}' + msg = f"variable not set in file: {self._file}" else: - msg = 'variable not set' + msg = "variable not set" raise EnvVarUndefined( - f'{self.__class__.__name__} initialize failed: required env var [{route}] undefined: {msg}' + f"{self.__class__.__name__} initialize failed: required env var [{route}] undefined: {msg}" ) from e def _load_from_sys_env(self) -> Mapping: @@ -49,7 +51,7 @@ def _load_from_sys_env(self) -> Mapping: data = {} for key, value in os.environ.items(): if key.lower().startswith(self._sys_env_prefix.lower()): - data[key[len(self._sys_env_prefix):]] = value + data[key[len(self._sys_env_prefix) :]] = value return data def _load_from_file(self) -> Mapping: @@ -58,30 +60,36 @@ def _load_from_file(self) -> Mapping: if not os.path.exists(self._file): rel_file = os.path.join(os.getcwd(), self._file) if not os.path.exists(rel_file): - raise FileNotFoundError(f'{self.__class__}: file: {repr(self._file)} not exists') + raise FileNotFoundError( + f"{self.__class__}: file: {repr(self._file)} not exists" + ) else: self._file = rel_file - if self._file.endswith('.json'): - return json.load(open(self._file, 'r')) + if self._file.endswith(".json"): + return json.load(open(self._file, "r")) - if self._file.endswith('.yml') or self._file.endswith('.yaml'): + if self._file.endswith(".yml") or self._file.endswith(".yaml"): from utilmeta.utils import requires - requires(yaml='pyyaml') + + requires(yaml="pyyaml") import yaml - return yaml.safe_load(open(self._file, 'r')) - content = open(self._file, 'r').read() + return yaml.safe_load(open(self._file, "r")) + + content = open(self._file, "r").read() data = {} for line in content.splitlines(): if not line.strip(): # empty line continue try: - key, value = line.split('=') + key, value = line.split("=") except ValueError as e: - raise ValueError(f'{self.__class__}: file: {repr(self._file)} invalid line: {repr(line)}, ' - f'should be =') from e + raise ValueError( + f"{self.__class__}: file: {repr(self._file)} invalid line: {repr(line)}, " + f"should be =" + ) from e key = str(key).strip() value = str(value).strip() if key: @@ -92,9 +100,12 @@ def _load_from_ref(self) -> Mapping: if not self._ref: return {} from utilmeta.utils import import_obj + obj = import_obj(self._ref) if isinstance(obj, Mapping): return obj - if hasattr(obj, '__dict__') and isinstance(obj.__dict__, Mapping): + if hasattr(obj, "__dict__") and isinstance(obj.__dict__, Mapping): return obj.__dict__ - raise TypeError(f'{self.__class__}: invalid ref: {repr(self._ref)}, dict or class expetect, got {obj}') + raise TypeError( + f"{self.__class__}: invalid ref: {repr(self._ref)}, dict or class expetect, got {obj}" + ) diff --git a/utilmeta/conf/http.py b/utilmeta/conf/http.py index f247582..ead9478 100644 --- a/utilmeta/conf/http.py +++ b/utilmeta/conf/http.py @@ -3,39 +3,42 @@ class Cookie(Config): - Lax = 'Lax' - Strict = 'Strict' + Lax = "Lax" + Strict = "Strict" # ---- class attribute hint age: int = 31449600 domain: Optional[str] = None - name: str = '' - path: str = '/' + name: str = "" + path: str = "/" secure: bool = False # cross_domain: bool = False http_only: bool = False - same_site: Optional[str] = 'Lax' + same_site: Optional[str] = "Lax" - def __init__(self, - age: int = 31449600, - domain: Optional[str] = None, - name: str = '', - path: str = '/', - secure: bool = False, - # cross_domain: bool = False, - http_only: bool = False, - same_site: Optional[str] = 'Lax' - ): + def __init__( + self, + age: int = 31449600, + domain: Optional[str] = None, + name: str = "", + path: str = "/", + secure: bool = False, + # cross_domain: bool = False, + http_only: bool = False, + same_site: Optional[str] = "Lax", + ): super().__init__(locals()) def as_django(self, prefix: str = None): config = { - 'AGE': self.age, - 'DOMAIN': self.domain, - 'HTTPONLY': self.http_only, - 'NAME': self.name, - 'PATH': self.path, - 'SAMESITE': str(self.same_site), - 'SECURE': self.secure + "AGE": self.age, + "DOMAIN": self.domain, + "HTTPONLY": self.http_only, + "NAME": self.name, + "PATH": self.path, + "SAMESITE": str(self.same_site), + "SECURE": self.secure, + } + return { + (f"{prefix}_{key}" if prefix else key): val for key, val in config.items() } - return {(f'{prefix}_{key}' if prefix else key): val for key, val in config.items()} diff --git a/utilmeta/conf/pool.py b/utilmeta/conf/pool.py index f6cfea0..35e65bd 100644 --- a/utilmeta/conf/pool.py +++ b/utilmeta/conf/pool.py @@ -7,7 +7,9 @@ class ThreadPool(Config): max_workers: Optional[int] timeout: Optional[int] - def __init__(self, max_workers: Optional[int] = None, timeout: Optional[int] = None): + def __init__( + self, max_workers: Optional[int] = None, timeout: Optional[int] = None + ): super().__init__(locals()) self._pool = ThreadPoolExecutor(self.max_workers) @@ -19,4 +21,5 @@ def get_result(self, func, *args, **kwargs): def submit(self, func, *args, **kwargs): self._pool.submit(func, *args, **kwargs) + # pool = ThreadPool() diff --git a/utilmeta/conf/preference.py b/utilmeta/conf/preference.py index 2f9444d..1f37f96 100644 --- a/utilmeta/conf/preference.py +++ b/utilmeta/conf/preference.py @@ -56,5 +56,5 @@ def __init__( super().__init__(locals()) @classmethod - def get(cls) -> 'Preference': + def get(cls) -> "Preference": return cls.config() or cls() diff --git a/utilmeta/conf/time.py b/utilmeta/conf/time.py index 82b474e..66e5ccd 100644 --- a/utilmeta/conf/time.py +++ b/utilmeta/conf/time.py @@ -2,13 +2,14 @@ from utilmeta.utils import TimeZone, get_timezone from typing import Optional import time + SERVER_UTCOFFSET = -time.timezone from datetime import datetime, timezone, timedelta class Time(Config): - DATE_DEFAULT = '%Y-%m-%d' - TIME_DEFAULT = '%H:%M:%S' + DATE_DEFAULT = "%Y-%m-%d" + TIME_DEFAULT = "%H:%M:%S" DATETIME_DEFAULT = "%Y-%m-%d %H:%M:%S" # ---- @@ -18,15 +19,17 @@ class Time(Config): use_tz: Optional[bool] = None time_zone: Optional[str] = None - def __init__(self, *, - date_format: str = DATE_DEFAULT, - time_format: str = TIME_DEFAULT, - datetime_format: str = DATETIME_DEFAULT, - # to_timestamp: bool = False, - # to_ms_timestamp: bool = False, - use_tz: Optional[bool] = True, - time_zone: Optional[str] = None, - ): + def __init__( + self, + *, + date_format: str = DATE_DEFAULT, + time_format: str = TIME_DEFAULT, + datetime_format: str = DATETIME_DEFAULT, + # to_timestamp: bool = False, + # to_ms_timestamp: bool = False, + use_tz: Optional[bool] = True, + time_zone: Optional[str] = None, + ): super().__init__(locals()) self.time_format = time_format self.date_format = date_format @@ -115,5 +118,7 @@ def convert_time(self, dt: datetime) -> datetime: if dt.tzinfo: return self.time_local(dt) if self.timezone_utcoffset != self.server_utcoffset: - return dt + timedelta(seconds=self.timezone_utcoffset - self.server_utcoffset) + return dt + timedelta( + seconds=self.timezone_utcoffset - self.server_utcoffset + ) return dt diff --git a/utilmeta/core/api/__init__.py b/utilmeta/core/api/__init__.py index 6129d25..3656eb8 100644 --- a/utilmeta/core/api/__init__.py +++ b/utilmeta/core/api/__init__.py @@ -5,15 +5,16 @@ from .plugins.base import APIPlugin as Plugin from .plugins.retry import RetryPlugin as Retry from .plugins.cors import CORSPlugin as CORS + # from .plugins.rate import RateLimitPlugin as RateLimit route = decorator.APIDecoratorWrapper(None) -get = decorator.APIDecoratorWrapper('get') -put = decorator.APIDecoratorWrapper('put') -post = decorator.APIDecoratorWrapper('post') -patch = decorator.APIDecoratorWrapper('patch') -delete = decorator.APIDecoratorWrapper('delete') +get = decorator.APIDecoratorWrapper("get") +put = decorator.APIDecoratorWrapper("put") +post = decorator.APIDecoratorWrapper("post") +patch = decorator.APIDecoratorWrapper("patch") +delete = decorator.APIDecoratorWrapper("delete") # below is SDK-only method -head = decorator.APIDecoratorWrapper('head') -options = decorator.APIDecoratorWrapper('options') -trace = decorator.APIDecoratorWrapper('trace') +head = decorator.APIDecoratorWrapper("head") +options = decorator.APIDecoratorWrapper("options") +trace = decorator.APIDecoratorWrapper("trace") diff --git a/utilmeta/core/api/base.py b/utilmeta/core/api/base.py index 12cf2ef..0328511 100644 --- a/utilmeta/core/api/base.py +++ b/utilmeta/core/api/base.py @@ -1,8 +1,14 @@ from typing import Union, Dict, Type, List, Any, Optional from utilmeta.utils.error import Error from utilmeta.utils.context import ParserProperty -from utilmeta.utils import Header, EndpointAttr, COMMON_METHODS, awaitable, \ - classonlymethod, distinct_add +from utilmeta.utils import ( + Header, + EndpointAttr, + COMMON_METHODS, + awaitable, + classonlymethod, + distinct_add, +) from utilmeta.utils import exceptions as exc import inspect @@ -24,10 +30,10 @@ from utype.utils.exceptions import ParseError from utilmeta.conf import Preference -setup_class = PluginEvent('setup_class', synchronous_only=True) +setup_class = PluginEvent("setup_class", synchronous_only=True) # enter_route = PluginEvent('enter_route') # exit_route = PluginEvent('exit_route') -setup_instance = PluginEvent('setup_instance') +setup_instance = PluginEvent("setup_instance") class APIRef: @@ -36,13 +42,16 @@ def __init__(self, ref_string: str): self._api = None @property - def api(self) -> Type['API']: + def api(self) -> Type["API"]: if self._api: return self._api from utilmeta.utils import import_obj + api = import_obj(self.ref) if not issubclass(api, API): - raise TypeError(f'Invalid ref: {repr(self.ref)}, should be an API class, got {api}') + raise TypeError( + f"Invalid ref: {repr(self.ref)}, should be an API class, got {api}" + ) self._api = api return api @@ -74,19 +83,19 @@ class API(PluginTarget): def _parse_bases(cls): base_routes = [] error_hooks = {} - properties = {} # take all base's properties as well + properties = {} # take all base's properties as well base_properties = {} annotations = {} - for base in reversed(cls.__bases__): # mro + for base in reversed(cls.__bases__): # mro if issubclass(base, API) and base.__bases__ != (object,): - annotations.update(getattr(base, '_annotations', {})) - base_routes.extend(getattr(base, '_routes', [])) - error_hooks.update(getattr(base, '_default_error_hooks', {})) - properties.update(getattr(base, '_properties', {})) + annotations.update(getattr(base, "_annotations", {})) + base_routes.extend(getattr(base, "_routes", [])) + error_hooks.update(getattr(base, "_default_error_hooks", {})) + properties.update(getattr(base, "_properties", {})) else: # if base have no annotations, there won't be any __annotations__ attribute - annotations.update(getattr(base, '__annotations__', {})) + annotations.update(getattr(base, "__annotations__", {})) # other common class mixin with no bases, check the properties for key, val in base.__dict__.items(): if inspect.isclass(val) and issubclass(val, Property): @@ -104,7 +113,7 @@ def _parse_bases(cls): cls._routes = base_routes cls._default_error_hooks = error_hooks - if Response.is_cls(getattr(cls, 'response', None)): + if Response.is_cls(getattr(cls, "response", None)): cls._response_cls = cls.response else: cls._response_cls = None @@ -120,26 +129,34 @@ def _check_unit_name(cls, name: str): if isinstance(base_attr, Endpoint): # override base class's unit (or hook) continue - raise AttributeError(f'{cls} function <{name}> is already a baseclass ({base}) attribute or method, ' - f'cannot make it a hook or api function, please change it to another name') + raise AttributeError( + f"{cls} function <{name}> is already a baseclass ({base}) attribute or method, " + f"cannot make it a hook or api function, please change it to another name" + ) def __init_subclass__(cls, **kwargs): - cls.__annotations__ = cls.__dict__.get('__annotations__', {}) + cls.__annotations__ = cls.__dict__.get("__annotations__", {}) cls._parse_bases() cls._generate_routes() cls._validate_routes() - cls._request_cls: Type[Request] = cls._annotations.get('request') or Request + cls._request_cls: Type[Request] = cls._annotations.get("request") or Request if not issubclass(cls._request_cls, Request): - raise TypeError(f'Invalid request class: {cls._request_cls}, must be subclass of Request') - req = getattr(cls, 'request', None) + raise TypeError( + f"Invalid request class: {cls._request_cls}, must be subclass of Request" + ) + req = getattr(cls, "request", None) if req is not None: if not isinstance(req, Request): - raise TypeError(f'Invalid "request" attribute: {req}, {cls} should use other attr names') + raise TypeError( + f'Invalid "request" attribute: {req}, {cls} should use other attr names' + ) cls._request_cls = req.__class__ - resp = getattr(cls, 'response', None) + resp = getattr(cls, "response", None) if resp is not None: if not issubclass(resp, Response): - raise TypeError(f'Invalid "response" attribute: {resp}, {cls} should use other attr names') + raise TypeError( + f'Invalid "response" attribute: {resp}, {cls} should use other attr names' + ) setup_class(cls, **kwargs) super().__init_subclass__(**kwargs) @@ -153,7 +170,7 @@ def _generate_routes(cls): # COLLECT ANNOTATIONS FORM for key, api in cls.__annotations__.items(): - if key.startswith('_'): + if key.startswith("_"): continue val = cls.__dict__.get(key) @@ -162,27 +179,31 @@ def _generate_routes(cls): elif is_annotated(api): # param: Annotated[str, request.QueryParam()] - for m in getattr(api, '__metadata__', []): + for m in getattr(api, "__metadata__", []): if inspect.isclass(m) and issubclass(m, Property): m = m() if isinstance(m, Property): cls._make_property(key, m) break - api = getattr(api, '__origin__', None) + api = getattr(api, "__origin__", None) if inspect.isclass(api) and issubclass(api, API): kwargs = dict(route=key, name=key, parent=cls) if not val: - val = getattr(api, '_generator', None) + val = getattr(api, "_generator", None) if isinstance(val, decorator.APIGenerator): kwargs.update(val.kwargs) elif inspect.isfunction(val): - raise TypeError(f'{cls.__name__}: generate route [{repr(key)}] failed: conflict api and endpoint') + raise TypeError( + f"{cls.__name__}: generate route [{repr(key)}] failed: conflict api and endpoint" + ) handlers.append(api) try: route = cls._route_cls(api, **kwargs) except Exception as e: - raise e.__class__(f'{cls.__name__}: generate route [{repr(key)}] failed with error: {e}') from e + raise e.__class__( + f"{cls.__name__}: generate route [{repr(key)}] failed with error: {e}" + ) from e if route.private: continue # route.initialize(cls) @@ -194,11 +215,11 @@ def _generate_routes(cls): if inspect.isclass(val) and issubclass(val, Property): # eg: logger: Logger if key not in cls.__dict__: - context = getattr(val, '__context__', None) + context = getattr(val, "__context__", None) if context and isinstance(context, Property): cls._make_property(key, context) - local_vars = {k: v for k, v in cls.__dict__.items() if not k.startswith('_')} + local_vars = {k: v for k, v in cls.__dict__.items() if not k.startswith("_")} for key, val in cls.__dict__.items(): if val in handlers: @@ -209,14 +230,16 @@ def _generate_routes(cls): if inspect.isclass(val) and issubclass(val, API): kwargs = dict(route=key, name=key, parent=cls) - generator = getattr(val, '_generator', None) + generator = getattr(val, "_generator", None) if isinstance(generator, decorator.APIGenerator): kwargs.update(generator.kwargs) handlers.append(val) try: route = cls._route_cls(val, **kwargs) except Exception as e: - raise e.__class__(f'{cls.__name__}: generate route [{repr(key)}] failed with error: {e}') from e + raise e.__class__( + f"{cls.__name__}: generate route [{repr(key)}] failed with error: {e}" + ) from e if route.private: continue # route.initialize(cls) @@ -238,13 +261,13 @@ def _generate_routes(cls): # 4. @api(method='CUSTOM') (method='custom') try: val = cls._endpoint_cls.apply_for( - val, cls, - name=key, - local_vars=local_vars + val, cls, name=key, local_vars=local_vars ) except Exception as e: - raise e.__class__(f'{cls.__name__}: ' - f'generate endpoint [{repr(key)}] failed with error: {e}') from e + raise e.__class__( + f"{cls.__name__}: " + f"generate endpoint [{repr(key)}] failed with error: {e}" + ) from e elif hook_type: val = cls._hook_cls.dispatch_for(val, hook_type) else: @@ -262,15 +285,17 @@ def _generate_routes(cls): val, name=key, route=val.route, - summary=val.getattr('summary'), - tags=val.getattr('tags'), - description=val.getattr('description'), - deprecated=val.getattr('deprecated'), - private=val.getattr('private'), - priority=val.getattr('priority') + summary=val.getattr("summary"), + tags=val.getattr("tags"), + description=val.getattr("description"), + deprecated=val.getattr("deprecated"), + private=val.getattr("private"), + priority=val.getattr("priority"), ) except Exception as e: - raise e.__class__(f'{cls.__name__}: generate route [{repr(key)}] failed with error: {e}') from e + raise e.__class__( + f"{cls.__name__}: generate route [{repr(key)}] failed with error: {e}" + ) from e routes.append(route) continue @@ -295,7 +320,7 @@ def _generate_routes(cls): # if not any() hooked, the target expression of the hook maybe invalid # we will give it a warning if not hook.hook_all: - msg = f'{cls}: unmatched hook: {hook} with targets: {hook.hook_targets}' + msg = f"{cls}: unmatched hook: {hook} with targets: {hook.hook_targets}" warnings.warn(msg) # from utilmeta.conf import config # if config.preference.ignore_unmatched_hooks: @@ -316,32 +341,35 @@ def _get_route_pattern(cls): for route in cls._routes: patterns.extend(route.get_patterns()) if not patterns: - return '' - return '^(%s)$' % '|'.join(patterns) + return "" + return "^(%s)$" % "|".join(patterns) @classonlymethod def _global_vars(cls): import sys + return sys.modules[cls.__module__].__dict__ @classonlymethod def _make_property(cls, name: str, prop: Property): - _in = getattr(prop.__in__, '__ident__', None) - if prop.__ident__ == 'body' or _in == 'body': - raise ValueError(f'{cls.__name__}: API class cannot define ' - f'Body or BodyParam common params: [{repr(name)}]') + _in = getattr(prop.__in__, "__ident__", None) + if prop.__ident__ == "body" or _in == "body": + raise ValueError( + f"{cls.__name__}: API class cannot define " + f"Body or BodyParam common params: [{repr(name)}]" + ) field = cls._parser_field_cls.generate( attname=name, default=prop, annotation=cls._annotations.get(name), options=cls.__options__, - global_vars=cls._global_vars() + global_vars=cls._global_vars(), ) inst = prop.init(field) - def getter(self: 'API'): + def getter(self: "API"): if name in self.__dict__: return self.__dict__[name] value = inst.get(self.request) @@ -350,23 +378,25 @@ def getter(self: 'API'): if not unprovided(default): # NOT CACHE return default - raise exc.BadRequest(f'{cls.__name__}: ' - f'{prop.__class__.__name__}({repr(field.name)}) not provided') + raise exc.BadRequest( + f"{cls.__name__}: " + f"{prop.__class__.__name__}({repr(field.name)}) not provided" + ) try: value = field.parse_value( - value, - context=self.__options__.make_context(cls) + value, context=self.__options__.make_context(cls) ) except ParseError as e: raise exc.BadRequest(str(e), detail=e.get_detail()) from e - self.__dict__[name] = value # auto-cached + self.__dict__[name] = value # auto-cached return value getter.__field__ = prop setter = None if prop.setter != Property.setter: - def setter(self: 'API', value): + + def setter(self: "API", value): inst.set(self.request, value) self.__dict__[name] = value # auto-cached @@ -380,8 +410,10 @@ def _validate_routes(cls): for api_route in cls._routes: if api_route.ident in route_idents: - raise ValueError(f'{cls}: api {api_route.handler} conflict with ' - f'{route_idents[api_route.ident]} on identity: {repr(api_route.ident)}') + raise ValueError( + f"{cls}: api {api_route.handler} conflict with " + f"{route_idents[api_route.ident]} on identity: {repr(api_route.ident)}" + ) route_idents[api_route.ident] = api_route.handler if not api_route.method: api_routes[api_route.route] = api_route.handler @@ -389,14 +421,16 @@ def _validate_routes(cls): for api_route in cls._routes: if api_route.method: if api_route.route in api_routes: - raise ValueError(f'{cls}: api function: {api_route.handler} ' - f'route: {repr(api_route.route)} conflict ' - f'with api class: {api_routes[api_route.route]}') + raise ValueError( + f"{cls}: api function: {api_route.handler} " + f"route: {repr(api_route.route)} conflict " + f"with api class: {api_routes[api_route.route]}" + ) # TODO: test if any static route is override by a higher priority dynamic route @classonlymethod def __reproduce_with__(cls, generator: decorator.APIGenerator): - plugins = generator.kwargs.get('plugins') + plugins = generator.kwargs.get("plugins") if plugins: cls._add_plugins(*plugins) cls._generator = generator @@ -407,14 +441,16 @@ def __reproduce_with__(cls, generator: decorator.APIGenerator): def __as__(cls, backend, route: str, *, asynchronous: bool = None): from utilmeta import UtilMeta from utilmeta.core.server.backends.base import ServerAdaptor + if isinstance(backend, UtilMeta): service = backend else: try: from utilmeta import service except ImportError: - service = UtilMeta(None, backend=backend, - name=route.strip('/').replace('/', '_')) + service = UtilMeta( + None, backend=backend, name=route.strip("/").replace("/", "_") + ) service._auto_created = True service.mount_to_api(cls, route=route) # backend can be a module name or application @@ -422,15 +458,19 @@ def __as__(cls, backend, route: str, *, asynchronous: bool = None): return adaptor.adapt(cls, route=route, asynchronous=asynchronous) @classonlymethod - def __mount__(cls, handler: Union[APIRoute, Type['API'], APIRef, Endpoint, str], route: str = '', - before_hooks: List[BeforeHook] = (), - after_hooks: List[AfterHook] = (), - error_hooks: Dict[Type[Exception], ErrorHook] = None, - ): + def __mount__( + cls, + handler: Union[APIRoute, Type["API"], APIRef, Endpoint, str], + route: str = "", + before_hooks: List[BeforeHook] = (), + after_hooks: List[AfterHook] = (), + error_hooks: Dict[Type[Exception], ErrorHook] = None, + ): if isinstance(handler, APIRef): handler = handler.api if isinstance(handler, str): from utilmeta.utils import import_obj + handler = import_obj(handler) if isinstance(handler, APIRoute): cls._routes.append(handler) @@ -438,19 +478,19 @@ def __mount__(cls, handler: Union[APIRoute, Type['API'], APIRef, Endpoint, str], api_route = cls._route_cls( handler=handler, route=route, - name=route.replace('/', '_'), + name=route.replace("/", "_"), before_hooks=before_hooks, after_hooks=after_hooks, - error_hooks=error_hooks + error_hooks=error_hooks, ) api_route.compile_route() cls._routes.append(api_route) - cls._validate_routes() # validate each time there is a new api mount + cls._validate_routes() # validate each time there is a new api mount def __init__(self, request): super().__init__() self.request = self._request_cls.apply_for(request) - self.response = getattr(self, 'response', Response) + self.response = getattr(self, "response", Response) # set request before setup instance, cause this hook may depend on the request context for key, val in self.__class__.__dict__.items(): if isinstance(val, Endpoint): @@ -471,10 +511,7 @@ def _init_properties(self): value = prop.get(self.request) if not unprovided(value): try: - value = prop.field.parse_value( - value, - context=context - ) + value = prop.field.parse_value(value, context=context) except ParseError as e: if self.request.is_options: # ignore parse error for OPTIONS request @@ -507,17 +544,18 @@ def _resolve(self) -> APIRoute: for route in method_routes.values(): distinct_add(headers, route.header_names) allow_headers.set(headers) - route_var.set('') + route_var.set("") if self.request.method not in method_routes: raise exc.MethodNotAllowed( - method=self.request.method, - allows=allow_methods.get() + method=self.request.method, allows=allow_methods.get() ) return method_routes[self.request.method] raise exc.NotFound(path=self.request.path) def _handle_error(self, error: Error): - hook = error.get_hook(self._error_hooks, exact=isinstance(error.exception, exc.Redirect)) + hook = error.get_hook( + self._error_hooks, exact=isinstance(error.exception, exc.Redirect) + ) # hook applied before handel_error plugin event if hook: result = hook(self, error) @@ -527,7 +565,9 @@ def _handle_error(self, error: Error): raise error.throw() async def _async_handle_error(self, error: Error): - hook = error.get_hook(self._error_hooks, exact=isinstance(error.exception, exc.Redirect)) + hook = error.get_hook( + self._error_hooks, exact=isinstance(error.exception, exc.Redirect) + ) # hook applied before handel_error plugin event if hook: result = hook(self, error) @@ -548,19 +588,18 @@ def _make_response(self, response, force: bool = False): for i, resp_type in enumerate(self._response_types): try: return resp_type( - response, - request=request, - strict=pref.api_default_strict_response + response, request=request, strict=pref.api_default_strict_response ) except Exception as e: - if i == len(self._response_types) - 1 and pref.api_default_strict_response: + if ( + i == len(self._response_types) - 1 + and pref.api_default_strict_response + ): raise e from e continue if self._response_cls: return self._response_cls( - response, - request=request, - strict=pref.api_default_strict_response + response, request=request, strict=pref.api_default_strict_response ) if force: return Response(response, request=request) @@ -579,7 +618,9 @@ async def __async_handler__(self): return await route.aserve(self) def __call__(self) -> Union[Response, Any]: - handler = self._chain_cls(self).build_api_handler(self.__class__.__handler__, asynchronous=False) + handler = self._chain_cls(self).build_api_handler( + self.__class__.__handler__, asynchronous=False + ) try: resp = handler(self) except Exception as e: @@ -588,7 +629,9 @@ def __call__(self) -> Union[Response, Any]: @awaitable(__call__) async def __call__(self) -> Union[Response, Any]: - handler = self._chain_cls(self).build_api_handler(self.__class__.__async_handler__, asynchronous=True) + handler = self._chain_cls(self).build_api_handler( + self.__class__.__async_handler__, asynchronous=True + ) try: resp = await handler(self) except Exception as e: @@ -596,14 +639,20 @@ async def __call__(self) -> Union[Response, Any]: return self._make_response(resp) def options(self): - return Response(headers={ - Header.ALLOW: ','.join(set([m.upper() for m in var.allow_methods.getter(self.request)])), - Header.LENGTH: '0' - }) + return Response( + headers={ + Header.ALLOW: ",".join( + set([m.upper() for m in var.allow_methods.getter(self.request)]) + ), + Header.LENGTH: "0", + } + ) def __serve__(self, unit): if isinstance(unit, Endpoint): - handler = self._chain_cls(self, unit).build_api_handler(unit.handler, asynchronous=False) + handler = self._chain_cls(self, unit).build_api_handler( + unit.handler, asynchronous=False + ) var.endpoint_ref.setter(self.request, unit.ref) self._response_types = unit.response_types return handler(self) @@ -612,7 +661,9 @@ def __serve__(self, unit): async def __aserve__(self, unit): if isinstance(unit, Endpoint): - handler = self._chain_cls(self, unit).build_api_handler(unit.async_handler, asynchronous=True) + handler = self._chain_cls(self, unit).build_api_handler( + unit.async_handler, asynchronous=True + ) var.endpoint_ref.setter(self.request, unit.ref) self._response_types = unit.response_types return await handler(self) diff --git a/utilmeta/core/api/chain.py b/utilmeta/core/api/chain.py index ffaf547..2469aef 100644 --- a/utilmeta/core/api/chain.py +++ b/utilmeta/core/api/chain.py @@ -13,11 +13,13 @@ def __init__(self, *targets: PluginTarget): self.targets = targets self.pref = Preference.get() - def chain_plugins(self, *events: PluginEvent, - required: bool = False, - reverse: bool = False, - asynchronous: bool = None - ) -> Tuple[Callable, ...]: + def chain_plugins( + self, + *events: PluginEvent, + required: bool = False, + reverse: bool = False, + asynchronous: bool = None, + ) -> Tuple[Callable, ...]: targets = self.targets _classes = set() for target in reversed(targets) if reverse else targets: @@ -29,12 +31,17 @@ def chain_plugins(self, *events: PluginEvent, if not plugins or not isinstance(plugins, dict): continue - for plugin_cls, plugin in reversed(plugins.items()) if reverse else plugins.items(): + for plugin_cls, plugin in ( + reversed(plugins.items()) if reverse else plugins.items() + ): if plugin_cls in _classes: # in case for more than 1 plugin target continue - handlers = [event.get(plugin, target=target, asynchronous=asynchronous) for event in events] + handlers = [ + event.get(plugin, target=target, asynchronous=asynchronous) + for event in events + ] if not any(handlers): continue @@ -75,8 +82,9 @@ async def async_process(cls, obj, handler: Callable): class APIChainBuilder(BaseChainBuilder): def __init__(self, api, endpoint: Endpoint = None): from utilmeta.core.api import API + if not isinstance(api, API): - raise TypeError(f'Invalid API: {api}') + raise TypeError(f"Invalid API: {api}") super().__init__(endpoint or api) self.api = api self.endpoint = endpoint @@ -99,8 +107,7 @@ async def async_api_handler( while True: try: api.request.adaptor.update_context( - retry_index=retry_index, - idempotent=self.idempotent + retry_index=retry_index, idempotent=self.idempotent ) req = api.request if request_handler: @@ -114,8 +121,7 @@ async def async_api_handler( response = req if response_handler: res = await self.async_process( - api._make_response(response, force=True), - response_handler + api._make_response(response, force=True), response_handler ) else: # successfully get response without response handler @@ -135,7 +141,9 @@ async def async_api_handler( break retry_index += 1 if retry_index >= self.pref.api_max_retry_loops: - raise exceptions.MaxRetriesExceed(max_retries=self.pref.api_max_retry_loops) + raise exceptions.MaxRetriesExceed( + max_retries=self.pref.api_max_retry_loops + ) return res def api_handler( @@ -150,8 +158,7 @@ def api_handler( while True: try: api.request.adaptor.update_context( - retry_index=retry_index, - idempotent=self.idempotent + retry_index=retry_index, idempotent=self.idempotent ) req = api.request if request_handler: @@ -163,8 +170,7 @@ def api_handler( response = req if response_handler: res = self.process( - api._make_response(response, force=True), - response_handler + api._make_response(response, force=True), response_handler ) else: # successfully get response without response handler @@ -184,7 +190,9 @@ def api_handler( break retry_index += 1 if retry_index >= self.pref.api_max_retry_loops: - raise exceptions.MaxRetriesExceed(max_retries=self.pref.api_max_retry_loops) + raise exceptions.MaxRetriesExceed( + max_retries=self.pref.api_max_retry_loops + ) return res def chain_api_handler( @@ -193,38 +201,49 @@ def chain_api_handler( request_handler=None, response_handler=None, error_handler=None, - asynchronous: bool = None + asynchronous: bool = None, ): if not any([request_handler, response_handler, error_handler]): return handler from utilmeta.core.api import API + if asynchronous: + @wraps(handler) async def wrapper(api: API = self.api): return await self.async_api_handler( - api, handler, + api, + handler, request_handler=request_handler, response_handler=response_handler, - error_handler=error_handler + error_handler=error_handler, ) + else: + @wraps(handler) def wrapper(api: API = self.api): return self.api_handler( - api, handler, + api, + handler, request_handler=request_handler, response_handler=response_handler, - error_handler=error_handler + error_handler=error_handler, ) + return wrapper def build_api_handler(self, handler, asynchronous: bool = None): # --- if asynchronous is None: - asynchronous = inspect.iscoroutinefunction(handler) or inspect.isasyncgenfunction(handler) + asynchronous = inspect.iscoroutinefunction( + handler + ) or inspect.isasyncgenfunction(handler) for request_handler, response_handler, error_handler in self.chain_plugins( - process_request, process_response, handle_error, + process_request, + process_response, + handle_error, required=False, asynchronous=asynchronous, ): @@ -233,6 +252,6 @@ def build_api_handler(self, handler, asynchronous: bool = None): request_handler=request_handler, response_handler=response_handler, error_handler=error_handler, - asynchronous=asynchronous + asynchronous=asynchronous, ) return handler diff --git a/utilmeta/core/api/decorator.py b/utilmeta/core/api/decorator.py index 60b7f46..49e8c93 100644 --- a/utilmeta/core/api/decorator.py +++ b/utilmeta/core/api/decorator.py @@ -4,17 +4,17 @@ import warnings -T = TypeVar('T') +T = TypeVar("T") __all__ = [ # bare route: only for API/Module # 'APIDecoratorWrapper', # 'APIGenerator', # hooks - 'before', - 'after', - 'handle', - 'plugin', + "before", + "after", + "handle", + "plugin", ] @@ -22,31 +22,39 @@ def set_hook(f, hook_type: str, value, priority: int = None): if not f: return f if not inspect.isfunction(f): - raise TypeError(f'Invalid hook: {f}, must be a function') + raise TypeError(f"Invalid hook: {f}, must be a function") - if f.__name__.startswith('_'): - raise ValueError(f'{hook_type} hook func: <{f.__name__}> is startswith "_", which will not be' - f' recognized as a api hook') + if f.__name__.startswith("_"): + raise ValueError( + f'{hook_type} hook func: <{f.__name__}> is startswith "_", which will not be' + f" recognized as a api hook" + ) if f.__name__.lower() in CommonMethod.gen(): - raise ValueError(f'{hook_type} hook func: <{f.__name__}> name is a HTTP method, which will not be' - f' recognized as a api hook') + raise ValueError( + f"{hook_type} hook func: <{f.__name__}> name is a HTTP method, which will not be" + f" recognized as a api hook" + ) if hasattr(f, EndpointAttr.method): - raise ValueError(f'{hook_type} hook func: {f} has HTTP method set, which means it is a api endpoint') + raise ValueError( + f"{hook_type} hook func: {f} has HTTP method set, which means it is a api endpoint" + ) assert hook_type in HOOK_TYPES for t in HOOK_TYPES: if hasattr(f, t): if t != hook_type: - raise AttributeError(f'Function: {f.__name__} is already ' - f'hook for <{t}>, cannot hook for {hook_type}') + raise AttributeError( + f"Function: {f.__name__} is already " + f"hook for <{t}>, cannot hook for {hook_type}" + ) else: return - setattr(f, EndpointAttr.hook, hook_type) # indicate the hook type - setattr(f, hook_type, value) # indicate the hook params + setattr(f, EndpointAttr.hook, hook_type) # indicate the hook type + setattr(f, hook_type, value) # indicate the hook params if priority: - setattr(f, 'priority', priority) + setattr(f, "priority", priority) return f @@ -90,14 +98,18 @@ def decorator(self, func, generator: APIGenerator = None): kwargs = generator.kwargs if generator else {} if isinstance(func, (staticmethod, classmethod)): - raise TypeError(f'@api can only decorate instance method or API class, got {func}') + raise TypeError( + f"@api can only decorate instance method or API class, got {func}" + ) if inspect.isclass(func): if self.method: - raise ValueError(f'@api.{self.method} cannot decorate an API class: {func}') - rep_func = getattr(func, '__reproduce_with__', None) + raise ValueError( + f"@api.{self.method} cannot decorate an API class: {func}" + ) + rep_func = getattr(func, "__reproduce_with__", None) if not rep_func: - raise ValueError(f'@api decorated an unsupported class: {func}') + raise ValueError(f"@api decorated an unsupported class: {func}") return rep_func(generator) name = func.__name__.lower() @@ -105,14 +117,16 @@ def decorator(self, func, generator: APIGenerator = None): if self.method: if self.method != name: raise ValueError( - f'HTTP Method: {self.method} ' - f'must not decorate another Http Method named function: {name}' + f"HTTP Method: {self.method} " + f"must not decorate another Http Method named function: {name}" ) self.method = name - if name.startswith('_'): - raise ValueError(f'Endpoint func: <{func.__name__}> is startswith "_", which will not be' - f' recognized as a api function') + if name.startswith("_"): + raise ValueError( + f'Endpoint func: <{func.__name__}> is startswith "_", which will not be' + f" recognized as a api function" + ) # if kwargs.get('route') == func.__name__: # warnings.warn(f'Endpoint alias is same as function name: {func.__name__},' @@ -124,27 +138,31 @@ def decorator(self, func, generator: APIGenerator = None): return func - def __call__(self, *fn_or_routes, - cls=None, - summary: str = None, - # alias: Union[str, List[str]] = None, - deprecated: bool = None, - idempotent: bool = None, - private: bool = None, - priority: int = None, - eager: bool = None, - tags: List[Union[dict, str]] = None, - description: str = None, - extension: dict = None, - **kwargs, - ): + def __call__( + self, + *fn_or_routes, + cls=None, + summary: str = None, + # alias: Union[str, List[str]] = None, + deprecated: bool = None, + idempotent: bool = None, + private: bool = None, + priority: int = None, + eager: bool = None, + tags: List[Union[dict, str]] = None, + description: str = None, + extension: dict = None, + **kwargs, + ): if len(fn_or_routes) == 1: f = fn_or_routes[0] - if inspect.isfunction(f): # no parameter @api.get will wrap a callable as this param - if getattr(f, 'method', None): + if inspect.isfunction( + f + ): # no parameter @api.get will wrap a callable as this param + if getattr(f, "method", None): # already decorated - route = getattr(f, 'route', f.__name__) + route = getattr(f, "route", f.__name__) else: return self.decorator(f) elif isinstance(f, str): @@ -160,7 +178,7 @@ def __call__(self, *fn_or_routes, del f del fn_or_routes for key, val in locals().items(): - if key == 'self' or key == 'kwargs': + if key == "self" or key == "kwargs": continue if val is None: continue @@ -178,19 +196,24 @@ def before(*units, excludes=None, priority: int = None): # units = ['*'] # return wrapper(func) if not units: - units = ['*'] - if '*' in units: - assert len(units) == 1, f'@api.before("*") means hook all units, remove the redundant units' + units = ["*"] + if "*" in units: + assert ( + len(units) == 1 + ), f'@api.before("*") means hook all units, remove the redundant units' elif excludes: - raise ValueError('@api.before excludes only affect when @api.before("*") is applied') + raise ValueError( + '@api.before excludes only affect when @api.before("*") is applied' + ) def wrapper(f: T) -> T: return set_hook( set_excludes(f, excludes), EndpointAttr.before_hook, units, - priority=priority + priority=priority, ) + return wrapper @@ -200,19 +223,21 @@ def after(*units, excludes=None, priority: int = None): # units = ['*'] # return wrapper(func) if not units: - units = ['*'] - if '*' in units: - assert len(units) == 1, f'@api.after("*") means hook all units, remove the redundant units' + units = ["*"] + if "*" in units: + assert ( + len(units) == 1 + ), f'@api.after("*") means hook all units, remove the redundant units' elif excludes: - raise ValueError('@api.after excludes only affect when @api.after("*") is applied') + raise ValueError( + '@api.after excludes only affect when @api.after("*") is applied' + ) def wrapper(f: T) -> T: return set_hook( - set_excludes(f, excludes), - EndpointAttr.after_hook, - units, - priority=priority + set_excludes(f, excludes), EndpointAttr.after_hook, units, priority=priority ) + return wrapper @@ -226,20 +251,20 @@ def handle(*unit_and_errors, excludes=None, priority: int = None): units.append(e) if not units: - units = ['*'] + units = ["*"] if not errors: errors = (Exception,) - if '*' in units: - assert len(units) == 1, f'@api.accept("*") means hook all units, remove the redundant units' + if "*" in units: + assert ( + len(units) == 1 + ), f'@api.accept("*") means hook all units, remove the redundant units' def wrapper(f: T) -> T: setattr(f, EndpointAttr.errors, errors) return set_hook( - set_excludes(f, excludes), - EndpointAttr.error_hook, - units, - priority=priority + set_excludes(f, excludes), EndpointAttr.error_hook, units, priority=priority ) + return wrapper @@ -248,4 +273,5 @@ def wrapper(func): for plg in plugins: plg(func) return func + return wrapper diff --git a/utilmeta/core/api/endpoint.py b/utilmeta/core/api/endpoint.py index 917675a..361803f 100644 --- a/utilmeta/core/api/endpoint.py +++ b/utilmeta/core/api/endpoint.py @@ -18,9 +18,9 @@ if TYPE_CHECKING: from .base import API -process_request = PluginEvent('process_request', streamline_result=True) -handle_error = PluginEvent('handle_error') -process_response = PluginEvent('process_response', streamline_result=True) +process_request = PluginEvent("process_request", streamline_result=True) +handle_error = PluginEvent("handle_error") +process_response = PluginEvent("process_response", streamline_result=True) class RequestContextWrapper(ContextWrapper): @@ -33,18 +33,20 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # used in generate allow headers - def init_prop(self, prop: Property, val: ParserField): # noqa, to be inherit - if prop.__ident__ == 'header': + def init_prop(self, prop: Property, val: ParserField): # noqa, to be inherit + if prop.__ident__ == "header": for origin in val.input_origins: - parser = getattr(origin, '__parser__', None) + parser = getattr(origin, "__parser__", None) if isinstance(parser, BaseParser): - utils.distinct_add(self.header_names, [str(v).lower() for v in parser.fields]) - elif prop.__in__ and getattr(prop.__in__, '__ident__', None) == 'header': + utils.distinct_add( + self.header_names, [str(v).lower() for v in parser.fields] + ) + elif prop.__in__ and getattr(prop.__in__, "__ident__", None) == "header": name = val.name.lower() if name not in self.header_names: self.header_names += name else: - headers = getattr(prop, 'headers', None) + headers = getattr(prop, "headers", None) if headers and utils.multi(headers): utils.distinct_add(self.header_names, [str(v).lower() for v in headers]) return prop.init(val) @@ -54,20 +56,29 @@ def contains_file(cls, field: ParserField): def file_like(file_cls): from utilmeta.core.file import File from io import BytesIO + if inspect.isclass(file_cls): if issubclass(file_cls, (File, BytesIO, bytearray)): return True return False - + from utype import Rule + if isinstance(field.type, type) and issubclass(field.type, Rule): # try to find List[schema] - if isinstance(field.type.__origin__, LogicalType) and field.type.__origin__.combinator: + if ( + isinstance(field.type.__origin__, LogicalType) + and field.type.__origin__.combinator + ): for arg in field.type.__origin__.args: if file_like(arg): return True else: - if field.type.__origin__ and issubclass(field.type.__origin__, list) and field.type.__args__: + if ( + field.type.__origin__ + and issubclass(field.type.__origin__, list) + and field.type.__args__ + ): # we only accept list, not tuple/set arg = field.type.__args__[0] if file_like(arg): @@ -83,18 +94,26 @@ def file_like(file_cls): def validate_method(self, method: str): # 1. only PUT / PATCH / POST allow to have body params # 2. File params not allowed in - allow_body = method.lower() in ['post', 'put', 'patch'] + allow_body = method.lower() in ["post", "put", "patch"] for key, val in self.properties.items(): field = val.field prop = val.prop - if prop.__ident__ == 'body' or prop.__in__ and getattr(prop.__in__, '__ident__', None) == 'body': + if ( + prop.__ident__ == "body" + or prop.__in__ + and getattr(prop.__in__, "__ident__", None) == "body" + ): if not allow_body: - raise InvalidDeclaration(f'body param: {repr(key)} not supported in method: {repr(method)}') + raise InvalidDeclaration( + f"body param: {repr(key)} not supported in method: {repr(method)}" + ) else: if self.contains_file(field): - raise InvalidDeclaration(f'request param: {repr(key)} used file as type declaration,' - f' which is not supported, file must be declared in Body or BodyParam') + raise InvalidDeclaration( + f"request param: {repr(key)} used file as type declaration," + f" which is not supported, file must be declared in Body or BodyParam" + ) class BaseEndpoint(PluginTarget): @@ -108,21 +127,24 @@ class BaseEndpoint(PluginTarget): # result will be parsed in the end of endpoint.serve # STRICT_RESULT = False - def __init__(self, f: Callable, *, - method: str, - name: str = None, - # the attribute name of API/SDK class - # instead of function name (maybe affected by not-@wrap decorator function) - plugins: list = None, - idempotent: bool = None, - eager: bool = False, - local_vars: dict = None, - ): + def __init__( + self, + f: Callable, + *, + method: str, + name: str = None, + # the attribute name of API/SDK class + # instead of function name (maybe affected by not-@wrap decorator function) + plugins: list = None, + idempotent: bool = None, + eager: bool = False, + local_vars: dict = None, + ): super().__init__(plugins=plugins) if not inspect.isfunction(f): - raise TypeError(f'Invalid endpoint function: {f}') + raise TypeError(f"Invalid endpoint function: {f}") self.f = f self.method = method @@ -137,13 +159,12 @@ def __init__(self, f: Callable, *, parser.resolve_forward_refs(local_vars=local_vars) # resolve even if local_vars is None self.wrapper = self.wrapper_cls( - parser, - default_properties=self.default_wrapper_properties + parser, default_properties=self.default_wrapper_properties ) self.executor = self.parser.wrap( eager_parse=self.eager, parse_params=self.PARSE_PARAMS, - parse_result=self.PARSE_RESULT + parse_result=self.PARSE_RESULT, ) self.sync_wrapper = None @@ -152,28 +173,28 @@ def __init__(self, f: Callable, *, self.async_executor = None # -- adapt @awaitable - if getattr(self.f, '_awaitable', False): - sync_func = getattr(self.f, '_syncfunc', None) - async_func = getattr(self.f, '_asyncfunc', None) + if getattr(self.f, "_awaitable", False): + sync_func = getattr(self.f, "_syncfunc", None) + async_func = getattr(self.f, "_asyncfunc", None) if sync_func: self.sync_wrapper = self.wrapper_cls( self.parser_cls.apply_for(sync_func), - default_properties=self.default_wrapper_properties + default_properties=self.default_wrapper_properties, ) self.sync_executor = self.sync_wrapper.parser.wrap( eager_parse=self.eager, parse_params=self.PARSE_PARAMS, - parse_result=self.PARSE_RESULT + parse_result=self.PARSE_RESULT, ) if async_func: self.async_wrapper = self.wrapper_cls( self.parser_cls.apply_for(async_func), - default_properties=self.default_wrapper_properties + default_properties=self.default_wrapper_properties, ) self.async_executor = self.async_wrapper.parser.wrap( eager_parse=self.eager, parse_params=self.PARSE_PARAMS, - parse_result=self.PARSE_RESULT + parse_result=self.PARSE_RESULT, ) self.response_types: List[Type[Response]] = parse_responses(self.return_type) @@ -216,9 +237,7 @@ def get_executor(self, asynchronous: bool = False): @property def default_wrapper_properties(self): - return { - key: PathParam for key in self.path_names - } + return {key: PathParam for key in self.path_names} def iter_plugins(self): for cls, plugin in self._plugins.items(): @@ -229,7 +248,7 @@ def getattr(self, name: str, default=None): @property def module_name(self): - return getattr(self.f, '__module__', None) + return getattr(self.f, "__module__", None) @property def is_method(self): @@ -241,7 +260,7 @@ def is_passed(self): @property def route(self): - return self.getattr('route', '' if self.is_method else self.name) + return self.getattr("route", "" if self.is_method else self.name) @property def parser(self): @@ -250,11 +269,14 @@ def parser(self): class Endpoint(BaseEndpoint): @classmethod - def apply_for(cls, func: Callable, api: Type['API'] = None, - name: str = None, - local_vars: dict = None, - ): - _cls = getattr(func, 'cls', None) + def apply_for( + cls, + func: Callable, + api: Type["API"] = None, + name: str = None, + local_vars: dict = None, + ): + _cls = getattr(func, "cls", None) if not _cls or not issubclass(_cls, Endpoint): # override current class _cls = cls @@ -274,21 +296,24 @@ def apply_for(cls, func: Callable, api: Type['API'] = None, kwargs.update(local_vars=local_vars) return _cls(func, **kwargs) - def __init__(self, f: Callable, *, - api: Type['API'] = None, - method: str, - name: str = None, - plugins: list = None, - idempotent: bool = None, - eager: bool = False, - # openapi specs: - operation_id: str = None, - tags: list = None, - summary: str = None, - description: str = None, - local_vars: dict = None, - extension: dict = None, - ): + def __init__( + self, + f: Callable, + *, + api: Type["API"] = None, + method: str, + name: str = None, + plugins: list = None, + idempotent: bool = None, + eager: bool = False, + # openapi specs: + operation_id: str = None, + tags: list = None, + summary: str = None, + description: str = None, + local_vars: dict = None, + extension: dict = None, + ): super().__init__( f, @@ -297,10 +322,10 @@ def __init__(self, f: Callable, *, name=name, idempotent=idempotent, eager=eager, - local_vars=local_vars + local_vars=local_vars, ) self.api = api - self.response_type = getattr(api, 'response', None) + self.response_type = getattr(api, "response", None) if self.response_type and not Response.is_cls(self.response_type): self.response_type = None self.operation_id = operation_id @@ -313,7 +338,7 @@ def __call__(self, *args, **kwargs): executor = self.get_executor(False) r = executor(*args, **kwargs) if inspect.isawaitable(r): - raise exc.ServerError('awaitable detected in sync function') + raise exc.ServerError("awaitable detected in sync function") return r @utils.awaitable(__call__) @@ -331,25 +356,25 @@ def openapi_extension(self): if not self.extension or not isinstance(self.extension, dict): return ext for key, val in self.extension.items(): - key = str(key).lower().replace('_', '-') - if not key.startswith('x-'): - key = f'x-{key}' + key = str(key).lower().replace("_", "-") + if not key.startswith("x-"): + key = f"x-{key}" ext[key] = val return ext @property def ref(self) -> str: if self.api: - return f'{self.api.__ref__}.{self.name}' + return f"{self.api.__ref__}.{self.name}" if self.module_name: - return f'{self.module_name}.{self.name}' + return f"{self.module_name}.{self.name}" return self.name - def handler(self, api: 'API'): + def handler(self, api: "API"): args, kwargs = self.parse_request(api.request) return self(api, *args, **kwargs) - async def async_handler(self, api: 'API'): + async def async_handler(self, api: "API"): args, kwargs = await self.async_parse_request(api.request) return await self(api, *args, **kwargs) @@ -358,7 +383,9 @@ def parse_request(self, request: Request): kwargs = dict(var.path_params.getter(request)) wrapper = self.get_wrapper(False) kwargs.update(wrapper.parse_context(request)) - return wrapper.parser.parse_params((), kwargs, context=wrapper.parser.options.make_context()) + return wrapper.parser.parse_params( + (), kwargs, context=wrapper.parser.options.make_context() + ) except utype.exc.ParseError as e: raise exc.BadRequest(str(e), detail=e.get_detail()) from e @@ -367,7 +394,9 @@ async def async_parse_request(self, request: Request): kwargs = dict(await var.path_params.getter(request)) wrapper = self.get_wrapper(True) kwargs.update(await wrapper.async_parse_context(request)) - return wrapper.parser.parse_params((), kwargs, context=wrapper.parser.options.make_context()) + return wrapper.parser.parse_params( + (), kwargs, context=wrapper.parser.options.make_context() + ) # in base Endpoint, args is not supported except utype.exc.ParseError as e: raise exc.BadRequest(str(e), detail=e.get_detail()) from e diff --git a/utilmeta/core/api/hook.py b/utilmeta/core/api/hook.py index b0504be..95e3327 100644 --- a/utilmeta/core/api/hook.py +++ b/utilmeta/core/api/hook.py @@ -22,7 +22,9 @@ class Hook: parse_result = False @classmethod - def dispatch_for(cls, func: Callable, hook_type: str, target_type: str = 'api') -> 'Hook': + def dispatch_for( + cls, func: Callable, hook_type: str, target_type: str = "api" + ) -> "Hook": for hook in cls.__subclasses__(): hook: Type[Hook] try: @@ -31,29 +33,33 @@ def dispatch_for(cls, func: Callable, hook_type: str, target_type: str = 'api') continue if cls.hook_type == hook_type and cls.target_type == target_type: return cls.apply_for(func) - raise NotImplementedError(f'{cls}: cannot dispatch for hook: {hook_type} in target: {repr(target_type)}') + raise NotImplementedError( + f"{cls}: cannot dispatch for hook: {hook_type} in target: {repr(target_type)}" + ) @classmethod - def apply_for(cls, func: Callable) -> 'Hook': + def apply_for(cls, func: Callable) -> "Hook": if not hasattr(func, utils.EndpointAttr.hook): - raise ValueError(f'Hook type for function: {func} is not specified') + raise ValueError(f"Hook type for function: {func} is not specified") return cls( func, hook_type=getattr(func, utils.EndpointAttr.hook), hook_targets=getattr(func, cls.hook_type, None), hook_excludes=getattr(func, utils.EndpointAttr.excludes, None), - priority=getattr(func, 'priority', None) + priority=getattr(func, "priority", None), ) - def __init__(self, f: Callable, - hook_type: str, - hook_targets: list = None, - hook_excludes: list = None, - priority: int = None, - ): + def __init__( + self, + f: Callable, + hook_type: str, + hook_targets: list = None, + hook_excludes: list = None, + priority: int = None, + ): if not inspect.isfunction(f): - raise TypeError(f'Invalid endpoint function: {f}') + raise TypeError(f"Invalid endpoint function: {f}") self.f = f self.hook_type = hook_type @@ -68,7 +74,7 @@ def __init__(self, f: Callable, @property def hook_all(self): - return '*' in self.hook_targets + return "*" in self.hook_targets @property def error_hook(self): @@ -101,24 +107,27 @@ async def __call__(self, *args, **kwargs): class BeforeHook(Hook): hook_type = utils.EndpointAttr.before_hook - target_type = 'api' + target_type = "api" wrapper_cls = RequestContextWrapper # parse_params = False # already pared for request @classmethod - def apply_for(cls, func: Callable) -> 'BeforeHook': + def apply_for(cls, func: Callable) -> "BeforeHook": return cls( func, hook_targets=getattr(func, cls.hook_type), hook_excludes=getattr(func, utils.EndpointAttr.excludes, None), - priority=getattr(func, 'priority', None) + priority=getattr(func, "priority", None), ) - def __init__(self, f: Callable, - hook_targets: list = None, - hook_excludes: list = None, - priority: int = None): + def __init__( + self, + f: Callable, + hook_targets: list = None, + hook_excludes: list = None, + priority: int = None, + ): super().__init__( f, hook_type=utils.EndpointAttr.before_hook, @@ -132,7 +141,9 @@ def parse_request(self, request: Request): try: kwargs = dict(var.path_params.getter(request)) kwargs.update(self.wrapper.parse_context(request)) - return self.parser.parse_params((), kwargs, context=self.parser.options.make_context()) + return self.parser.parse_params( + (), kwargs, context=self.parser.options.make_context() + ) except utype.exc.ParseError as e: raise exceptions.BadRequest(str(e), detail=e.get_detail()) from e @@ -140,39 +151,44 @@ async def async_parse_request(self, request: Request): try: kwargs = dict(await var.path_params.getter(request)) kwargs.update(await self.wrapper.async_parse_context(request)) - return self.parser.parse_params((), kwargs, context=self.parser.options.make_context()) + return self.parser.parse_params( + (), kwargs, context=self.parser.options.make_context() + ) # in base Endpoint, args is not supported except utype.exc.ParseError as e: raise exceptions.BadRequest(str(e), detail=e.get_detail()) from e - def serve(self, api: 'API'): + def serve(self, api: "API"): args, kwargs = self.parse_request(api.request) return self(api, *args, **kwargs) - async def aserve(self, api: 'API'): + async def aserve(self, api: "API"): args, kwargs = await self.async_parse_request(api.request) return await self(api, *args, **kwargs) class AfterHook(Hook): hook_type = utils.EndpointAttr.after_hook - target_type = 'api' + target_type = "api" parse_params = True # parse_result = True @classmethod - def apply_for(cls, func: Callable) -> 'AfterHook': + def apply_for(cls, func: Callable) -> "AfterHook": return cls( func, hook_targets=getattr(func, cls.hook_type), hook_excludes=getattr(func, utils.EndpointAttr.excludes, None), - priority=getattr(func, 'priority', None) + priority=getattr(func, "priority", None), ) - def __init__(self, f: Callable, - hook_targets: list = None, - hook_excludes: list = None, - priority: int = None): + def __init__( + self, + f: Callable, + hook_targets: list = None, + hook_excludes: list = None, + priority: int = None, + ): super().__init__( f, hook_type=utils.EndpointAttr.after_hook, @@ -189,24 +205,27 @@ def prepare(self, api, *args, **kwargs): class ErrorHook(Hook): hook_type = utils.EndpointAttr.error_hook - target_type = 'api' + target_type = "api" parse_params = True @classmethod - def apply_for(cls, func: Callable) -> 'ErrorHook': + def apply_for(cls, func: Callable) -> "ErrorHook": return cls( func, hook_targets=getattr(func, cls.hook_type), hook_excludes=getattr(func, utils.EndpointAttr.excludes, None), hook_errors=getattr(func, utils.EndpointAttr.errors, None), - priority=getattr(func, 'priority', None) + priority=getattr(func, "priority", None), ) - def __init__(self, f: Callable, - hook_targets: list = None, - hook_excludes: list = None, - hook_errors: list = None, - priority: int = None): + def __init__( + self, + f: Callable, + hook_targets: list = None, + hook_excludes: list = None, + hook_errors: list = None, + priority: int = None, + ): super().__init__( f, hook_type=utils.EndpointAttr.error_hook, diff --git a/utilmeta/core/api/plugins/base.py b/utilmeta/core/api/plugins/base.py index d8ff4b2..c9094ba 100644 --- a/utilmeta/core/api/plugins/base.py +++ b/utilmeta/core/api/plugins/base.py @@ -3,9 +3,9 @@ from utilmeta.core.response import Response -process_response = PluginEvent('process_response', streamline_result=True) -process_request = PluginEvent('process_request', streamline_result=True) -handle_error = PluginEvent('handle_error') +process_response = PluginEvent("process_response", streamline_result=True) +process_request = PluginEvent("process_request", streamline_result=True) +handle_error = PluginEvent("handle_error") class APIPlugin(PluginBase): @@ -26,6 +26,7 @@ def handle_error(self, error: Error): def inject(self, target_class): # inject to the endpoints from ..base import API + if isinstance(target_class, type) and issubclass(target_class, API): for route in target_class._routes: self.inject(route.handler) diff --git a/utilmeta/core/api/plugins/cors.py b/utilmeta/core/api/plugins/cors.py index 02fa0cf..146c29f 100644 --- a/utilmeta/core/api/plugins/cors.py +++ b/utilmeta/core/api/plugins/cors.py @@ -10,21 +10,28 @@ class CORSPlugin(APIPlugin): - DEFAULT_ALLOW_HEADERS = ('content-type', 'content-length', 'accept', 'origin', 'user-agent') + DEFAULT_ALLOW_HEADERS = ( + "content-type", + "content-length", + "accept", + "origin", + "user-agent", + ) EXCLUDED_STATUS = (502, 503, 504) - def __init__(self, - allow_origin: Union[List[str], str] = None, - cors_max_age: Union[int, timedelta, float] = None, - allow_headers: List[str] = (), - allow_errors: List[Type[Exception]] = (Exception,), - expose_headers: List[str] = None, - csrf_exempt: bool = None, - exclude_statuses: List[int] = EXCLUDED_STATUS, - gen_csrf_token: bool = None, - options_200: bool = True, - override: bool = False, - ): + def __init__( + self, + allow_origin: Union[List[str], str] = None, + cors_max_age: Union[int, timedelta, float] = None, + allow_headers: List[str] = (), + allow_errors: List[Type[Exception]] = (Exception,), + expose_headers: List[str] = None, + csrf_exempt: bool = None, + exclude_statuses: List[int] = EXCLUDED_STATUS, + gen_csrf_token: bool = None, + options_200: bool = True, + override: bool = False, + ): super().__init__(locals()) self.csrf_exempt = csrf_exempt @@ -33,12 +40,16 @@ def __init__(self, if isinstance(allow_origin, str): allow_origin = [allow_origin] elif not multi(allow_origin): - raise TypeError(f'Request allow_origin must be None, "*" or a origin str / str list') + raise TypeError( + f'Request allow_origin must be None, "*" or a origin str / str list' + ) - if '*' in allow_origin: - self.allow_origins = ['*'] + if "*" in allow_origin: + self.allow_origins = ["*"] else: - self.allow_origins = [get_origin(origin) for origin in allow_origin if origin] + self.allow_origins = [ + get_origin(origin) for origin in allow_origin if origin + ] if self.allow_all_origin: if self.csrf_exempt is None: @@ -58,8 +69,17 @@ def __init__(self, if allow_errors and not multi(allow_errors): allow_errors = [allow_errors] - self.allow_errors = tuple([e for e in allow_errors if isinstance(e, type) and issubclass(e, Exception)]) \ - if allow_errors else None + self.allow_errors = ( + tuple( + [ + e + for e in allow_errors + if isinstance(e, type) and issubclass(e, Exception) + ] + ) + if allow_errors + else None + ) self.allow_headers: list = allow_headers or [] self.cors_max_age = cors_max_age self.expose_headers = expose_headers @@ -70,39 +90,44 @@ def __init__(self, @property def allow_all_origin(self): - return self.allow_origins and '*' in self.allow_origins + return self.allow_origins and "*" in self.allow_origins @property def allow_all_headers(self): - return self.allow_headers and '*' in self.allow_headers + return self.allow_headers and "*" in self.allow_headers def __call__(self, func, *args, **kwargs): from ..base import API + if inspect.isclass(func) and issubclass(func, API): - if not hasattr(func, 'response'): + if not hasattr(func, "response"): from utilmeta.core.response import Response + func.response = Response return super().__call__(func, *args, **kwargs) def process_request(self, request: Request): from utilmeta import service + if request.origin != service.origin: # origin cross settings is above all other request control settings # so that when request error occur, the cross-origin settings can take effect # so that client see a valid error message instead of a CORS error if not self.allow_origins: - raise exc.PermissionDenied(f'Invalid request origin: {request.origin}') + raise exc.PermissionDenied(f"Invalid request origin: {request.origin}") else: if not self.allow_all_origin: if request.origin not in self.allow_origins: - raise exc.PermissionDenied(f'Invalid request origin: {request.origin}') + raise exc.PermissionDenied( + f"Invalid request origin: {request.origin}" + ) if self.gen_csrf_token: request.adaptor.gen_csrf_token() elif not self.csrf_exempt: # only check csrf for from_api requests if not request.adaptor.check_csrf_token(): - raise exc.PermissionDenied(f'CSRF token missing or incorrect') + raise exc.PermissionDenied(f"CSRF token missing or incorrect") def cors_required(self, request: Request) -> bool: """ @@ -115,6 +140,7 @@ def cors_required(self, request: Request) -> bool: if request.is_options or self.allow_all_origin: return True from utilmeta import service + if request.origin != service.origin: return True return self.allow_origins and request.origin not in self.allow_origins @@ -134,24 +160,35 @@ def process_response(self, response: Response): if not self.override: return response if self.cors_required(request): - response.update_headers(**{ - Header.ALLOW_ORIGIN: request.origin or '*', - Header.ALLOW_CREDENTIALS: 'true', - Header.ALLOW_METHODS: ','.join(set([m.upper() for m in var.allow_methods.getter(request)])), - }) + response.update_headers( + **{ + Header.ALLOW_ORIGIN: request.origin or "*", + Header.ALLOW_CREDENTIALS: "true", + Header.ALLOW_METHODS: ",".join( + set([m.upper() for m in var.allow_methods.getter(request)]) + ), + } + ) if request.is_options: if self.allow_all_headers: - response.set_header(Header.ALLOW_HEADERS, '*') + response.set_header(Header.ALLOW_HEADERS, "*") else: # request_headers = [h.strip().lower() for h in # request.headers.get(Header.OPTIONS_HEADERS, '').split(',')] allow_headers = list(self.allow_headers or []) - allow_headers.extend([h.lower() for h in var.allow_headers.getter(request)]) + allow_headers.extend( + [h.lower() for h in var.allow_headers.getter(request)] + ) if allow_headers: - response.set_header(Header.ALLOW_HEADERS, ','.join(allow_headers)) + response.set_header( + Header.ALLOW_HEADERS, ",".join(allow_headers) + ) if self.expose_headers: - response.set_header(Header.EXPOSE_HEADERS, ','.join(set([h.lower() for h in self.expose_headers]))) + response.set_header( + Header.EXPOSE_HEADERS, + ",".join(set([h.lower() for h in self.expose_headers])), + ) if self.cors_max_age: response.set_header(Header.ACCESS_MAX_AGE, self.cors_max_age) return response @@ -163,7 +200,7 @@ def handle_error(self, error, api=None): return # if error is uncaught if api: - make_response = getattr(api, '_make_response', None) + make_response = getattr(api, "_make_response", None) # this is a rather ugly hack, maybe we will figure out something nicer or universal # because we need to postpone the response process if callable(make_response): @@ -172,12 +209,14 @@ def handle_error(self, error, api=None): @wraps(make_response) def _make_response(response, force: bool = False): return self.process_response(make_response(response, force)) + api._make_response = _make_response return # process with error hooks # response = api._make_response(api._handle_error(error)) # if error is raised here - return self.process_response((getattr(api, 'response', None) or Response)( - error=error, - request=error.request - )) + return self.process_response( + (getattr(api, "response", None) or Response)( + error=error, request=error.request + ) + ) diff --git a/utilmeta/core/api/plugins/retry.py b/utilmeta/core/api/plugins/retry.py index c84bef1..21f3160 100644 --- a/utilmeta/core/api/plugins/retry.py +++ b/utilmeta/core/api/plugins/retry.py @@ -3,7 +3,14 @@ from utype.types import * from utilmeta.utils.error import Error from utilmeta.utils import exceptions -from utilmeta.utils import multi, class_func, time_now, get_interval, awaitable, DEFAULT_RETRY_ON_STATUSES +from utilmeta.utils import ( + multi, + class_func, + time_now, + get_interval, + awaitable, + DEFAULT_RETRY_ON_STATUSES, +) from utype.parser.func import FunctionParser import random from utype.types import Float @@ -20,22 +27,25 @@ class RetryPlugin(APIPlugin): DEFAULT_RETRY_ON_ERRORS = (Exception,) DEFAULT_RETRY_AFTER_HEADERS = () - def __init__(self, - max_retries: int = 1, - max_retries_timeout: Union[float, int, timedelta] = None, - retry_interval: Union[float, int, timedelta, List[float], List[int], - List[timedelta], Callable] = None, - # a value: 1 / 15.5 / timedelta(seconds=3.5) - # a callable: will take 2 optional params (current_retries, max_retries) - # a list of values: [1, 3, 10, 15, 30], will be mapped to each retries - retry_timeout: Union[float, int, timedelta, List[float], List[int], - List[timedelta], Callable] = None, - retry_delta_ratio: float = None, - retry_on_errors: List[Type[Exception]] = None, - retry_on_statuses: List[int] = DEFAULT_RETRY_ON_STATUSES, - retry_on_idempotent_only: bool = None, - retry_after_headers: Union[str, List[str]] = None, - ): + def __init__( + self, + max_retries: int = 1, + max_retries_timeout: Union[float, int, timedelta] = None, + retry_interval: Union[ + float, int, timedelta, List[float], List[int], List[timedelta], Callable + ] = None, + # a value: 1 / 15.5 / timedelta(seconds=3.5) + # a callable: will take 2 optional params (current_retries, max_retries) + # a list of values: [1, 3, 10, 15, 30], will be mapped to each retries + retry_timeout: Union[ + float, int, timedelta, List[float], List[int], List[timedelta], Callable + ] = None, + retry_delta_ratio: float = None, + retry_on_errors: List[Type[Exception]] = None, + retry_on_statuses: List[int] = DEFAULT_RETRY_ON_STATUSES, + retry_on_idempotent_only: bool = None, + retry_after_headers: Union[str, List[str]] = None, + ): super().__init__(locals()) self.max_retries = max_retries @@ -49,23 +59,29 @@ def __init__(self, self.retry_delta_ratio = retry_delta_ratio if retry_on_errors and not multi(retry_on_errors): retry_on_errors = [retry_on_errors] - self.retry_on_errors = retry_on_errors or self.DEFAULT_RETRY_ON_ERRORS # can be inherited + self.retry_on_errors = ( + retry_on_errors or self.DEFAULT_RETRY_ON_ERRORS + ) # can be inherited if retry_on_statuses and not multi(retry_on_statuses): retry_on_statuses = [retry_on_statuses] self.retry_on_statuses = retry_on_statuses or DEFAULT_RETRY_ON_STATUSES self.retry_on_idempotent_only = retry_on_idempotent_only if retry_after_headers and not multi(retry_after_headers): retry_after_headers = [retry_after_headers] - self.retry_after_headers = retry_after_headers or self.DEFAULT_RETRY_AFTER_HEADERS + self.retry_after_headers = ( + retry_after_headers or self.DEFAULT_RETRY_AFTER_HEADERS + ) # self.max_retry_after = max_retry_after - def whether_retry(self, request: Request = None, response: Response = None, error: Error = None): + def whether_retry( + self, request: Request = None, response: Response = None, error: Error = None + ): """ Can inherit and custom, like base on response's header values """ if self.retry_on_idempotent_only: if request: - idempotent = request.adaptor.get_context('idempotent') + idempotent = request.adaptor.get_context("idempotent") if not idempotent: return False else: @@ -81,10 +97,12 @@ def whether_retry(self, request: Request = None, response: Response = None, erro return False def process_request(self, request: Request): - current_retry = request.adaptor.get_context('retry_index') or 0 + current_retry = request.adaptor.get_context("retry_index") or 0 if current_retry >= self.max_retries: - raise self.max_retries_error_cls(f'{self.__class__}: max_retries: {self.max_retries} exceeded', - max_retries=self.max_retries) + raise self.max_retries_error_cls( + f"{self.__class__}: max_retries: {self.max_retries} exceeded", + max_retries=self.max_retries, + ) self.handle_max_retries_timeout(request, set_timeout=True) return request @@ -97,26 +115,34 @@ def handle_max_retries_timeout(self, request: Request, set_timeout: bool = False if delta <= 0: # max retries time exceeded raise self.max_retries_timeout_error_cls( - f'{self.__class__}: max_retries_timeout exceed for {abs(delta)} seconds', - max_retries_timeout=self.max_retries_timeout) + f"{self.__class__}: max_retries_timeout exceed for {abs(delta)} seconds", + max_retries_timeout=self.max_retries_timeout, + ) # reset request timeout if set_timeout: if self.retry_timeout: - current_retry = request.adaptor.get_context('retry_index') or 0 + current_retry = request.adaptor.get_context("retry_index") or 0 timeout = self.retry_timeout if callable(timeout): - timeout = timeout(current_retry, self.max_retries, self.max_retries_timeout) + timeout = timeout( + current_retry, self.max_retries, self.max_retries_timeout + ) if multi(timeout): timeout = timeout[min(len(timeout) - 1, current_retry)] timeout = get_interval(timeout, null=True) if timeout: if self.retry_delta_ratio: - timeout = timeout + (random.random() * 2 - 1) * self.retry_delta_ratio * timeout + timeout = ( + timeout + + (random.random() * 2 - 1) + * self.retry_delta_ratio + * timeout + ) request.adaptor.update_context(timeout=timeout) - to = request.adaptor.get_context('timeout') + to = request.adaptor.get_context("timeout") if not to or to > delta: request.adaptor.update_context(timeout=delta) @@ -124,7 +150,7 @@ def process_response(self, response: Response): request = response.request if not request: return response - current_retry = request.get_context('retry_index') or 0 + current_retry = request.get_context("retry_index") or 0 if current_retry + 1 >= self.max_retries: # cannot retry return response @@ -140,7 +166,7 @@ async def process_response(self, response: Response): request = response.request if not request: return response - current_retry = request.get_context('retry_index') or 0 + current_retry = request.get_context("retry_index") or 0 if current_retry + 1 >= self.max_retries: # cannot retry return response @@ -152,13 +178,16 @@ async def process_response(self, response: Response): return request # return request to make SDK retry this request def handle_error(self, e: Error): - if isinstance(e.exception, (self.max_retries_error_cls, self.max_retries_timeout_error_cls)): + if isinstance( + e.exception, + (self.max_retries_error_cls, self.max_retries_timeout_error_cls), + ): return request = e.request if not request: # raise error return - current_retry = request.adaptor.get_context('retry_index') or 0 + current_retry = request.adaptor.get_context("retry_index") or 0 if current_retry + 1 >= self.max_retries: # cannot retry return # proceed to handle error instead of raise @@ -171,13 +200,16 @@ def handle_error(self, e: Error): @awaitable(handle_error) async def handle_error(self, e: Error): - if isinstance(e.exception, (self.max_retries_error_cls, self.max_retries_timeout_error_cls)): + if isinstance( + e.exception, + (self.max_retries_error_cls, self.max_retries_timeout_error_cls), + ): return request = e.request if not request: # raise error return - current_retry = request.adaptor.get_context('retry_index') or 0 + current_retry = request.adaptor.get_context("retry_index") or 0 if current_retry + 1 >= self.max_retries: # cannot retry return # proceed to handle error instead of raise @@ -188,7 +220,9 @@ async def handle_error(self, e: Error): self.handle_max_retries_timeout(request, set_timeout=False) return request - def get_retry_after(self, request: Request, response: Response = None) -> Optional[float]: + def get_retry_after( + self, request: Request, response: Response = None + ) -> Optional[float]: """ Treat 429 specially, cause the headers/data may indicate the next_request_time """ @@ -198,7 +232,7 @@ def get_retry_after(self, request: Request, response: Response = None) -> Option current_time = time_now() start_time = request.time passed = (current_time - start_time).total_seconds() - current_retry = request.adaptor.get_context('retry_index') or 0 + current_retry = request.adaptor.get_context("retry_index") or 0 retry_after = None if response: @@ -226,14 +260,19 @@ def get_retry_after(self, request: Request, response: Response = None) -> Option return 0 if callable(retry_after): # function parser can consume - retry_after = retry_after(current_retry, self.max_retries, self.max_retries_timeout) + retry_after = retry_after( + current_retry, self.max_retries, self.max_retries_timeout + ) if multi(retry_after): retry_after = retry_after[min(len(retry_after) - 1, current_retry)] retry_after = get_interval(retry_after, null=True) if isinstance(retry_after, (int, float)): if self.retry_delta_ratio: - retry_after = retry_after + (random.random() * 2 - 1) * self.retry_delta_ratio * retry_after + retry_after = ( + retry_after + + (random.random() * 2 - 1) * self.retry_delta_ratio * retry_after + ) if self.max_retries_timeout: # cannot wait longer than the max timeout @@ -249,14 +288,18 @@ def handle_retry_after(self, request: Request, response: Response = None): return False if retry_after: import time + time.sleep(retry_after) return True - async def async_handle_retry_after(self, request: Request, response: Response = None): + async def async_handle_retry_after( + self, request: Request, response: Response = None + ): retry_after = self.get_retry_after(request, response) if retry_after is None: return False if retry_after: import asyncio + await asyncio.sleep(retry_after) return True diff --git a/utilmeta/core/api/route.py b/utilmeta/core/api/route.py index 1095c4f..68429e9 100644 --- a/utilmeta/core/api/route.py +++ b/utilmeta/core/api/route.py @@ -1,6 +1,15 @@ import re from typing import Union, Dict, Type, List, Optional, TYPE_CHECKING -from utilmeta.utils import awaitable, get_doc, regular, duplicate, pop, distinct_add, multi, PATH_REGEX +from utilmeta.utils import ( + awaitable, + get_doc, + regular, + duplicate, + pop, + distinct_add, + multi, + PATH_REGEX, +) import inspect from functools import partial @@ -14,25 +23,27 @@ class BaseRoute: - def __init__(self, - handler, - route: Union[str, tuple], - name: str, - parent=None, - before_hooks: List[BeforeHook] = (), - after_hooks: List[AfterHook] = (), - error_hooks: Dict[Type[Exception], ErrorHook] = None): + def __init__( + self, + handler, + route: Union[str, tuple], + name: str, + parent=None, + before_hooks: List[BeforeHook] = (), + after_hooks: List[AfterHook] = (), + error_hooks: Dict[Type[Exception], ErrorHook] = None, + ): self.name = name self.handler = handler self.parent = parent if isinstance(route, str): - route = str(route).strip('/') + route = str(route).strip("/") elif isinstance(route, tuple): route = self.from_routes(*route) else: - raise TypeError(f'Invalid route: {route}') + raise TypeError(f"Invalid route: {route}") self.route = route self.before_hooks = before_hooks or [] @@ -42,7 +53,7 @@ def __init__(self, @classmethod def from_routes(cls, *routes): # meant to be inherited - return '/'.join([str(v).strip('/') for v in routes]) + return "/".join([str(v).strip("/") for v in routes]) def match_targets(self, targets: list): if self.route in targets: @@ -95,14 +106,14 @@ def clone(self): parent=self.parent, before_hooks=list(self.before_hooks), after_hooks=list(self.after_hooks), - error_hooks=dict(self.error_hooks) + error_hooks=dict(self.error_hooks), ) @property def no_hooks(self): return not self.before_hooks and not self.after_hooks and not self.error_hooks - def merge_hooks(self, route: 'BaseRoute'): + def merge_hooks(self, route: "BaseRoute"): # self: near # route: far if not route or not isinstance(route, BaseRoute): @@ -120,38 +131,48 @@ def merge_hooks(self, route: 'BaseRoute'): class APIRoute(BaseRoute): PATH_REGEX = PATH_REGEX - DEFAULT_PATH_REGEX = '[^/]+' - - def __init__(self, - handler: Union[Type['API'], Endpoint], - route: Union[str, tuple], - name: str, - parent: Type['API'] = None, - summary: str = None, - description: str = None, - tags: list = None, - deprecated: bool = None, - private: bool = None, - priority: int = None, - before_hooks: List[BeforeHook] = (), - after_hooks: List[AfterHook] = (), - error_hooks: Dict[Type[Exception], ErrorHook] = None, **kwargs): + DEFAULT_PATH_REGEX = "[^/]+" + + def __init__( + self, + handler: Union[Type["API"], Endpoint], + route: Union[str, tuple], + name: str, + parent: Type["API"] = None, + summary: str = None, + description: str = None, + tags: list = None, + deprecated: bool = None, + private: bool = None, + priority: int = None, + before_hooks: List[BeforeHook] = (), + after_hooks: List[AfterHook] = (), + error_hooks: Dict[Type[Exception], ErrorHook] = None, + **kwargs, + ): self.method = None from .base import API + if isinstance(handler, Endpoint): self.method = handler.method if handler.is_method and route: - raise ValueError(f'Endpoint method: <{self.method}> (with HTTP method name) ' - f'cannot assign route: {repr(route)}, please use another function name') + raise ValueError( + f"Endpoint method: <{self.method}> (with HTTP method name) " + f"cannot assign route: {repr(route)}, please use another function name" + ) if not route: route = handler.route elif inspect.isclass(handler) and issubclass(handler, API): if not route: - raise ValueError(f'API handler: {handler} should specify a route, got empty') + raise ValueError( + f"API handler: {handler} should specify a route, got empty" + ) else: - raise TypeError(f'invalid api class or function: {handler}, must be a ' - f'Endpoint instance of subclass of API') + raise TypeError( + f"invalid api class or function: {handler}, must be a " + f"Endpoint instance of subclass of API" + ) super().__init__( handler, @@ -160,7 +181,7 @@ def __init__(self, parent=parent, before_hooks=before_hooks, after_hooks=after_hooks, - error_hooks=error_hooks + error_hooks=error_hooks, ) self.kwargs = kwargs @@ -168,7 +189,7 @@ def __init__(self, self.description = description or get_doc(handler) self.tags = tags self.deprecated = deprecated - self.private = private or handler.__name__.startswith('_') + self.private = private or handler.__name__.startswith("_") self.priority = priority self.regex_list = [] self.kwargs_regex = {} @@ -185,13 +206,15 @@ def init_headers(self): for key, val in self.handler._properties.items(): name = val.field.name.lower() - if getattr(val.prop.__in__, '__ident__', None) == 'header': + if getattr(val.prop.__in__, "__ident__", None) == "header": if name not in self.header_names: self.header_names.append(name) else: - headers = getattr(val.prop, 'headers', None) + headers = getattr(val.prop, "headers", None) if headers and multi(headers): - distinct_add(self.header_names, [str(v).lower() for v in headers]) + distinct_add( + self.header_names, [str(v).lower() for v in headers] + ) def get_field(self, name: str) -> Optional[Field]: if isinstance(self.handler, Endpoint): @@ -201,7 +224,7 @@ def get_field(self, name: str) -> Optional[Field]: return None field = getattr(self.handler, name, None) if isinstance(field, property): - field = getattr(field.fget, '__field__', None) + field = getattr(field.fget, "__field__", None) if isinstance(field, Field): return field elif isinstance(field, type) and issubclass(field, Field): @@ -221,15 +244,15 @@ def options(self): def get_patterns(self): pattern = self.route for arg in self.path_args: - pattern = pattern.replace('{%s}' % arg, self.DEFAULT_PATH_REGEX) + pattern = pattern.replace("{%s}" % arg, self.DEFAULT_PATH_REGEX) patterns = [pattern] if not self.is_endpoint: - patterns.append(f'{pattern}/.*') + patterns.append(f"{pattern}/.*") return patterns @classmethod def get_pattern(cls, path: str): - path = path.strip('/') + path = path.strip("/") params: List[str] = cls.PATH_REGEX.findall(path) if not params: @@ -237,13 +260,13 @@ def get_pattern(cls, path: str): pattern = path for arg in params: - pattern = pattern.replace('{%s}' % arg, cls.DEFAULT_PATH_REGEX) + pattern = pattern.replace("{%s}" % arg, cls.DEFAULT_PATH_REGEX) return re.compile(pattern) def compile_route(self): if not self.route: - self.regex_list = [re.compile('')] + self.regex_list = [re.compile("")] return regs = [] @@ -254,19 +277,21 @@ def compile_route(self): regs = [re.compile(self.route)] if not self.is_endpoint: # for API - regs.append(re.compile(f'{self.route}/(?P<_>.*)')) + regs.append(re.compile(f"{self.route}/(?P<_>.*)")) self.regex_list = regs return d = duplicate(params) - assert not d, f"Endpoint path: {repr(self.route)} shouldn't contains duplicate param {d}" + assert ( + not d + ), f"Endpoint path: {repr(self.route)} shouldn't contains duplicate param {d}" omit = None # required path params must before optional ones # omit is the mark of the first optional path param beg = 0 divider = [] - suffix = '' + suffix = "" for i, p in enumerate(params): field = self.get_field(p) @@ -276,51 +301,57 @@ def compile_route(self): if field: break if not field: - raise ValueError(f'missing path name parameter: {repr(p)}') + raise ValueError(f"missing path name parameter: {repr(p)}") if self.is_endpoint: if not field.required: if omit is None: omit = p elif omit is not None: - raise ValueError(f"Required path argument ({repr(p)}) is after a optional arg " - f"({omit}), which is invalid") + raise ValueError( + f"Required path argument ({repr(p)}) is after a optional arg " + f"({omit}), which is invalid" + ) - sub = '{%s}' % p + sub = "{%s}" % p end = self.route.find(sub) div = self.route[beg:end] if beg and not div: - raise ValueError(f"Endpoint path: {repr(self.route)} param {repr(sub)}" - f" should divide each other with string") + raise ValueError( + f"Endpoint path: {repr(self.route)} param {repr(sub)}" + f" should divide each other with string" + ) divider.append((div, p)) beg = end + len(sub) if beg < len(self.route): suffix = self.route[beg:] - pattern = '' + pattern = "" for div, param in divider: path_field = self.get_field(param) # use ducked attribute here - path_regex = getattr(path_field, 'regex', self.DEFAULT_PATH_REGEX) + path_regex = getattr(path_field, "regex", self.DEFAULT_PATH_REGEX) div = regular(div) pattern += div - if self.is_endpoint and (omit and param == str(omit) or param != str(omit) and regs): + if self.is_endpoint and ( + omit and param == str(omit) or param != str(omit) and regs + ): # until omit param, do not add pattern # after omit param, every param should add a pattern - regs.append(re.compile(pattern.rstrip('/'))) + regs.append(re.compile(pattern.rstrip("/"))) # omit does apply for API kwargs_reg[param] = path_regex - pattern += f'(?P<{param}>{path_regex})' + pattern += f"(?P<{param}>{path_regex})" pattern += suffix regs.append(re.compile(pattern)) if not self.is_endpoint: # for API - regs.append(re.compile(f'{pattern}/(?P<_>.*)')) + regs.append(re.compile(f"{pattern}/(?P<_>.*)")) regs.reverse() # reverse the reg list so the longest reg match the path first, @@ -336,19 +367,23 @@ def is_endpoint(self): @property def ident(self): if self.method: - return f'{self.method}:{self.route}'.lower() + return f"{self.method}:{self.route}".lower() return self.route def make_property(self): # if with_hooks: # pass if self.is_endpoint: - def getter(api_inst: 'API'): + + def getter(api_inst: "API"): return partial(self.handler, api_inst) + else: - def getter(api_inst: 'API'): + + def getter(api_inst: "API"): return self.handler(api_inst.request) # api._init_plugins() + return property(getter) def match_route(self, request: Request): @@ -366,7 +401,7 @@ def match_route(self, request: Request): if not self.method: # only set path params if route is API # endpoints need to match for multiple methods - route_attr.set(pop(group, '_', '')) + route_attr.set(pop(group, "_", "")) if not self.method or self.method == request.method: # set path params for endpoint and API in every match @@ -385,7 +420,7 @@ def match_route(self, request: Request): return True return False - def serve(self, api: 'API'): + def serve(self, api: "API"): # --- names_var = var.operation_names.setup(api.request) names = names_var.get() or [] @@ -410,7 +445,7 @@ def serve(self, api: 'API'): return result - async def aserve(self, api: 'API'): + async def aserve(self, api: "API"): # --- names_var = var.operation_names.setup(api.request) names = await names_var.get() or [] @@ -435,9 +470,9 @@ async def aserve(self, api: 'API'): return result - def __call__(self, api: 'API'): + def __call__(self, api: "API"): return self.serve(api) @awaitable(__call__) - async def __call__(self, api: 'API'): + async def __call__(self, api: "API"): return await self.aserve(api) diff --git a/utilmeta/core/api/specs/base.py b/utilmeta/core/api/specs/base.py index 4d8170f..1cdf31d 100644 --- a/utilmeta/core/api/specs/base.py +++ b/utilmeta/core/api/specs/base.py @@ -8,7 +8,7 @@ class BaseAPISpec: spec = None __version__ = None - def __init__(self, service: 'UtilMeta'): + def __init__(self, service: "UtilMeta"): self.service = service self.format = format diff --git a/utilmeta/core/api/specs/openapi.py b/utilmeta/core/api/specs/openapi.py index 6f4fa4d..473f86e 100644 --- a/utilmeta/core/api/specs/openapi.py +++ b/utilmeta/core/api/specs/openapi.py @@ -34,37 +34,39 @@ if TYPE_CHECKING: from utilmeta import UtilMeta -MULTIPART = 'multipart/form-data' +MULTIPART = "multipart/form-data" def guess_content_type(schema: dict): if not schema: return JSON - type = schema.get('type') - format = schema.get('format') + type = schema.get("type") + format = schema.get("format") - if type in ('object', 'array'): + if type in ("object", "array"): return JSON - if schema.get('$ref'): + if schema.get("$ref"): return JSON - if format == 'binary': + if format == "binary": return OCTET_STREAM return PLAIN -def get_operation_id(method: str, path: str, excludes: list = (), attribute: bool = False): +def get_operation_id( + method: str, path: str, excludes: list = (), attribute: bool = False +): ident = f'{method.lower()}:{path.strip("/")}' if attribute: - ident = re.sub('[^A-Za-z0-9]+', '_', ident).strip('_') + ident = re.sub("[^A-Za-z0-9]+", "_", ident).strip("_") if excludes: i = 1 origin = ident while ident in excludes: - ident = f'{origin}_{i}' + ident = f"{origin}_{i}" i += 1 return ident @@ -81,10 +83,15 @@ def generate_for_field(self, f: ParserField, options=None): data.update(accept=t.accept) if isinstance(t, LogicalType) and f.discriminator_map: # not part of json-schema, but in OpenAPI - data.update(discriminator=dict( - propertyName=f.field.discriminator, - mapping={k: self.generate_for_type(v) for k, v in f.discriminator_map.items()} - )) + data.update( + discriminator=dict( + propertyName=f.field.discriminator, + mapping={ + k: self.generate_for_type(v) + for k, v in f.discriminator_map.items() + }, + ) + ) return data def get_ref_object(self, ref: str): @@ -98,7 +105,7 @@ def get_ref_schema(self, ref: str): def get_schema(self, schema: dict): if not schema or not isinstance(schema, dict): return None - ref = schema.get('$ref') + ref = schema.get("$ref") if ref: return self.get_ref_schema(ref) return schema @@ -106,39 +113,39 @@ def get_schema(self, schema: dict): def get_body_content_type(self, body_schema: dict): if not body_schema or not isinstance(body_schema, dict): return None - ref = body_schema.get('$ref') + ref = body_schema.get("$ref") if ref: body_schema = self.get_ref_schema(ref) if not body_schema or not isinstance(body_schema, dict): return None - if body_schema.get('type') == 'object': - for key, field in body_schema.get('properties', {}).items(): - if field.get('format') == 'binary': + if body_schema.get("type") == "object": + for key, field in body_schema.get("properties", {}).items(): + if field.get("format") == "binary": return MULTIPART - if field.get('type') == 'array': - if field.get('items', {}).get('format') == 'binary': + if field.get("type") == "array": + if field.get("items", {}).get("format") == "binary": return MULTIPART return JSON return guess_content_type(body_schema) def generate_for_response(self, response: Type[Response]): - parser = getattr(response, '__parser__', None) - result_field = parser.get_field('result') if parser else None - headers_field = parser.get_field('headers') if parser else None + parser = getattr(response, "__parser__", None) + result_field = parser.get_field("result") if parser else None + headers_field = parser.get_field("headers") if parser else None result_schema = self.generate_for_field(result_field) if result_field else None - headers_schema = self.__class__(headers_field.type, output=True)() \ - if headers_field and headers_field.type != Headers else {} + headers_schema = ( + self.__class__(headers_field.type, output=True)() + if headers_field and headers_field.type != Headers + else {} + ) # headers is different, doesn't need to generate $ref - headers_props = headers_schema.get('properties') or {} - headers_required = headers_schema.get('required') or [] + headers_props = headers_schema.get("properties") or {} + headers_required = headers_schema.get("required") or [] headers = {} for key, val_schema in headers_props.items(): - headers[key] = { - 'schema': val_schema, - 'required': key in headers_required - } + headers[key] = {"schema": val_schema, "required": key in headers_required} content_type = response.content_type # todo: headers wrapped @@ -147,36 +154,36 @@ def generate_for_response(self, response: Type[Response]): keys = {} if response.result_key: props[response.result_key] = result_schema - keys.update({'x-response-result-key': response.result_key}) + keys.update({"x-response-result-key": response.result_key}) if response.message_key: msg = dict(self.generate_for_type(str)) msg.update( - title='Message', - description='an error message of response', + title="Message", + description="an error message of response", ) props[response.message_key] = msg - keys.update({'x-response-message-key': response.message_key}) + keys.update({"x-response-message-key": response.message_key}) if response.state_key: state = dict(self.generate_for_type(str)) state.update( - title='State', - description='action state code of response', + title="State", + description="action state code of response", ) props[response.state_key] = state - keys.update({'x-response-state-key': response.state_key}) + keys.update({"x-response-state-key": response.state_key}) if response.count_key: cnt = dict(self.generate_for_type(int)) cnt.update( - title='Count', - description='a count of the total number of query result', + title="Count", + description="a count of the total number of query result", ) props[response.count_key] = cnt - keys.update({'x-response-count-key': response.count_key}) + keys.update({"x-response-count-key": response.count_key}) data_schema = { - 'type': 'object', - 'properties': props, - 'required': list(props) + "type": "object", + "properties": props, + "required": list(props), } if keys: data_schema.update(keys) @@ -187,16 +194,14 @@ def generate_for_response(self, response: Type[Response]): content_type = guess_content_type(data_schema) response_schema = dict( - content={content_type: { - 'schema': data_schema - }}, + content={content_type: {"schema": data_schema}}, ) if headers: response_schema.update(headers=headers) if response.description: response_schema.update(description=response.description) if response.name: - response_schema.update({'x-response-name': response.name}) + response_schema.update({"x-response-name": response.name}) return response_schema @@ -204,15 +209,15 @@ def generate_for_response(self, response: Type[Response]): class OpenAPIInfo(Schema): title: str version: str - description: str = Field(default='') - term_of_service: str = Field(alias='termsOfService', alias_from=['tos'], default='') + description: str = Field(default="") + term_of_service: str = Field(alias="termsOfService", alias_from=["tos"], default="") contact: dict = Field(default_factory=dict) license: dict = Field(default_factory=dict) class ServerSchema(Schema): url: str - description: str = Field(default='') + description: str = Field(default="") variables: dict = Field(default_factory=dict) @@ -245,21 +250,23 @@ class OpenAPISchema(Schema): class OpenAPI(BaseAPISpec): - spec = 'openapi' - __version__ = '3.1.0' + spec = "openapi" + __version__ = "3.1.0" generator_cls = OpenAPIGenerator schema_cls = OpenAPISchema - FORMATS = ['json', 'yaml'] - PARAMS_IN = ['path', 'query', 'header', 'cookie'] + FORMATS = ["json", "yaml"] + PARAMS_IN = ["path", "query", "header", "cookie"] URL_FETCH_TIMEOUT = 5 # None -> dict # json -> json string # yml -> yml string - def __init__(self, service: 'UtilMeta', - external_docs: Union[str, dict, Callable] = None, - base_url: str = None, - ): + def __init__( + self, + service: "UtilMeta", + external_docs: Union[str, dict, Callable] = None, + base_url: str = None, + ): super().__init__(service) self.defs = {} self.names = {} @@ -337,13 +344,13 @@ def merge_openapi_docs(self, *docs: dict) -> OpenAPISchema: security.extend(doc.security) for tag in doc.tags: - tag_name = tag.get('name') if isinstance(tag, dict) else str(tag) + tag_name = tag.get("name") if isinstance(tag, dict) else str(tag) if not tag_name: continue if tag_name in tag_names: continue tags.append(tag_name) - tags.append(tag if isinstance(tag, dict) else {'name': tag_name}) + tags.append(tag if isinstance(tag, dict) else {"name": tag_name}) for key, val in doc.items(): if key not in self.schema_cls.__parser__.fields: additions[key] = val @@ -355,7 +362,7 @@ def merge_openapi_docs(self, *docs: dict) -> OpenAPISchema: components=components, security=security, tags=tags, - **additions + **additions, ) def get_external_docs(self, external_docs) -> List[OpenAPISchema]: @@ -366,7 +373,9 @@ def get_external_docs(self, external_docs) -> List[OpenAPISchema]: try: docs = docs(self.service) except Exception as e: - warnings.warn(f'call external docs function: {external_docs} failed: {e}') + warnings.warn( + f"call external docs function: {external_docs} failed: {e}" + ) return [] if multi(docs): @@ -387,29 +396,32 @@ def get_external_docs(self, external_docs) -> List[OpenAPISchema]: docs = docs.decode() if file and file.filename and isinstance(docs, str): - if file.filename.endswith('.yaml') or file.filename.endswith('.yml'): + if file.filename.endswith(".yaml") or file.filename.endswith(".yml"): import yaml + docs = yaml.safe_load(docs) if isinstance(docs, dict): try: return [OpenAPISchema(docs)] except utype.exc.ParseError as e: - warnings.warn(f'parse external docs object failed: {e}') + warnings.warn(f"parse external docs object failed: {e}") return [] if isinstance(docs, str): if valid_url(docs): from urllib.request import urlopen from http.client import HTTPResponse + try: resp: HTTPResponse = urlopen(docs, timeout=self.URL_FETCH_TIMEOUT) except Exception as e: - warnings.warn(f'parse external docs url: {docs} failed: {e}') + warnings.warn(f"parse external docs url: {docs} failed: {e}") return [] if resp.status == 200: - content_type = resp.getheader('Content-Type') or '' - if 'yaml' in content_type or 'json' in content_type: + content_type = resp.getheader("Content-Type") or "" + if "yaml" in content_type or "json" in content_type: import yaml + obj = yaml.safe_load(resp.read()) else: obj = json.loads(resp.read()) @@ -418,12 +430,13 @@ def get_external_docs(self, external_docs) -> List[OpenAPISchema]: resp.close() elif os.path.exists(docs): try: - docs_content = open(docs, 'r', errors='ignore').read() + docs_content = open(docs, "r", errors="ignore").read() except Exception as e: - warnings.warn(f'parse external docs file: {docs} failed: {e}') + warnings.warn(f"parse external docs file: {docs} failed: {e}") return [] - if docs.endswith('.yaml') or docs.endswith('.yml'): + if docs.endswith(".yaml") or docs.endswith(".yml"): import yaml + obj = yaml.safe_load(docs_content) else: obj = json.loads(docs_content) @@ -434,15 +447,18 @@ def get_external_docs(self, external_docs) -> List[OpenAPISchema]: except json.JSONDecodeError: try: import yaml + obj = yaml.safe_load(docs) except Exception as e: - warnings.warn(f'parse external docs content failed with error: {e}') + warnings.warn( + f"parse external docs content failed with error: {e}" + ) return [] if obj: try: return [OpenAPISchema(obj)] except utype.exc.ParseError as e: - warnings.warn(f'parse external docs failed: {e}') + warnings.warn(f"parse external docs failed: {e}") return [] return [] @@ -453,51 +469,57 @@ def get_rel_paths(cls, paths: dict, current_base_url: str, base_url: str) -> dic # 1, base_url: http://127.0.0.1:8000 # 2, base_url: http://new.location.com/some/route # 3, base_url: http://new.location.com - if not current_base_url or not base_url or current_base_url == base_url or not paths: + if ( + not current_base_url + or not base_url + or current_base_url == base_url + or not paths + ): return paths - prefix = '' + prefix = "" prefix_strip = False # only support prefix if current_base_url.startswith(base_url): - prefix = current_base_url[len(base_url):] + prefix = current_base_url[len(base_url) :] elif base_url.startswith(current_base_url): - prefix = base_url[len(current_base_url):] + prefix = base_url[len(current_base_url) :] prefix_strip = True else: from urllib.parse import urlparse + current_parsed = urlparse(current_base_url) url_parsed = urlparse(base_url) if current_parsed.path.startswith(url_parsed.path): - prefix = current_parsed.path[len(url_parsed.path):] + prefix = current_parsed.path[len(url_parsed.path) :] elif url_parsed.path.startswith(current_parsed.path): - prefix = url_parsed.path[len(current_parsed.path):] + prefix = url_parsed.path[len(current_parsed.path) :] prefix_strip = True elif current_parsed.path: # todo: deal with this situation prefix = current_parsed.path - prefix = prefix.strip('/') + prefix = prefix.strip("/") if not prefix: return paths - prefix = '/' + prefix + prefix = "/" + prefix new_paths = {} for key, path in paths.items(): if prefix_strip: - key = '/' + str(key).lstrip('/') - if key == prefix or key.startswith(prefix + '/'): + key = "/" + str(key).lstrip("/") + if key == prefix or key.startswith(prefix + "/"): # prefix: /api # key: /api/articles -> /articles # /api/ ----------> / # /api -----------> / # /static --------> none - new_path = '/' + key[len(prefix):].lstrip('/') + new_path = "/" + key[len(prefix) :].lstrip("/") else: continue else: - if key.strip('/'): - new_path = prefix + '/' + str(key).lstrip('/') + if key.strip("/"): + new_path = prefix + "/" + str(key).lstrip("/") else: new_path = prefix new_paths[new_path] = path @@ -514,13 +536,21 @@ def __call__(self): adaptor_docs = OpenAPISchema(docs) if not adaptor_docs.servers: adaptor_docs.servers = [ - ServerSchema(url=url_join(get_origin(self.base_url), self.service.adaptor.root_path)) + ServerSchema( + url=url_join( + get_origin(self.base_url), + self.service.adaptor.root_path, + ) + ) ] except NotImplementedError: adaptor_docs = None except Exception as e: - warnings.warn(f'generate OpenAPI docs for [{self.service.backend_name}] failed: {e}') + warnings.warn( + f"generate OpenAPI docs for [{self.service.backend_name}] failed: {e}" + ) from utilmeta.utils import Error + err = Error(e) err.setup() print(err.full_info) @@ -537,7 +567,7 @@ def __call__(self): components=self.components, paths=paths, tags=list(self.tags.values()), - servers=[self.server] + servers=[self.server], ) docs = [utilmeta_docs] # even of no paths: some adaptor generate no server.url @@ -561,7 +591,7 @@ def components(self): return dict( schemas=self.get_defs(), responses=self.get_responses(), - securitySchemes=self.security_schemas + securitySchemes=self.security_schemas, ) @property @@ -574,12 +604,13 @@ def save(self, file: str): @classmethod def save_to(cls, schema, file: str): - if file.endswith('.yaml') or file.endswith('.yml'): + if file.endswith(".yaml") or file.endswith(".yml"): import yaml # requires pyyaml + content = yaml.dump(schema) else: content = json_dumps(schema, indent=4) - with open(file, mode='w', encoding='utf-8') as f: + with open(file, mode="w", encoding="utf-8") as f: f.write(content) if not os.path.isabs(file): @@ -615,17 +646,19 @@ def get(self): # _path = self.request.path if path: file_path = os.path.join(service.project_dir, path) - if path.endswith('.yml'): + if path.endswith(".yml"): import yaml # requires pyyaml + content = yaml.dump(_generated_document) else: content = json_dumps(_generated_document) - with open(file_path, 'w') as f: + with open(file_path, "w") as f: f.write(content) return content else: - if '.yaml' in self.request.path or '.yml' in self.request.path: + if ".yaml" in self.request.path or ".yml" in self.request.path: import yaml # requires pyyaml + content = yaml.dump(_generated_document) return content @@ -635,13 +668,13 @@ def get(self): @classmethod def _path_join(cls, *routes): - return '/' + '/'.join([str(r or '').strip('/') for r in routes]).rstrip('/') + return "/" + "/".join([str(r or "").strip("/") for r in routes]).rstrip("/") def generate_info(self) -> OpenAPIInfo: data = dict( title=self.service.title or self.service.name, - description=self.service.description or self.service.title or '', - version=self.service.version_str + description=self.service.description or self.service.title or "", + version=self.service.version_str, ) if self.service.info: data.update(self.service.info) @@ -650,7 +683,7 @@ def generate_info(self) -> OpenAPIInfo: def generate_paths(self): api = self.service.resolve() if not issubclass(api, API): - raise TypeError(f'Invalid root_api: {api}') + raise TypeError(f"Invalid root_api: {api}") # return self.from_api(api, path=self.service.root_url) return self.from_api(api) @@ -691,7 +724,7 @@ def get_response_name(self, response: Type[Response], names: list = ()): return k names = list(names) names.append(response.name or get_obj_name(response)) - return re.sub('[^A-Za-z0-9]+', '_', '_'.join(names)).strip('_') + return re.sub("[^A-Za-z0-9]+", "_", "_".join(names)).strip("_") def set_response(self, response: Type[Response], names: list = ()): name = self.get_response_name(response, names=names) @@ -709,20 +742,22 @@ def set_response(self, response: Type[Response], names: list = ()): # exact data return name # de-duplicate name - name += '_1' + name += "_1" self.responses[response] = data self.response_names[name] = response return name - def parse_properties(self, props: Dict[str, ParserProperty]) -> Tuple[list, dict, list]: + def parse_properties( + self, props: Dict[str, ParserProperty] + ) -> Tuple[list, dict, list]: params = [] media_types = {} body_params = {} body_form = False body_params_required = [] body_required = False - body_description = '' + body_description = "" auth_requirements = [] for key, prop_holder in props.items(): @@ -735,7 +770,7 @@ def parse_properties(self, props: Dict[str, ParserProperty]) -> Tuple[list, dict auth = None scope = [] if isinstance(prop, User): - scope = ['login'] + scope = ["login"] auth = prop.authentication if not prop.required: auth_requirements.append({}) @@ -763,29 +798,29 @@ def parse_properties(self, props: Dict[str, ParserProperty]) -> Tuple[list, dict else: _in = str(prop.__in__) - if _in == 'body': + if _in == "body": if field.is_required(generator.options): body_params_required.append(name) body_params[name] = field_schema if field_schema: - if field_schema.get('type') == 'array': - if field_schema.get('items', {}).get('format') == 'binary': + if field_schema.get("type") == "array": + if field_schema.get("items", {}).get("format") == "binary": body_form = True - elif field_schema.get('format') == 'binary': + elif field_schema.get("format") == "binary": body_form = True elif _in in self.PARAMS_IN: data = { - 'in': _in, - 'name': name, - 'required': field.required, + "in": _in, + "name": name, + "required": field.required, # prop may be injected - 'schema': field_schema, + "schema": field_schema, } if prop.description: - data['description'] = prop.description + data["description"] = prop.description if prop.deprecated: - data['deprecated'] = True + data["deprecated"] = True if isinstance(field.field, properties.RequestParam): if field.field.style: @@ -795,16 +830,14 @@ def parse_properties(self, props: Dict[str, ParserProperty]) -> Tuple[list, dict params.append(data) - elif prop.__ident__ == 'body': + elif prop.__ident__ == "body": schema = field_schema # treat differently - content_type = getattr(prop, 'content_type', None) + content_type = getattr(prop, "content_type", None) if not content_type: # guess content_type = generator.get_body_content_type(schema) or PLAIN - media_types[content_type] = { - 'schema': schema - } + media_types[content_type] = {"schema": schema} body_description = prop.description body_required = prop.required @@ -813,37 +846,41 @@ def parse_properties(self, props: Dict[str, ParserProperty]) -> Tuple[list, dict # should ex schema = field_schema prop_schema = generator.get_schema(schema) or {} - schema_type = prop_schema.get('type') + schema_type = prop_schema.get("type") - if not prop_schema or schema_type != 'object': - raise TypeError(f'Invalid object type: {field.type} for request property: ' - f'{repr(prop.__ident__)}, must be a object type, got {repr(schema_type)}') + if not prop_schema or schema_type != "object": + raise TypeError( + f"Invalid object type: {field.type} for request property: " + f"{repr(prop.__ident__)}, must be a object type, got {repr(schema_type)}" + ) - props = prop_schema.get('properties') or {} - required = prop_schema.get('required') or [] + props = prop_schema.get("properties") or {} + required = prop_schema.get("required") or [] for prop_name, value in props.items(): - params.append({ - 'in': prop.__ident__, - 'name': prop_name, - 'schema': value, - 'required': prop_name in required, - # 'style': 'form', - # 'explode': True - }) + params.append( + { + "in": prop.__ident__, + "name": prop_name, + "schema": value, + "required": prop_name in required, + # 'style': 'form', + # 'explode': True + } + ) if media_types: if body_params: generator = self.get_generator(None) for ct in list(media_types): - schema: dict = media_types[ct].get('schema') + schema: dict = media_types[ct].get("schema") if not schema: continue body_schema = dict(generator.get_schema(schema)) - body_props = body_schema.get('properties') or {} + body_props = body_schema.get("properties") or {} body_props.update(body_params) - body_schema['properties'] = body_props - media_types[ct]['schema'] = body_schema + body_schema["properties"] = body_props + media_types[ct]["schema"] = body_schema if body_form and ct != MULTIPART: media_types[MULTIPART] = media_types.pop(ct) @@ -853,43 +890,42 @@ def parse_properties(self, props: Dict[str, ParserProperty]) -> Tuple[list, dict content_type = MULTIPART if body_form else JSON media_types = { content_type: { - 'schema': { - 'type': 'object', - 'properties': body_params, - 'required': body_params_required + "schema": { + "type": "object", + "properties": body_params, + "required": body_params_required, } } } body = None if media_types: - body = dict( - content=media_types, - required=body_required - ) + body = dict(content=media_types, required=body_required) if body_description: body.update(description=body_description) return params, body, auth_requirements @property def default_status(self): - return str(self.pref.default_response_status or 'default') - - def from_endpoint(self, endpoint: Endpoint, - tags: list = (), - extra_params: list = None, - extra_body: dict = None, - response_cls: Type[Response] = None, - extra_responses: dict = None, - extra_requires: list = None - ) -> dict: + return str(self.pref.default_response_status or "default") + + def from_endpoint( + self, + endpoint: Endpoint, + tags: list = (), + extra_params: list = None, + extra_body: dict = None, + response_cls: Type[Response] = None, + extra_responses: dict = None, + extra_requires: list = None, + ) -> dict: # https://spec.openapis.org/oas/v3.1.0#operationObject operation_names = list(tags) + [endpoint.name] operation_id = endpoint.operation_id if not operation_id or operation_id in self.operations: - operation_id = '_'.join(operation_names) + operation_id = "_".join(operation_names) if operation_id in self.operations: - operation_id = endpoint.ref.replace('.', '_') + operation_id = endpoint.ref.replace(".", "_") self.operations.add(operation_id) # tags ----- @@ -911,12 +947,16 @@ def from_endpoint(self, endpoint: Endpoint, for resp in endpoint.response_types: resp_name = self.set_response(resp, names=operation_names) - responses[str(resp.status or self.default_status)] = {'$ref': f'#/components/responses/{resp_name}'} + responses[str(resp.status or self.default_status)] = { + "$ref": f"#/components/responses/{resp_name}" + } if response_cls and response_cls != Response: resp_name = self.set_response(response_cls, names=operation_names) - responses.setdefault(str(response_cls.status or self.default_status), - {'$ref': f'#/components/responses/{resp_name}'}) + responses.setdefault( + str(response_cls.status or self.default_status), + {"$ref": f"#/components/responses/{resp_name}"}, + ) if extra_params: # _params = dict(extra_params) @@ -931,16 +971,16 @@ def from_endpoint(self, endpoint: Endpoint, operationId=operation_id, tags=self.add_tags(tags), responses=dict(sorted(responses.items())), - security=self.merge_requires(extra_requires, requires) + security=self.merge_requires(extra_requires, requires), ) if params: operation.update(parameters=params) if body and endpoint.method in HAS_BODY_METHODS: operation.update(requestBody=body) if endpoint.idempotent is not None: - operation.update({'x-idempotent': endpoint.idempotent}) + operation.update({"x-idempotent": endpoint.idempotent}) if endpoint.ref: - operation.update({'x-ref': endpoint.ref}) + operation.update({"x-ref": endpoint.ref}) extension = endpoint.openapi_extension if extension: operation.update(extension) @@ -957,39 +997,52 @@ def add_tags(self, tags: list): if isinstance(tag, str): tag_name = tag elif isinstance(tag, dict): - tag_name = tag.get('name') + tag_name = tag.get("name") if not tag_name: continue tag_names.append(tag_name) if tag_name not in self.tags: - self.tags[tag_name] = tag if isinstance(tag, dict) else {'name': tag_name} + self.tags[tag_name] = ( + tag if isinstance(tag, dict) else {"name": tag_name} + ) return tag_names - def from_route(self, route: APIRoute, - *routes: str, - tags: list = (), - params: list = None, - response_cls: Type[Response] = None, - responses: dict = None, - requires: list = None) -> dict: + def from_route( + self, + route: APIRoute, + *routes: str, + tags: list = (), + params: list = None, + response_cls: Type[Response] = None, + responses: dict = None, + requires: list = None, + ) -> dict: # https://spec.openapis.org/oas/v3.1.0#pathItemObject new_routes = [*routes, route.route] if route.route else list(routes) new_tags = [*tags, route.name] if route.name else list(tags) # route_tags = route.get_tags() path = self._path_join(*new_routes) - route_data = {k: v for k, v in dict( - summary=route.summary, - description=route.description, - deprecated=route.deprecated - ).items() if v is not None} + route_data = { + k: v + for k, v in dict( + summary=route.summary, + description=route.description, + deprecated=route.deprecated, + ).items() + if v is not None + } extra_body = None extra_params = [] extra_requires = [] - extra_responses = dict(responses or {}) # the deeper (close to the api response) is prior + extra_responses = dict( + responses or {} + ) # the deeper (close to the api response) is prior # before hooks for before in route.before_hooks: - prop_params, body, before_requires = self.parse_properties(before.wrapper.properties) + prop_params, body, before_requires = self.parse_properties( + before.wrapper.properties + ) if body and not extra_body: extra_body = body extra_params.extend(prop_params) @@ -998,13 +1051,17 @@ def from_route(self, route: APIRoute, for after in route.after_hooks: for rt in after.response_types: resp_name = self.set_response(rt, names=list(tags)) - extra_responses[str(rt.status or self.default_status)] = {'$ref': f'#/components/responses/{resp_name}'} + extra_responses[str(rt.status or self.default_status)] = { + "$ref": f"#/components/responses/{resp_name}" + } for error, hook in route.error_hooks.items(): for rt in hook.response_types: resp_name = self.set_response(rt, names=list(tags)) - status = rt.status or getattr(error, 'status', None) or 'default' - extra_responses.setdefault(str(status), {'$ref': f'#/components/responses/{resp_name}'}) + status = rt.status or getattr(error, "status", None) or "default" + extra_responses.setdefault( + str(status), {"$ref": f"#/components/responses/{resp_name}"} + ) # set default. because error hooks is not triggered by default path_data = {} @@ -1018,7 +1075,7 @@ def from_route(self, route: APIRoute, extra_body=extra_body, response_cls=response_cls, extra_responses=extra_responses, - extra_requires=extra_requires + extra_requires=extra_requires, ) # inject data in the endpoint, not the route with probably other endpoints endpoint_data.update(route_data) @@ -1036,12 +1093,13 @@ def from_route(self, route: APIRoute, common_params = list(params or []) common_params.extend(extra_params) core_data = self.from_api( - route.handler, *new_routes, + route.handler, + *new_routes, tags=new_tags, params=common_params, response_cls=response_cls, responses=extra_responses, - requires=requires + requires=requires, ) if core_data: core_data.update(route_data) @@ -1049,12 +1107,16 @@ def from_route(self, route: APIRoute, return path_data - def from_api(self, api: Type[API], *routes, - tags: list = (), - params: list = None, - response_cls: Type[Response] = None, - responses: dict = None, - requires: list = None) -> Optional[dict]: + def from_api( + self, + api: Type[API], + *routes, + tags: list = (), + params: list = None, + response_cls: Type[Response] = None, + responses: dict = None, + requires: list = None, + ) -> Optional[dict]: if api.__external__: # external APIs will not participate in docs return None @@ -1063,7 +1125,7 @@ def from_api(self, api: Type[API], *routes, prop_params, body, prop_requires = self.parse_properties(api._properties) extra_params.extend(prop_params) - api_response = getattr(api, 'response', None) + api_response = getattr(api, "response", None) if Response.is_cls(api_response) and api_response != Response: # set response self.set_response(api_response, names=list(tags)) @@ -1072,12 +1134,13 @@ def from_api(self, api: Type[API], *routes, if api_route.private: continue route_paths = self.from_route( - api_route, *routes, + api_route, + *routes, tags=tags, params=extra_params, response_cls=api_response or response_cls, responses=responses, - requires=self.merge_requires(requires, prop_requires) + requires=self.merge_requires(requires, prop_requires), ) if not api_route.route and api_route.method: # core api methods diff --git a/utilmeta/core/auth/basic.py b/utilmeta/core/auth/basic.py index 4c1e64d..a712910 100644 --- a/utilmeta/core/auth/basic.py +++ b/utilmeta/core/auth/basic.py @@ -4,10 +4,8 @@ class BasicAuth(BaseAuthentication): - name = 'basic' - headers = [ - 'authorization' - ] + name = "basic" + headers = ["authorization"] @classmethod def getter(cls, request: Request, field=None): @@ -15,16 +13,16 @@ def getter(cls, request: Request, field=None): if not token: return decoded = base64.decodebytes(token.encode()) - lst = decoded.decode().split(':') + lst = decoded.decode().split(":") if len(lst) > 1: - username, password = lst[0], ':'.join(lst[1:]) + username, password = lst[0], ":".join(lst[1:]) else: username, password = lst[0], None - return {'username': username, 'password': password} + return {"username": username, "password": password} def openapi_scheme(self) -> dict: return { - 'type': 'http', - 'scheme': 'basic', - 'description': self.description or '', + "type": "http", + "scheme": "basic", + "description": self.description or "", } diff --git a/utilmeta/core/auth/jwt.py b/utilmeta/core/auth/jwt.py index dac7087..dfa7738 100644 --- a/utilmeta/core/auth/jwt.py +++ b/utilmeta/core/auth/jwt.py @@ -7,13 +7,11 @@ class JsonWebToken(BaseAuthentication): - name = 'jwt' - jwt_var = var.RequestContextVar('_jwt_token') - headers = [ - 'authorization' - ] + name = "jwt" + jwt_var = var.RequestContextVar("_jwt_token") + headers = ["authorization"] - def getter(self, request: Request, field = None): + def getter(self, request: Request, field=None): token_type, token = request.authorization if not token: return {} @@ -21,6 +19,7 @@ def getter(self, request: Request, field = None): from jwt import JWT # noqa from jwt.exceptions import JWTDecodeError # noqa from jwt.jwk import OctetJWK # noqa + jwt = JWT() key = None if self.secret_key: @@ -29,31 +28,33 @@ def getter(self, request: Request, field = None): # jwt 1.7 import jwt # noqa from jwt.exceptions import DecodeError as JWTDecodeError # noqa + key = self.secret_key try: jwt_params = jwt.decode(token, key, self.algorithm) # noqa except JWTDecodeError: - raise exceptions.BadRequest(f'invalid jwt token') + raise exceptions.BadRequest(f"invalid jwt token") if self.audience: - aud = jwt_params.get('aud') + aud = jwt_params.get("aud") if aud != self.audience: - raise exceptions.PermissionDenied(f'Invalid audience: {repr(aud)}') + raise exceptions.PermissionDenied(f"Invalid audience: {repr(aud)}") return jwt_params - def __init__(self, - secret_key: Union[str, Any], - algorithm: str = 'HS256', - # jwk: Union[str, dict] = None, - # jwk json string / dict - # jwk file path - # jwk url - audience: str = None, - required: bool = False, - user_token_field: str = None - ): + def __init__( + self, + secret_key: Union[str, Any], + algorithm: str = "HS256", + # jwk: Union[str, dict] = None, + # jwk json string / dict + # jwk file path + # jwk url + audience: str = None, + required: bool = False, + user_token_field: str = None, + ): super().__init__(required=required) if not secret_key: - raise ValueError('Authentication config error: JWT secret key is required') + raise ValueError("Authentication config error: JWT secret key is required") self.algorithm = algorithm self.secret_key = secret_key # self.jwk = jwk @@ -62,30 +63,30 @@ def __init__(self, def apply_user_model(self, user_model: ModelAdaptor): if self.user_token_field and not isinstance(self.user_token_field, str): - self.user_token_field = user_model.field_adaptor_cls(self.user_token_field).name + self.user_token_field = user_model.field_adaptor_cls( + self.user_token_field + ).name - def login(self, request: Request, key: str = 'uid', expiry_age: int = None): + def login(self, request: Request, key: str = "uid", expiry_age: int = None): user = var.user.getter(request) if not user: return import time from utilmeta import service + iat = time.time() inv = expiry_age - token_dict = { - 'iat': iat, - 'iss': service.origin, - key: user.pk - } + token_dict = {"iat": iat, "iss": service.origin, key: user.pk} if self.audience: - token_dict['aud'] = self.audience + token_dict["aud"] = self.audience if inv: - token_dict['exp'] = iat + inv + token_dict["exp"] = iat + inv try: # python-jwt # pip install jwt from jwt import JWT # noqa from jwt.jwk import OctetJWK # noqa + jwt = JWT() jwt_key = None if self.secret_key: @@ -96,15 +97,19 @@ def login(self, request: Request, key: str = 'uid', expiry_age: int = None): # pip install pyjwt # jwt 1.7 import jwt # noqa + jwt_token = jwt.encode( # noqa - token_dict, self.secret_key, - algorithm=self.algorithm + token_dict, self.secret_key, algorithm=self.algorithm ) if isinstance(jwt_token, bytes): # jwt > 2.0 gives the str - jwt_token = jwt_token.decode('ascii') + jwt_token = jwt_token.decode("ascii") self.jwt_var.setter(request, jwt_token) - return {self.user_token_field: jwt_token} if isinstance(self.user_token_field, str) else None + return ( + {self.user_token_field: jwt_token} + if isinstance(self.user_token_field, str) + else None + ) # if conf.jwt_token_field: # setattr(user, conf.jwt_token_field, jwt_token) # user.save(update_fields=[conf.jwt_token_field]) @@ -113,8 +118,8 @@ def login(self, request: Request, key: str = 'uid', expiry_age: int = None): def openapi_scheme(self) -> dict: return dict( - type='http', - scheme='bearer', - description=self.description or '', - bearerFormat='JWT', + type="http", + scheme="bearer", + description=self.description or "", + bearerFormat="JWT", ) diff --git a/utilmeta/core/auth/oauth2.py b/utilmeta/core/auth/oauth2.py index 35af949..8251813 100644 --- a/utilmeta/core/auth/oauth2.py +++ b/utilmeta/core/auth/oauth2.py @@ -16,12 +16,8 @@ class Auth0JWTBearerTokenValidator(JWTBearerTokenValidator): def __init__(self, domain, audience): issuer = f"https://{domain}/" jsonurl = urlopen(f"{issuer}.well-known/jwks.json") - public_key = JsonWebKey.import_key_set( - json.loads(jsonurl.read()) - ) - super(Auth0JWTBearerTokenValidator, self).__init__( - public_key - ) + public_key = JsonWebKey.import_key_set(json.loads(jsonurl.read())) + super(Auth0JWTBearerTokenValidator, self).__init__(public_key) self.claims_options = { "exp": {"essential": True}, "aud": {"essential": True, "value": audience}, @@ -30,11 +26,9 @@ def __init__(self, domain, audience): class OAuth2(BaseAuthentication): - name = 'oauth2' + name = "oauth2" protector_cls = ResourceProtector - headers = [ - 'authorization' - ] + headers = ["authorization"] def __init__(self, *validators: BearerTokenValidator, scopes_key=None): self.protector = self.protector_cls() @@ -54,7 +48,7 @@ def acquire_token(self, request: Request, scopes=None): method=request.method, uri=request.url, data=request.body, - headers=request.headers + headers=request.headers, ) req.req = request if isinstance(scopes, str): @@ -98,10 +92,10 @@ def openapi_scheme(self) -> dict: # bearerFormat='JWT', # ) return { - 'type': 'oauth2', - 'flows': { + "type": "oauth2", + "flows": { # todo # https://spec.openapis.org/oas/v3.1.0#oauthFlowObject }, - 'description': self.description or '', + "description": self.description or "", } diff --git a/utilmeta/core/auth/plugins/require.py b/utilmeta/core/auth/plugins/require.py index 67a8d3c..a746b3c 100644 --- a/utilmeta/core/auth/plugins/require.py +++ b/utilmeta/core/auth/plugins/require.py @@ -14,12 +14,12 @@ class AuthValidatorPlugin(APIPlugin): user_var = var.user user_id_var = var.user_id scopes_var = var.scopes - __all = '*' + __all = "*" @staticmethod def login(user): if not user: - raise exceptions.Unauthorized('login required') + raise exceptions.Unauthorized("login required") return True def __init__(self, *scope_or_fns, name: str = None): @@ -28,7 +28,7 @@ def __init__(self, *scope_or_fns, name: str = None): super().__init__(scopes=self.scopes, functions=self.functions, name=name) if not scope_or_fns: self.functions = [self.login] - name = name or 'login' + name = name or "login" self.name = name # self.readonly = readonly # self.login = login @@ -60,10 +60,10 @@ def validate_scopes(self, request: Request): return if not set(scopes or []).issuperset(self.scopes): raise exceptions.PermissionDenied( - 'insufficient scope', + "insufficient scope", scope=scopes, required_scope=self.scopes, - name=self.name + name=self.name, ) def validate_functions(self, request: Request): @@ -75,8 +75,7 @@ def validate_functions(self, request: Request): v = func(user) if not v: raise exceptions.PermissionDenied( - f'{self.name or "permission"} required', - name=self.name + f'{self.name or "permission"} required', name=self.name ) async def async_validate_functions(self, request: Request): @@ -90,6 +89,5 @@ async def async_validate_functions(self, request: Request): v = await v if not v: raise exceptions.PermissionDenied( - f'{self.name or "permission"} required', - name=self.name + f'{self.name or "permission"} required', name=self.name ) diff --git a/utilmeta/core/auth/properties.py b/utilmeta/core/auth/properties.py index beb391b..0c32db6 100644 --- a/utilmeta/core/auth/properties.py +++ b/utilmeta/core/auth/properties.py @@ -45,7 +45,9 @@ def get_user(self, request: Request): if inst is not None: # user.set(inst) if self.scopes_field: - self.scopes_context_var.setter(request, getattr(inst, self.scopes_field, [])) + self.scopes_context_var.setter( + request, getattr(inst, self.scopes_field, []) + ) self.context_var.setter(request, inst) return inst return None @@ -58,7 +60,9 @@ async def get_user(self, request: Request): if inst is not None: # user.set(inst) if self.scopes_field: - self.scopes_context_var.setter(request, getattr(inst, self.scopes_field, [])) + self.scopes_context_var.setter( + request, getattr(inst, self.scopes_field, []) + ) self.context_var.setter(request, inst) return inst return None @@ -95,12 +99,14 @@ async def getter(self, request: Request, field: ParserField = None): raise exc.Unauthorized return unprovided return user + # ------------------------------------------------------- def init(self, field: ParserField): if not self.user_model: if field.type and not isinstance(None, field.type): from utilmeta.core.orm.backends.base import ModelAdaptor + self.user_model = ModelAdaptor.dispatch(field.type) self.prepare_fields() return super().init(field) @@ -112,29 +118,34 @@ def init(self, field: ParserField): # self.parser_field = field # TODO: validate fields existent in user model - def __init__(self, - user_model=None, *, - authentication: BaseAuthentication, - key: str = '_user_id', - field: str = 'id', - scopes_field=None, - login_fields=None, - login_time_field=None, - login_ip_field=None, - password_field=None, - default=unprovided, - required: bool = None, - # context var - context_var=None, - id_context_var=None, - scopes_context_var=None, - ): + def __init__( + self, + user_model=None, + *, + authentication: BaseAuthentication, + key: str = "_user_id", + field: str = "id", + scopes_field=None, + login_fields=None, + login_time_field=None, + login_ip_field=None, + password_field=None, + default=unprovided, + required: bool = None, + # context var + context_var=None, + id_context_var=None, + scopes_context_var=None, + ): super().__init__(default=default, required=required) if not isinstance(authentication, BaseAuthentication): - raise TypeError(f'Invalid authentication, must be instance of BaseAuthentication subclasses') + raise TypeError( + f"Invalid authentication, must be instance of BaseAuthentication subclasses" + ) from utilmeta.core.orm import ModelAdaptor + self.user_model: ModelAdaptor = ModelAdaptor.dispatch(user_model) self.authentication = authentication @@ -211,6 +222,7 @@ def query_login_user(self, ident: str): if len(self.login_fields) == 1: return self.query_user(**{self.login_fields[0]: ident}) from utilmeta.core.orm.backends.django.expressions import Q + q = Q() for f in self.login_fields: q |= Q(**{f: ident}) @@ -224,6 +236,7 @@ async def aquery_login_user(self, ident: str): if len(self.login_fields) == 1: return await self.aquery_user(**{self.login_fields[0]: ident}) from utilmeta.core.orm.backends.django.expressions import Q + q = Q() for f in self.login_fields: q |= Q(**{f: ident}) @@ -231,7 +244,9 @@ async def aquery_login_user(self, ident: str): else: return None - def login(self, request: Request, ident: str, password: str, expiry_age: int = None) -> Optional[Any]: + def login( + self, request: Request, ident: str, password: str, expiry_age: int = None + ) -> Optional[Any]: user = self.query_login_user(ident) if not user: return None @@ -242,7 +257,9 @@ def login(self, request: Request, ident: str, password: str, expiry_age: int = N return user # @awaitable(login) - async def alogin(self, request: Request, ident: str, password: str, expiry_age: int = None) -> Optional[Any]: + async def alogin( + self, request: Request, ident: str, password: str, expiry_age: int = None + ) -> Optional[Any]: user = await self.aquery_login_user(ident) if not user: return None @@ -252,9 +269,17 @@ async def alogin(self, request: Request, ident: str, password: str, expiry_age: await self.alogin_user(request, user, expiry_age=expiry_age) return user - def login_user(self, request: Request, user, expiry_age: int = None, ignore_updates: bool = False) -> None: + def login_user( + self, + request: Request, + user, + expiry_age: int = None, + ignore_updates: bool = False, + ) -> None: self.context_var.setter(request, user) - self.id_context_var.setter(request, getattr(user, 'pk', None) or getattr(user, 'id', None)) + self.id_context_var.setter( + request, getattr(user, "pk", None) or getattr(user, "id", None) + ) try: data = self.authentication.login(request, self.key, expiry_age) except NotImplementedError: @@ -265,9 +290,17 @@ def login_user(self, request: Request, user, expiry_age: int = None, ignore_upda self.update_fields(request, user, data) # @awaitable(login_user) - async def alogin_user(self, request: Request, user, expiry_age: int = None, ignore_updates: bool = False) -> None: + async def alogin_user( + self, + request: Request, + user, + expiry_age: int = None, + ignore_updates: bool = False, + ) -> None: self.context_var.setter(request, user) - self.id_context_var.setter(request, getattr(user, 'pk', None) or getattr(user, 'id', None)) + self.id_context_var.setter( + request, getattr(user, "pk", None) or getattr(user, "id", None) + ) try: data = self.authentication.login(request, self.key, expiry_age) if inspect.isawaitable(data): @@ -302,4 +335,5 @@ async def aupdate_fields(self, request: Request, user, data=None) -> None: def check_password(cls, password: str, encoded: str): # you can override this method from django.contrib.auth.hashers import check_password + return check_password(password, encoded) diff --git a/utilmeta/core/auth/session/base.py b/utilmeta/core/auth/session/base.py index df0b3a6..7796918 100644 --- a/utilmeta/core/auth/session/base.py +++ b/utilmeta/core/auth/session/base.py @@ -11,13 +11,11 @@ class BaseSession(BaseAuthentication): __private__ = True - name = 'session' + name = "session" Cookie = Cookie - DEFAULT_CONTEXT_VAR = var.RequestContextVar('_session', cached=True) + DEFAULT_CONTEXT_VAR = var.RequestContextVar("_session", cached=True) DEFAULT_ENGINE = None - headers = [ - 'cookie' - ] + headers = ["cookie"] def get_session(self, request: Request, engine=None): session_key = request.cookies.get(self.cookie_name) @@ -49,7 +47,7 @@ def get_engine(self, field: ParserField): if isinstance(None, engine) or not engine: engine = self.engine if not callable(engine): - raise TypeError('No engine specified') + raise TypeError("No engine specified") return engine def init(self, field: ParserField): @@ -86,16 +84,17 @@ async def process_response(self, response: Response, api=None): return SessionPlugin(_session_self) - def __init__(self, - engine: Union[str, type] = None, - expire_at_browser_close: bool = False, - save_every_request: bool = False, - cycle_key_at_login: bool = True, - allow_localhost: bool = False, - interrupted: Literal['override', 'cycle', 'error'] = 'override', - cookie: Cookie = Cookie(http_only=True), - context_var=None, - ): + def __init__( + self, + engine: Union[str, type] = None, + expire_at_browser_close: bool = False, + save_every_request: bool = False, + cycle_key_at_login: bool = True, + allow_localhost: bool = False, + interrupted: Literal["override", "cycle", "error"] = "override", + cookie: Cookie = Cookie(http_only=True), + context_var=None, + ): super().__init__() assert isinstance(cookie, Cookie) if isinstance(engine, str): @@ -103,9 +102,9 @@ def __init__(self, self.engine = engine or self.DEFAULT_ENGINE # self.engine = import_obj(engine) if isinstance(engine, str) else engine self.cookie = cookie - self.cookie_name = cookie.name or 'sessionid' + self.cookie_name = cookie.name or "sessionid" if not self.cookie.http_only: - warnings.warn(f'Session using cookie should turn http_only=True') + warnings.warn(f"Session using cookie should turn http_only=True") self.context_var = context_var or self.DEFAULT_CONTEXT_VAR # self.cache_alias = cache_alias @@ -180,7 +179,13 @@ def save_session(self, response: Response, session): async def save_session(self, response: Response, session): raise NotImplementedError - def _set_cookie(self, response: Response, session_key: str, max_age: int = None, expires: str = None): + def _set_cookie( + self, + response: Response, + session_key: str, + max_age: int = None, + expires: str = None, + ): cookie_domain = self.cookie.domain secure = self.cookie.secure or None same_site = self.cookie.same_site @@ -189,7 +194,7 @@ def _set_cookie(self, response: Response, session_key: str, max_age: int = None, secure = None cookie_domain = None if not localhost(response.request.host): - same_site = 'None' + same_site = "None" secure = True # secure is required to use SameSite=None response.set_cookie( @@ -206,8 +211,8 @@ def _set_cookie(self, response: Response, session_key: str, max_age: int = None, def openapi_scheme(self) -> dict: return { - 'type': 'apiKey', - 'name': self.cookie_name, - 'in': 'cookie', - 'description': self.description or '', + "type": "apiKey", + "name": self.cookie_name, + "in": "cookie", + "description": self.description or "", } diff --git a/utilmeta/core/auth/session/cache.py b/utilmeta/core/auth/session/cache.py index af0938f..b2b7109 100644 --- a/utilmeta/core/auth/session/cache.py +++ b/utilmeta/core/auth/session/cache.py @@ -1,5 +1,11 @@ -from .schema import BaseSessionSchema, SchemaSession, SessionCreateError, SessionUpdateError +from .schema import ( + BaseSessionSchema, + SchemaSession, + SessionCreateError, + SessionUpdateError, +) from utilmeta.core.cache import CacheConnections, Cache + # from utilmeta.utils import awaitable from typing import Type from utype import Field @@ -7,7 +13,7 @@ class CacheSessionSchema(BaseSessionSchema): __connections_cls__ = CacheConnections - _config: 'CacheSession' + _config: "CacheSession" def get_cache(self) -> Cache: return self.__connections_cls__.get(self._config.cache_alias) @@ -15,10 +21,10 @@ def get_cache(self) -> Cache: @property @Field(no_output=True) def cache_key_prefix(self): - common_prefix = f'{CacheSessionSchema.__module__}' + common_prefix = f"{CacheSessionSchema.__module__}" # not self.__class__, cause this class may be inherited by different schema classes - key_prefix = self._config.key_prefix or '' - return f'{common_prefix}{key_prefix}' + key_prefix = self._config.key_prefix or "" + return f"{common_prefix}{key_prefix}" @property @Field(no_output=True) @@ -67,10 +73,10 @@ def save(self, must_create: bool = False): cache = self.get_cache() if not must_create: - if self._config.interrupted != 'override': + if self._config.interrupted != "override": if cache.get(self.get_key()) is None: # old session data is deleted - if self._config.interrupted == 'cycle': + if self._config.interrupted == "cycle": self._session_key = self._get_new_session_key() else: raise SessionUpdateError @@ -91,10 +97,10 @@ async def asave(self, must_create: bool = False): cache = self.get_cache() if not must_create: - if self._config.interrupted != 'override': + if self._config.interrupted != "override": if await cache.aget(self.get_key()) is None: # old session data is deleted - if self._config.interrupted == 'cycle': + if self._config.interrupted == "cycle": self._session_key = await self._aget_new_session_key() else: raise SessionUpdateError @@ -180,7 +186,13 @@ class CacheSession(SchemaSession): DEFAULT_ENGINE = CacheSessionSchema schema = CacheSessionSchema - def __init__(self, engine=None, cache_alias: str = 'default', key_prefix: str = None, **kwargs): + def __init__( + self, + engine=None, + cache_alias: str = "default", + key_prefix: str = None, + **kwargs, + ): super().__init__(engine, **kwargs) self.cache_alias = cache_alias self.key_prefix = key_prefix @@ -188,6 +200,8 @@ def __init__(self, engine=None, cache_alias: str = 'default', key_prefix: str = def init(self, field): # check cache exists if not CacheConnections.get(self.cache_alias): - raise ValueError(f'{self.__class__.__name__}: cache_alias ({repr(self.cache_alias)}) ' - f'not defined in {CacheConnections}') + raise ValueError( + f"{self.__class__.__name__}: cache_alias ({repr(self.cache_alias)}) " + f"not defined in {CacheConnections}" + ) return super().init(field) diff --git a/utilmeta/core/auth/session/cached_db.py b/utilmeta/core/auth/session/cached_db.py index 9abff31..c73abd6 100644 --- a/utilmeta/core/auth/session/cached_db.py +++ b/utilmeta/core/auth/session/cached_db.py @@ -3,7 +3,7 @@ from .db import DBSessionSchema from utilmeta.core.orm import ModelAdaptor -__all__ = ['CachedDBSessionSchema', 'CachedDBSession'] +__all__ = ["CachedDBSessionSchema", "CachedDBSession"] class CachedDBSessionSchema(CacheSessionSchema, DBSessionSchema): @@ -30,7 +30,7 @@ def save(self, must_create: bool = False): try: super().save(must_create) except Exception as e: - print(f'Save with error: {e}') + print(f"Save with error: {e}") # ignore cache failed async def asave(self, must_create: bool = False): @@ -40,7 +40,7 @@ async def asave(self, must_create: bool = False): try: await super().asave(must_create) except Exception as e: - print(f'Save with error: {e}') + print(f"Save with error: {e}") # ignore cache failed def delete(self, session_key: str = None): @@ -49,7 +49,7 @@ def delete(self, session_key: str = None): try: super().delete(session_key) except Exception as e: - print(f'Delete with error: {e}') + print(f"Delete with error: {e}") # ignore cache failed # @awaitable(delete) @@ -59,7 +59,7 @@ async def adelete(self, session_key: str = None): try: await super().adelete(session_key) except Exception as e: - print(f'Delete with error: {e}') + print(f"Delete with error: {e}") # ignore cache failed def load_data(self): @@ -67,8 +67,7 @@ def load_data(self): return None # to be inherited session = self._model_cls.get_instance( - session_key=self._session_key, - deleted_time=None + session_key=self._session_key, deleted_time=None ) if session: try: @@ -78,7 +77,7 @@ def load_data(self): timeout=self.timeout, ) except Exception as e: - print(f'Sync to cache failed: {e}') + print(f"Sync to cache failed: {e}") # ignore return self.decode(session.encoded_data) return None @@ -89,8 +88,7 @@ async def aload_data(self): return None # to be inherited session = await self._model_cls.aget_instance( - session_key=self._session_key, - deleted_time=None + session_key=self._session_key, deleted_time=None ) if session: try: @@ -100,7 +98,7 @@ async def aload_data(self): timeout=self.timeout, ) except Exception as e: - print(f'Sync to cache failed: {e}') + print(f"Sync to cache failed: {e}") # ignore return self.decode(session.encoded_data) return None diff --git a/utilmeta/core/auth/session/db.py b/utilmeta/core/auth/session/db.py index 768277b..c4c9b0c 100644 --- a/utilmeta/core/auth/session/db.py +++ b/utilmeta/core/auth/session/db.py @@ -12,7 +12,7 @@ class DBSessionSchema(BaseSessionSchema): - _config: 'DBSession' + _config: "DBSession" # SESSION_ID_KEY: ClassVar = '$session_id' # CLIENT_IP_KEY: ClassVar = '$client_ip' # CLIENT_UA_KEY: ClassVar = '$client_ua' @@ -70,7 +70,7 @@ def get_session_data(self): encoded_data=self.encode(dict(self)), expiry_time=time_now() + timedelta(seconds=self.timeout), last_activity=time_now(), - created_time=self._request.time if self._request else time_now() + created_time=self._request.time if self._request else time_now(), ) def load_object(self, must_create: bool = False): @@ -83,10 +83,7 @@ def load_object(self, must_create: bool = False): elif not must_create: obj = self._model_cls.get_instance(session_key=self.session_key) session_id = obj.pk if obj else None - return self._model_cls.init_instance( - id=session_id, - **self.get_session_data() - ) + return self._model_cls.init_instance(id=session_id, **self.get_session_data()) async def aload_object(self, must_create: bool = False): """ @@ -101,10 +98,7 @@ async def aload_object(self, must_create: bool = False): data = self.get_session_data() if inspect.isawaitable(data): data = await data - return self._model_cls.init_instance( - id=session_id, - **data - ) + return self._model_cls.init_instance(id=session_id, **data) def db_save(self, must_create=False, force: bool = True): if self.session_key is None: @@ -121,7 +115,9 @@ async def adb_save(self, must_create=False, force: bool = True): obj = await self.aload_object(must_create) if not obj.pk: must_create = True - await obj.asave(force_insert=must_create, force_update=not must_create and force) + await obj.asave( + force_insert=must_create, force_update=not must_create and force + ) def save(self, must_create: bool = False): return self.db_save(must_create) @@ -135,10 +131,9 @@ def db_delete(self, session_key=None): if self.session_key is None: return session_key = self.session_key - self._model_cls.get_queryset( - session_key=session_key, - deleted_time=None - ).update(deleted_time=time_now()) + self._model_cls.get_queryset(session_key=session_key, deleted_time=None).update( + deleted_time=time_now() + ) # @awaitable(db_delete) async def adb_delete(self, session_key=None): @@ -147,8 +142,7 @@ async def adb_delete(self, session_key=None): return session_key = self.session_key await self._model_cls.get_queryset( - session_key=session_key, - deleted_time=None + session_key=session_key, deleted_time=None ).aupdate(deleted_time=time_now()) def delete(self, session_key: str = None): @@ -175,8 +169,7 @@ async def aload(self): return None # to be inherited session = await self._model_cls.aget_instance( - session_key=self._session_key, - deleted_time=None + session_key=self._session_key, deleted_time=None ) if session: return self.decode(session.encoded_data) @@ -187,6 +180,6 @@ class DBSession(SchemaSession): DEFAULT_ENGINE = DBSessionSchema schema = DBSessionSchema - def __init__(self, session_model: Type['AbstractSession'], **kwargs): + def __init__(self, session_model: Type["AbstractSession"], **kwargs): super().__init__(**kwargs) self.session_model = ModelAdaptor.dispatch(session_model) diff --git a/utilmeta/core/auth/session/django.py b/utilmeta/core/auth/session/django.py index c74fe8c..a8624ea 100644 --- a/utilmeta/core/auth/session/django.py +++ b/utilmeta/core/auth/session/django.py @@ -10,13 +10,14 @@ class DjangoSession(BaseSession): - def __init__(self, - engine: Union[str, Callable] = None, - cache_alias: str = 'default', - file_path: str = None, - serializer: str = None, - **kwargs - ): + def __init__( + self, + engine: Union[str, Callable] = None, + cache_alias: str = "default", + file_path: str = None, + serializer: str = None, + **kwargs, + ): super().__init__(engine, **kwargs) self.cache_alias = cache_alias self.file_path = file_path @@ -31,26 +32,27 @@ def get_engine(self, field: ParserField): if isinstance(None, engine): engine = self.engine if not issubclass(engine, SessionBase): - raise TypeError(f'Invalid django engine: {engine}') + raise TypeError(f"Invalid django engine: {engine}") return engine def init(self, field: ParserField): from utilmeta.core.server.backends.django import DjangoSettings + dj_settings = DjangoSettings.config() if dj_settings: dj_settings.register(self) else: - warnings.warn('No DjangoSettings is used in service') + warnings.warn("No DjangoSettings is used in service") super().init(field) def as_django(self): return { - 'SESSION_CACHE_ALIAS': self.cache_alias, - 'SESSION_SERIALIZER': self.serializer, - 'SESSION_ENGINE': self.engine, - 'SESSION_EXPIRE_AT_BROWSER_CLOSE': self.expire_at_browser_close, - 'SESSION_FILE_PATH': self.file_path, - **self.cookie.as_django(prefix='SESSION') + "SESSION_CACHE_ALIAS": self.cache_alias, + "SESSION_SERIALIZER": self.serializer, + "SESSION_ENGINE": self.engine, + "SESSION_EXPIRE_AT_BROWSER_CLOSE": self.expire_at_browser_close, + "SESSION_FILE_PATH": self.file_path, + **self.cookie.as_django(prefix="SESSION"), } def login(self, request, key: str, expiry_age: int = None, user_id_var=var.user_id): diff --git a/utilmeta/core/auth/session/schema.py b/utilmeta/core/auth/session/schema.py index 21a7780..8516c24 100644 --- a/utilmeta/core/auth/session/schema.py +++ b/utilmeta/core/auth/session/schema.py @@ -13,7 +13,7 @@ if TYPE_CHECKING: from utilmeta.core.request import Request -T = TypeVar('T') +T = TypeVar("T") EPOCH = datetime(1970, 1, 1, 0, 0, 0, tzinfo=timezone.utc) @@ -26,6 +26,7 @@ class SessionCreateError(SessionError): Used internally as a consistent exception type to catch from save (see the docstring for SessionBase.save() for details). """ + pass @@ -33,6 +34,7 @@ class SessionUpdateError(SessionError): """ Occurs if Django tries to update a session that was deleted. """ + pass @@ -42,27 +44,27 @@ class BaseSessionSchema(Schema): 1. support both sync and async session 2. use schema to define session fields """ + __options__ = Options(addition=True, ignore_required=True) _serializer_cls: ClassVar = JSONSerializer - _config: 'SchemaSession' + _config: "SchemaSession" _session_key = None - _request: 'Request' = None + _request: "Request" = None _modified = False # inner fields expiry: Optional[datetime] = Field( - required=False, default=None, - defer_default=True, alias='_session_expiry' + required=False, default=None, defer_default=True, alias="_session_expiry" ) - key_salt: ClassVar[str] = 'utilmeta.core.auth.session.schema' + key_salt: ClassVar[str] = "utilmeta.core.auth.session.schema" # not based __class__, because it may change for different APIs @classmethod - def init_from(cls: Type[T], session_key: str, config: 'SchemaSession') -> T: + def init_from(cls: Type[T], session_key: str, config: "SchemaSession") -> T: self = cls.__new__(cls) if not isinstance(config, SchemaSession): - raise TypeError(f'Invalid session config: {config}') + raise TypeError(f"Invalid session config: {config}") self._config = config if session_key: if not isinstance(session_key, str): @@ -77,10 +79,10 @@ def init_from(cls: Type[T], session_key: str, config: 'SchemaSession') -> T: return self @classmethod - async def ainit_from(cls: Type[T], session_key: str, config: 'SchemaSession') -> T: + async def ainit_from(cls: Type[T], session_key: str, config: "SchemaSession") -> T: self = cls.__new__(cls) if not isinstance(config, SchemaSession): - raise TypeError(f'Invalid session config: {config}') + raise TypeError(f"Invalid session config: {config}") self._config = config if session_key: if not isinstance(session_key, str): @@ -95,10 +97,10 @@ async def ainit_from(cls: Type[T], session_key: str, config: 'SchemaSession') -> return self @classmethod - def init(cls: Type[T], request: Request, config: 'SchemaSession') -> T: + def init(cls: Type[T], request: Request, config: "SchemaSession") -> T: # for developer to directly call if not isinstance(config, SchemaSession): - raise TypeError(f'Invalid session config: {config}') + raise TypeError(f"Invalid session config: {config}") cvar = config.context_var.setup(request) if cvar.contains(): data: BaseSessionSchema = cvar.get() @@ -121,9 +123,9 @@ def init(cls: Type[T], request: Request, config: 'SchemaSession') -> T: return session @classmethod - async def ainit(cls: Type[T], request: Request, config: 'SchemaSession') -> T: + async def ainit(cls: Type[T], request: Request, config: "SchemaSession") -> T: if not isinstance(config, SchemaSession): - raise TypeError(f'Invalid session config: {config}') + raise TypeError(f"Invalid session config: {config}") cvar = config.context_var.setup(request) if cvar.contains(): data: BaseSessionSchema = await cvar.get() @@ -192,7 +194,7 @@ def pop(self, key, default=unprovided): args = () if unprovided(default) else (default,) return super().pop(key, *args) - def setdefault(self, key, default): # noqa + def setdefault(self, key, default): # noqa if key in self: return self[key] else: @@ -202,19 +204,25 @@ def setdefault(self, key, default): # noqa def encode(self, session_dict): from django.core import signing + return signing.dumps( - session_dict, salt=self.key_salt, serializer=self._serializer_cls, + session_dict, + salt=self.key_salt, + serializer=self._serializer_cls, compress=True, ) def decode(self, session_data): from django.core import signing + try: - return signing.loads(session_data, salt=self.key_salt, serializer=self._serializer_cls) + return signing.loads( + session_data, salt=self.key_salt, serializer=self._serializer_cls + ) except Exception as e: # ValueError, unpickling exceptions. If any of these happen, just # return an empty dictionary (an empty session). - warnings.warn(f'Session data corrupted: {str(e)}') + warnings.warn(f"Session data corrupted: {str(e)}") return {} def update(self, __m=None, **kwargs): @@ -289,7 +297,7 @@ def expiry_age(self) -> int: arguments specifying the modification and expiry of the session. """ expiry = self.expiry - if not expiry: # Checks both None and 0 cases + if not expiry: # Checks both None and 0 cases return self._config.cookie.age return max(0, int((expiry - time_now(expiry)).total_seconds())) @@ -299,14 +307,18 @@ def exists(self, session_key): """ Return True if the given session_key already exists. """ - raise NotImplementedError('subclasses of SessionBase must provide an exists() method') + raise NotImplementedError( + "subclasses of SessionBase must provide an exists() method" + ) # @awaitable(exists) async def aexists(self, session_key): """ Return True if the given session_key already exists. """ - raise NotImplementedError('subclasses of SessionBase must provide an aexists() method') + raise NotImplementedError( + "subclasses of SessionBase must provide an aexists() method" + ) def create(self): """ @@ -314,7 +326,9 @@ def create(self): a unique key and will have saved the result once (with empty data) before the method returns. """ - raise NotImplementedError('subclasses of SessionBase must provide a create() method') + raise NotImplementedError( + "subclasses of SessionBase must provide a create() method" + ) def acreate(self): """ @@ -322,7 +336,9 @@ def acreate(self): a unique key and will have saved the result once (with empty data) before the method returns. """ - raise NotImplementedError('subclasses of SessionBase must provide a acreate() method') + raise NotImplementedError( + "subclasses of SessionBase must provide a acreate() method" + ) def save(self, must_create=False): """ @@ -330,32 +346,44 @@ def save(self, must_create=False): object (or raise CreateError). Otherwise, only update an existing object and don't create one (raise UpdateError if needed). """ - raise NotImplementedError('subclasses of SessionBase must provide a save() method') + raise NotImplementedError( + "subclasses of SessionBase must provide a save() method" + ) # @awaitable(save) async def asave(self, must_create=False): - raise NotImplementedError('subclasses of SessionBase must provide a asave() method') + raise NotImplementedError( + "subclasses of SessionBase must provide a asave() method" + ) def delete(self, session_key=None): """ Delete the session data under this key. If the key is None, use the current session key value. """ - raise NotImplementedError('subclasses of SessionBase must provide a delete() method') + raise NotImplementedError( + "subclasses of SessionBase must provide a delete() method" + ) # @awaitable(delete) async def adelete(self, session_key=None): - raise NotImplementedError('subclasses of SessionBase must provide a adelete() method') + raise NotImplementedError( + "subclasses of SessionBase must provide a adelete() method" + ) def load(self): """ Load the session data and return a dictionary. """ - raise NotImplementedError('subclasses of SessionBase must provide a load() method') + raise NotImplementedError( + "subclasses of SessionBase must provide a load() method" + ) # @awaitable(load) async def aload(self): - raise NotImplementedError('subclasses of SessionBase must provide a aload() method') + raise NotImplementedError( + "subclasses of SessionBase must provide a aload() method" + ) class SchemaSession(BaseSession): @@ -367,7 +395,9 @@ def __init__(self, engine=None, **kwargs): super().__init__(engine=engine, **kwargs) @self - class schema(self.engine or self.DEFAULT_ENGINE): pass + class schema(self.engine or self.DEFAULT_ENGINE): + pass + schema._config = self self.schema = schema @@ -380,7 +410,7 @@ def get_engine(self, field): if isinstance(None, engine): engine = self.engine if not issubclass(engine, self.DEFAULT_ENGINE): - raise TypeError(f'Invalid SchemaSession engine: {engine}') + raise TypeError(f"Invalid SchemaSession engine: {engine}") return engine def get_session(self, request: Request, engine=None): @@ -469,7 +499,9 @@ def login(self, request, key: str, expiry_age: int = None, user_id_var=var.user_ session.expiry = time_now() + timedelta(seconds=expiry_age) @awaitable(login) - async def login(self, request, key: str, expiry_age: int = None, user_id_var=var.user_id): + async def login( + self, request, key: str, expiry_age: int = None, user_id_var=var.user_id + ): new_user_id = await user_id_var.getter(request) if new_user_id is None: return @@ -500,8 +532,11 @@ def save_session(self, response: Response, session: BaseSessionSchema): if (session.modified or self.save_every_request) and not session.is_empty: if response.status != 500: expiry = session.expiry - expire_at_browser_close = self.expire_at_browser_close if expiry is None \ + expire_at_browser_close = ( + self.expire_at_browser_close + if expiry is None else expiry.timestamp() == 0 + ) if expire_at_browser_close: max_age = None expires = None @@ -524,8 +559,11 @@ async def save_session(self, response: Response, session: BaseSessionSchema): if (session.modified or self.save_every_request) and not session.is_empty: if response.status != 500: expiry = session.expiry - expire_at_browser_close = self.expire_at_browser_close if expiry is None \ - else str(expiry).startswith('1970-01-01 00:00:00') + expire_at_browser_close = ( + self.expire_at_browser_close + if expiry is None + else str(expiry).startswith("1970-01-01 00:00:00") + ) if expire_at_browser_close: max_age = None expires = None diff --git a/utilmeta/core/auth/signature.py b/utilmeta/core/auth/signature.py index c1c278f..cf552a2 100644 --- a/utilmeta/core/auth/signature.py +++ b/utilmeta/core/auth/signature.py @@ -5,6 +5,7 @@ from utype.parser.field import ParserField from utilmeta.utils import exceptions as exc from utilmeta.utils import get_interval, awaitable + # from utilmeta.adapt.orm.base import ModelAdaptor from utype.types import Datetime from utype.utils.datastructures import unprovided @@ -12,15 +13,19 @@ class SignatureAccess(BaseAuthentication): - name = 'signature' + name = "signature" user_context_var = var.user scopes_context_var = var.scopes def init(self, field: ParserField): from utilmeta.core.orm.backends.base import ModelAdaptor - self.access_models.extend([ - ModelAdaptor.dispatch(m) for m in field.input_origins - if m and not isinstance(None, m)] + + self.access_models.extend( + [ + ModelAdaptor.dispatch(m) + for m in field.input_origins + if m and not isinstance(None, m) + ] ) if not self.access_models: pass @@ -34,10 +39,10 @@ def _get_pre_data(self, request: Request): return None sig = self.get_request_signature(request) if not sig: - raise exc.BadRequest(f'{self.__class__}: {self.signature_header} required') + raise exc.BadRequest(f"{self.__class__}: {self.signature_header} required") ts = self.get_request_timestamp(request) if not ts: - raise exc.BadRequest(f'{self.__class__}: {self.timestamp_header} required') + raise exc.BadRequest(f"{self.__class__}: {self.timestamp_header} required") return ak, ts, sig def _validate_post_data(self, request: Request, access, ts, sig): @@ -45,9 +50,9 @@ def _validate_post_data(self, request: Request, access, ts, sig): if timeout is not None: dt = Datetime(ts) if abs((request.time - dt).total_seconds()) > timeout: - raise exc.PermissionDenied('Timestamp expired') + raise exc.PermissionDenied("Timestamp expired") if not access: - raise exc.PermissionDenied('invalid Access Key') + raise exc.PermissionDenied("invalid Access Key") sk = getattr(access, self.secret_key_field) if self.user_field: user = getattr(access, self.user_field) @@ -57,7 +62,7 @@ def _validate_post_data(self, request: Request, access, ts, sig): self.scopes_context_var.set(getattr(access, self.scopes_field, [])) gen_sig = self.get_signature(request, timestamp=ts, secret_key=sk) if sig != gen_sig: - raise exc.PermissionDenied('Invalid Signature') + raise exc.PermissionDenied("Invalid Signature") def getter(self, request: Request, field: ParserField = None): r = self._get_pre_data(request) @@ -81,18 +86,19 @@ async def getter(self, request: Request, field: ParserField = None): def __init__( self, *access_models: type, - access_key_field: Union[str, Any] = 'access_key', - secret_key_field: Union[str, Any] = 'secret_key', + access_key_field: Union[str, Any] = "access_key", + secret_key_field: Union[str, Any] = "secret_key", user_field: str = None, scopes_field: str = None, - access_key_header: str = 'X-Access-Key', - signature_header: str = 'X-Signature', - timestamp_header: str = 'X-Timestamp', + access_key_header: str = "X-Access-Key", + signature_header: str = "X-Signature", + timestamp_header: str = "X-Timestamp", timestamp_timeout: Union[timedelta, int, Callable] = 30, required: bool = False, - digest_mode: str = 'SHA256' + digest_mode: str = "SHA256", ): from utilmeta.core.orm.backends.base import ModelAdaptor + self.access_models = [ModelAdaptor.dispatch(m) for m in access_models] self.access_key_field = access_key_field @@ -102,8 +108,11 @@ def __init__( self.access_key_header = access_key_header self.signature_header = signature_header self.timestamp_header = timestamp_header - self.timestamp_timeout = get_interval(timestamp_timeout, null=True) if not \ - callable(timestamp_timeout) else timestamp_timeout + self.timestamp_timeout = ( + get_interval(timestamp_timeout, null=True) + if not callable(timestamp_timeout) + else timestamp_timeout + ) self.digest_mode = digest_mode super().__init__(required=required, default=unprovided if required else None) @@ -125,9 +134,11 @@ def get_signature(self, request: Request, timestamp: str, secret_key: str) -> st """ Can be override """ - tag = f'{timestamp}{request.adaptor.request_method}{request.url}'.encode() - tag += request.body or b'' - return hmac.new(key=secret_key.encode(), msg=tag, digestmod=self.digest_mode).hexdigest() + tag = f"{timestamp}{request.adaptor.request_method}{request.url}".encode() + tag += request.body or b"" + return hmac.new( + key=secret_key.encode(), msg=tag, digestmod=self.digest_mode + ).hexdigest() def get_access_instance(self, access_key: str): for model in self.access_models: @@ -153,10 +164,10 @@ def openapi_scheme(self) -> dict: # bearerFormat='JWT', # ) return { - 'type': 'apiKey', - 'name': self.access_key_header, - 'in': 'header', - 'description': self.description or '', + "type": "apiKey", + "name": self.access_key_header, + "in": "header", + "description": self.description or "", } @property @@ -164,5 +175,5 @@ def headers(self): return [ self.access_key_header.lower(), self.timestamp_header.lower(), - self.signature_header.lower() + self.signature_header.lower(), ] diff --git a/utilmeta/core/cache/backends/django.py b/utilmeta/core/cache/backends/django.py index 54a1b6c..56df720 100644 --- a/utilmeta/core/cache/backends/django.py +++ b/utilmeta/core/cache/backends/django.py @@ -6,21 +6,22 @@ class DjangoCacheAdaptor(BaseCacheAdaptor): - LOCMEM: ClassVar = 'django.core.cache.backends.locmem.LocMemCache' - MEMCACHED: ClassVar = 'django.core.cache.backends.memcached.MemcachedCache' - PYLIBMC: ClassVar = 'django.core.cache.backends.memcached.PyLibMCCache' - REDIS: ClassVar = 'django.core.cache.backends.redis.RedisCache' + LOCMEM: ClassVar = "django.core.cache.backends.locmem.LocMemCache" + MEMCACHED: ClassVar = "django.core.cache.backends.memcached.MemcachedCache" + PYLIBMC: ClassVar = "django.core.cache.backends.memcached.PyLibMCCache" + REDIS: ClassVar = "django.core.cache.backends.redis.RedisCache" DEFAULT_ENGINES = { - 'locmem': LOCMEM, - 'memcached': MEMCACHED, - 'pylibmc': PYLIBMC, - 'redis': REDIS + "locmem": LOCMEM, + "memcached": MEMCACHED, + "pylibmc": PYLIBMC, + "redis": REDIS, } @property def cache(self): from django.core.cache import caches, BaseCache + cache: BaseCache = caches[self.alias] return cache @@ -28,12 +29,16 @@ def check(self): try: import django except (ModuleNotFoundError, ImportError) as e: - raise e.__class__(f'{self.__class__} as database adaptor requires to install django') from e + raise e.__class__( + f"{self.__class__} as database adaptor requires to install django" + ) from e def get(self, key: str, default=None): return self.cache.get(key, default) - def fetch(self, args=None, *keys: str, named: bool = False) -> Union[list, Dict[str, Any]]: + def fetch( + self, args=None, *keys: str, named: bool = False + ) -> Union[list, Dict[str, Any]]: # get many keys = keys_or_args(args, *keys) data = self.cache.get_many(keys) @@ -42,8 +47,15 @@ def fetch(self, args=None, *keys: str, named: bool = False) -> Union[list, Dict[ else: return [data.get(key) for key in keys] - def set(self, key: str, value, *, timeout: Union[int, timedelta, datetime] = None, - exists_only: bool = False, not_exists_only: bool = False): + def set( + self, + key: str, + value, + *, + timeout: Union[int, timedelta, datetime] = None, + exists_only: bool = False, + not_exists_only: bool = False, + ): if exists_only: if not self.exists(key): return @@ -75,7 +87,9 @@ def expire(self, *keys: str, timeout: float): for key in keys: return self.cache.touch(key, timeout=timeout) - def alter(self, key: str, amount: Union[int, float], limit: int = None) -> Optional[Union[int, float]]: + def alter( + self, key: str, amount: Union[int, float], limit: int = None + ) -> Optional[Union[int, float]]: if not amount: return self.get(key) try: diff --git a/utilmeta/core/cache/backends/redis/aioredis.py b/utilmeta/core/cache/backends/redis/aioredis.py index ceb66c6..abdfd04 100644 --- a/utilmeta/core/cache/backends/redis/aioredis.py +++ b/utilmeta/core/cache/backends/redis/aioredis.py @@ -8,7 +8,7 @@ class AioredisAdaptor(BaseCacheAdaptor): asynchronous = True - def __init__(self, config: 'Cache', alias: str = None): + def __init__(self, config: "Cache", alias: str = None): super().__init__(config, alias=alias) self._cache = None @@ -16,7 +16,10 @@ def get_cache(self): if self._cache: return self._cache import aioredis - rd = aioredis.from_url(self.config.get_location(), encoding="utf-8", decode_responses=True) + + rd = aioredis.from_url( + self.config.get_location(), encoding="utf-8", decode_responses=True + ) self._cache = rd return rd @@ -24,14 +27,18 @@ def check(self): try: import aioredis except (ModuleNotFoundError, ImportError) as e: - raise e.__class__(f'{self.__class__} as database adaptor requires to install caches. ' - f'use pip install aioredis') from e + raise e.__class__( + f"{self.__class__} as database adaptor requires to install caches. " + f"use pip install aioredis" + ) from e async def get(self, key: str, default=None): cache = self.get_cache() return await cache.get(key) - async def fetch(self, args=None, *keys: str, named: bool = False) -> Union[list, Dict[str, Any]]: + async def fetch( + self, args=None, *keys: str, named: bool = False + ) -> Union[list, Dict[str, Any]]: # get many keys = keys_or_args(args, *keys) cache = self.get_cache() @@ -41,10 +48,19 @@ async def fetch(self, args=None, *keys: str, named: bool = False) -> Union[list, else: return result - async def set(self, key: str, value, *, timeout: Union[int, timedelta, datetime] = None, - exists_only: bool = False, not_exists_only: bool = False): + async def set( + self, + key: str, + value, + *, + timeout: Union[int, timedelta, datetime] = None, + exists_only: bool = False, + not_exists_only: bool = False, + ): cache = self.get_cache() - return await cache.set(key, value, ex=timeout, nx=not_exists_only, xx=exists_only) + return await cache.set( + key, value, ex=timeout, nx=not_exists_only, xx=exists_only + ) async def update(self, data: Dict[str, Any]): # set many @@ -72,7 +88,9 @@ async def expire(self, *keys: str, timeout: float): for key in keys: await cache.expire(key, timeout) - async def alter(self, key: str, amount: Union[int, float], limit: int = None) -> Optional[Union[int, float]]: + async def alter( + self, key: str, amount: Union[int, float], limit: int = None + ) -> Optional[Union[int, float]]: if not amount: return await self.get(key) cache = self.get_cache() diff --git a/utilmeta/core/cache/backends/redis/config.py b/utilmeta/core/cache/backends/redis/config.py index 967c6b4..287d8c3 100644 --- a/utilmeta/core/cache/backends/redis/config.py +++ b/utilmeta/core/cache/backends/redis/config.py @@ -11,50 +11,54 @@ class RedisCache(Cache): username: Optional[str] = None password: Optional[str] = None db: int = 0 - scheme: str = 'redis' + scheme: str = "redis" def __init__( - self, *, + self, + *, username: Optional[str] = None, password: Optional[str] = None, - scheme: str = 'redis', + scheme: str = "redis", db: int = 0, - **kwargs + **kwargs, ): super().__init__( - engine='redis', + engine="redis", username=username, password=password, scheme=scheme, db=db, - **kwargs + **kwargs, ) @property def type(self) -> str: - return 'redis' + return "redis" def get_location(self): if self.location: return self.location if not self.password: - return f'{self.scheme}://{self.host}:{self.port}/{self.db}' + return f"{self.scheme}://{self.host}:{self.port}/{self.db}" else: return f'{self.scheme}://{self.username or ""}:{self.password}@{self.host}:{self.port}/{self.db}' @property def con(self): from redis import Redis + return Redis.from_url(self.get_location()) @property def async_con(self): from aioredis.client import Redis + cli: Redis = self.get_adaptor(True).get_cache() return cli def info(self): from redis.exceptions import ConnectionError + try: return self.con.info() except ConnectionError: diff --git a/utilmeta/core/cache/backends/redis/entity.py b/utilmeta/core/cache/backends/redis/entity.py index 0fd91cd..afa4854 100644 --- a/utilmeta/core/cache/backends/redis/entity.py +++ b/utilmeta/core/cache/backends/redis/entity.py @@ -12,7 +12,7 @@ class RedisCacheEntity(CacheEntity): - backend_name = 'redis' + backend_name = "redis" @property def con(self) -> Redis: @@ -30,8 +30,10 @@ def reset_stats(self): def keys(self): if not self.src.trace_keys: - raise NotImplementedError(f'Cache.keys not implemented, ' - f'please set trace_keys=True to enable this method') + raise NotImplementedError( + f"Cache.keys not implemented, " + f"please set trace_keys=True to enable this method" + ) tot_keys: List[bytes] = self.con.zrange(self.update_key, 0, -1) if not tot_keys: return [] @@ -65,8 +67,10 @@ def get_latest_update(self): def last_modified(self, *keys: str) -> Optional[datetime]: if not self.src.trace_keys: - raise NotImplementedError(f'Cache.last_modified not implemented, ' - f'please set trace_keys=True to enable this method') + raise NotImplementedError( + f"Cache.last_modified not implemented, " + f"please set trace_keys=True to enable this method" + ) times = [] for key in keys: sc = self.con.zscore(self.update_key, key) @@ -78,8 +82,10 @@ def last_modified(self, *keys: str) -> Optional[datetime]: def clear(self): if not self.src.trace_keys: - raise NotImplementedError(f'Cache.clear not implemented, ' - f'please set trace_keys=True to enable this method') + raise NotImplementedError( + f"Cache.clear not implemented, " + f"please set trace_keys=True to enable this method" + ) keys = [v.decode() for v in self.con.zrange(self.update_key, 0, -1)] # do not need to validate exists now self.con.delete(*keys, self.requests_key, self.update_key, self.hits_key) @@ -92,19 +98,25 @@ def delete(self, *keys: str): if not keys: return if self.readonly: - raise RuntimeError(f'Attempt to delete key ({keys}) at a readonly cache') + raise RuntimeError(f"Attempt to delete key ({keys}) at a readonly cache") # upd_keys = self.last_update_key(keys) self.con.delete(*keys) if self.src.trace_keys: - self.con.zrem(self.hits_key, *keys) # remove deleted keys from hit statistics - self.con.zrem(self.update_key, *keys) # remove deleted keys from last update + self.con.zrem( + self.hits_key, *keys + ) # remove deleted keys from hit statistics + self.con.zrem( + self.update_key, *keys + ) # remove deleted keys from last update def count(self) -> int: if not self.src.trace_keys: - raise NotImplementedError(f'Cache.keys not implemented, ' - f'please set trace_keys=True to enable this method') + raise NotImplementedError( + f"Cache.keys not implemented, " + f"please set trace_keys=True to enable this method" + ) # return self.con.zcard(self.total_key) # consider key timeout, we cannot cat the accurate exists metrics from total key only tot_keys = self.con.zrange(self.update_key, 0, -1) @@ -113,7 +125,9 @@ def count(self) -> int: def exists(self, *keys: str) -> int: return self.con.exists(*keys) - def alter(self, key: str, amount: Union[int, float], limit: Union[int, float] = None): + def alter( + self, key: str, amount: Union[int, float], limit: Union[int, float] = None + ): if self.readonly: return None self.prepare(key) @@ -132,22 +146,25 @@ def alter(self, key: str, amount: Union[int, float], limit: Union[int, float] = if isinstance(result, NUM_TYPES): return result if isinstance(result, bytes): - result = result.decode() # noqa - return get_number(result) # noqa + result = result.decode() # noqa + return get_number(result) # noqa def lock(self, *keys: str, block: bool = False): return RedisLocker( - self.con, *keys, + self.con, + *keys, block=block, timeout=self.src.lock_timeout, - blocking_timeout=self.src.lock_blocking_timeout + blocking_timeout=self.src.lock_blocking_timeout, ) def lpush(self, key: str, *values): if not values: return if self.readonly: - raise PermissionError(f'Attempt to lpush ({key} -> {values}) to a readonly cache') + raise PermissionError( + f"Attempt to lpush ({key} -> {values}) to a readonly cache" + ) res = self.con.lpush(key, *values) if self.src.trace_keys: self.con.zadd(self.update_key, {key: utc_ms_ts()}) @@ -157,7 +174,9 @@ def rpush(self, key: str, *values): if not values: return if self.readonly: - raise PermissionError(f'Attempt to rpush ({key} -> {values}) to a readonly cache') + raise PermissionError( + f"Attempt to rpush ({key} -> {values}) to a readonly cache" + ) res = self.con.rpush(key, *values) if self.src.trace_keys: self.con.zadd(self.update_key, {key: utc_ms_ts()}) @@ -165,7 +184,7 @@ def rpush(self, key: str, *values): def lpop(self, key: str): if self.readonly: - raise PermissionError(f'Attempt to lpop ({key}) to a readonly cache') + raise PermissionError(f"Attempt to lpop ({key}) to a readonly cache") res = self.con.lpop(key) if self.src.trace_keys: self.con.zadd(self.update_key, {key: utc_ms_ts()}) @@ -173,7 +192,7 @@ def lpop(self, key: str): def rpop(self, key: str): if self.readonly: - raise PermissionError(f'Attempt to rpop ({key}) to a readonly cache') + raise PermissionError(f"Attempt to rpop ({key}) to a readonly cache") res = self.con.rpop(key) if self.src.trace_keys: self.con.zadd(self.update_key, {key: utc_ms_ts()}) @@ -231,7 +250,9 @@ def prepare(self, *keys: str): current_keys = self.con.zrange(self.update_key, 0, -1) excess: int = self.exists(*current_keys, *keys) - self.src.max_entries - if excess > self.src.max_entries_tolerance: # default to 0, but can set a throttle value + if ( + excess > self.src.max_entries_tolerance + ): # default to 0, but can set a throttle value # total_key >= max_entries # delete the least frequently hit key target_key = None @@ -246,7 +267,9 @@ def prepare(self, *keys: str): if target_key: try: - del_keys = [items[0] for items in self.con.zpopmin(target_key, excess)] + del_keys = [ + items[0] for items in self.con.zpopmin(target_key, excess) + ] except (ResponseError, *COMMON_ERRORS): # old version of redis or windows not support this command, downgrade for k in self.con.zrange(target_key, 0, -1): @@ -263,7 +286,9 @@ def prepare(self, *keys: str): if self.src.max_variants: excess = self.con.zcard(self.vary_hits_key) - self.src.max_variants - if excess > self.src.max_variants_tolerance: # default to 0, but can set a throttle value + if ( + excess > self.src.max_variants_tolerance + ): # default to 0, but can set a throttle value # total_key >= max_entries # delete the least frequently hit key target_key = None @@ -278,7 +303,10 @@ def prepare(self, *keys: str): if target_key: try: - del_keys = [items[0] for items in self.con.zpopmin(target_key, excess)] + del_keys = [ + items[0] + for items in self.con.zpopmin(target_key, excess) + ] except (ResponseError, *COMMON_ERRORS): # old version of redis or windows not support this command, downgrade for k in self.con.zrange(target_key, 0, -1): @@ -293,7 +321,7 @@ def prepare(self, *keys: str): def update(self, data: dict, timeout: float = None): if self.readonly: - raise PermissionError(f'Attempt to set val {data} to a readonly cache') + raise PermissionError(f"Attempt to set val {data} to a readonly cache") if timeout == 0: # will expire ASAP return @@ -304,10 +332,18 @@ def update(self, data: dict, timeout: float = None): # for incrby / decrby / incrbyfloat work fine at lua script number typed data will not be dump self.con.mset(dumped) - def set(self, key: str, val, timeout: float = None, - exists_only: bool = False, not_exists_only: bool = False): + def set( + self, + key: str, + val, + timeout: float = None, + exists_only: bool = False, + not_exists_only: bool = False, + ): if self.readonly: - raise PermissionError(f'Attempt to set val ({self.keys} -> {repr(val)}) to a readonly cache') + raise PermissionError( + f"Attempt to set val ({self.keys} -> {repr(val)}) to a readonly cache" + ) if timeout == 0: # will expire ASAP return diff --git a/utilmeta/core/cache/backends/redis/lock.py b/utilmeta/core/cache/backends/redis/lock.py index 56d5aeb..d5152a4 100644 --- a/utilmeta/core/cache/backends/redis/lock.py +++ b/utilmeta/core/cache/backends/redis/lock.py @@ -8,17 +8,19 @@ class RedisLocker(BaseLocker): def __init__(self, con: Redis, *keys, **kwargs): super().__init__(*keys, **kwargs) from redis.lock import Lock + self.con = con self.locks: List[Lock] = [] def __enter__(self): import time + start = time.time() for key in self.scope: lock = self.con.lock( name=self.key_func(key), blocking_timeout=self.blocking_timeout, - timeout=self.timeout + timeout=self.timeout, ) if lock.acquire(blocking=self.block, token=gen_key(32, alnum=True)): self.targets.append(key) @@ -26,11 +28,12 @@ def __enter__(self): end = time.time() if self.timeout: if (end - start) > self.timeout: - raise TimeoutError(f'Locker acquire keys: {self.scope} timeout') + raise TimeoutError(f"Locker acquire keys: {self.scope} timeout") return self def __exit__(self, exc_type, exc_val, exc_tb): from redis.exceptions import LockError + for lock in self.locks: try: lock.release() diff --git a/utilmeta/core/cache/backends/redis/scripts/__init__.py b/utilmeta/core/cache/backends/redis/scripts/__init__.py index 7e9a852..a2a2a5f 100644 --- a/utilmeta/core/cache/backends/redis/scripts/__init__.py +++ b/utilmeta/core/cache/backends/redis/scripts/__init__.py @@ -2,8 +2,8 @@ script_path = os.path.dirname(__file__) -BATCH_RETRIEVE_LUA = open(os.path.join(script_path, 'batch_retrieve.lua')).read() -BATCH_EXISTS_LUA = open(os.path.join(script_path, 'batch_exists.lua')).read() -BATCH_RELATES_LUA = open(os.path.join(script_path, 'batch_relates.lua')).read() -BATCH_COUNT_LUA = open(os.path.join(script_path, 'batch_count.lua')).read() -ALTER_AMOUNT_LUA = open(os.path.join(script_path, 'alter_amount.lua')).read() +BATCH_RETRIEVE_LUA = open(os.path.join(script_path, "batch_retrieve.lua")).read() +BATCH_EXISTS_LUA = open(os.path.join(script_path, "batch_exists.lua")).read() +BATCH_RELATES_LUA = open(os.path.join(script_path, "batch_relates.lua")).read() +BATCH_COUNT_LUA = open(os.path.join(script_path, "batch_count.lua")).read() +ALTER_AMOUNT_LUA = open(os.path.join(script_path, "alter_amount.lua")).read() diff --git a/utilmeta/core/cache/base.py b/utilmeta/core/cache/base.py index c40a294..f9a5bd3 100644 --- a/utilmeta/core/cache/base.py +++ b/utilmeta/core/cache/base.py @@ -15,7 +15,7 @@ def get_cache(self): return def get_engine(self): - if '.' in self.config.engine: + if "." in self.config.engine: return self.config.engine if self.config.engine.lower() in self.DEFAULT_ENGINES: return self.DEFAULT_ENGINES[self.config.engine.lower()] @@ -30,12 +30,21 @@ def exec(self, command: str): def get(self, key: str, default=None): raise NotImplementedError - def fetch(self, args=None, *keys: str, named: bool = False) -> Union[list, Dict[str, Any]]: + def fetch( + self, args=None, *keys: str, named: bool = False + ) -> Union[list, Dict[str, Any]]: # get many raise NotImplementedError - def set(self, key: str, value, *, timeout: Union[int, timedelta, datetime] = None, - exists_only: bool = False, not_exists_only: bool = False): + def set( + self, + key: str, + value, + *, + timeout: Union[int, timedelta, datetime] = None, + exists_only: bool = False, + not_exists_only: bool = False + ): raise NotImplementedError def update(self, data: Dict[str, Any]): @@ -54,5 +63,7 @@ def exists(self, args=None, *keys) -> int: def expire(self, *keys: str, timeout: float): raise NotImplementedError - def alter(self, key: str, amount: Union[int, float], limit: int = None) -> Optional[Union[int, float]]: + def alter( + self, key: str, amount: Union[int, float], limit: int = None + ) -> Optional[Union[int, float]]: raise NotImplementedError diff --git a/utilmeta/core/cache/config.py b/utilmeta/core/cache/config.py index 97569aa..3345998 100644 --- a/utilmeta/core/cache/config.py +++ b/utilmeta/core/cache/config.py @@ -14,12 +14,9 @@ class Cache(Config): This is just a declaration interface for database the real implementation is database adaptor """ - DEFAULT_HOST: ClassVar = '127.0.0.1' - DEFAULT_PORTS: ClassVar = { - 'redis': 6379, - 'mcache': 11211, - 'memcache': 11211 - } + + DEFAULT_HOST: ClassVar = "127.0.0.1" + DEFAULT_PORTS: ClassVar = {"redis": 6379, "mcache": 11211, "memcache": 11211} sync_adaptor_cls = None async_adaptor_cls = None @@ -29,24 +26,26 @@ class Cache(Config): host: Optional[str] = None port: int = 0 timeout: int = 300 - location: Union[str, List[str]] = '' + location: Union[str, List[str]] = "" prefix: Optional[str] = None max_entries: Optional[int] = None key_function: Optional[Callable] = None options: Optional[dict] = None - def __init__(self, *, - engine: str, # 'redis' / 'memcached' / 'locmem' - host: Optional[str] = None, - port: int = 0, - timeout: int = 300, - location: Union[str, List[str]] = '', - prefix: Optional[str] = None, - max_entries: Optional[int] = None, - key_function: Optional[Callable] = None, - options: Optional[dict] = None, - **kwargs - ): + def __init__( + self, + *, + engine: str, # 'redis' / 'memcached' / 'locmem' + host: Optional[str] = None, + port: int = 0, + timeout: int = 300, + location: Union[str, List[str]] = "", + prefix: Optional[str] = None, + max_entries: Optional[int] = None, + key_function: Optional[Callable] = None, + options: Optional[dict] = None, + **kwargs, + ): kwargs.update(locals()) super().__init__(kwargs) self.host = self.host or self.DEFAULT_HOST @@ -61,21 +60,21 @@ def __init__(self, *, @property def type(self) -> str: - if 'redis' in self.engine.lower(): - return 'redis' - elif 'memcached' in self.engine.lower(): - return 'memcached' - elif 'locmem' in self.engine.lower(): - return 'locmem' - elif 'file' in self.engine.lower(): - return 'file' - elif 'database' in self.engine.lower() or 'db' in self.engine.lower(): - return 'db' - return 'memory' + if "redis" in self.engine.lower(): + return "redis" + elif "memcached" in self.engine.lower(): + return "memcached" + elif "locmem" in self.engine.lower(): + return "locmem" + elif "file" in self.engine.lower(): + return "file" + elif "database" in self.engine.lower() or "db" in self.engine.lower(): + return "db" + return "memory" @property def is_memory(self) -> bool: - return self.type in ['locmem', 'memory'] + return self.type in ["locmem", "memory"] @property def local(self): @@ -93,34 +92,40 @@ def apply(self, alias: str, asynchronous: bool = None): self.adaptor = self.async_adaptor_cls(self, alias) else: from .backends.django import DjangoCacheAdaptor + self.adaptor = DjangoCacheAdaptor(self, alias) else: if self.sync_adaptor_cls: self.adaptor = self.sync_adaptor_cls(self, alias) else: from .backends.django import DjangoCacheAdaptor + self.adaptor = DjangoCacheAdaptor(self, alias) self.asynchronous = asynchronous self.adaptor.check() self._applied = True - def get_adaptor(self, asynchronous: bool = False) -> 'BaseCacheAdaptor': + def get_adaptor(self, asynchronous: bool = False) -> "BaseCacheAdaptor": if not self._applied: - self.apply('default', asynchronous) + self.apply("default", asynchronous) if self.adaptor and self.adaptor.asynchronous == asynchronous: return self.adaptor if asynchronous: if not self.async_adaptor_cls: - raise exceptions.SettingNotConfigured(self.__class__, item='async_adaptor_cls') + raise exceptions.SettingNotConfigured( + self.__class__, item="async_adaptor_cls" + ) return self.async_adaptor_cls(self, self.adaptor.alias) if not self.sync_adaptor_cls: - raise exceptions.SettingNotConfigured(self.__class__, item='sync_adaptor_cls') + raise exceptions.SettingNotConfigured( + self.__class__, item="sync_adaptor_cls" + ) return self.sync_adaptor_cls(self, self.adaptor.alias) def get_location(self): if self.location: return self.location - return f'{self.host}:{self.port}' + return f"{self.host}:{self.port}" def get(self, key: str, default=None): return self.get_adaptor(False).get(key, default) @@ -128,30 +133,50 @@ def get(self, key: str, default=None): async def aget(self, key: str, default=None): return await self.get_adaptor(True).get(key, default) - def fetch(self, args=None, *keys: str, named: bool = False) -> Union[list, Dict[str, Any]]: + def fetch( + self, args=None, *keys: str, named: bool = False + ) -> Union[list, Dict[str, Any]]: # get many return self.get_adaptor(False).fetch(args, *keys, named=named) - async def afetch(self, args=None, *keys: str, named: bool = False) -> Union[list, Dict[str, Any]]: + async def afetch( + self, args=None, *keys: str, named: bool = False + ) -> Union[list, Dict[str, Any]]: # get many return await self.get_adaptor(True).fetch(args, *keys, named=named) - def set(self, key: str, value, *, timeout: Union[int, timedelta, datetime] = None, - exists_only: bool = False, not_exists_only: bool = False): + def set( + self, + key: str, + value, + *, + timeout: Union[int, timedelta, datetime] = None, + exists_only: bool = False, + not_exists_only: bool = False, + ): return self.get_adaptor(False).set( - key, value, + key, + value, timeout=timeout, exists_only=exists_only, - not_exists_only=not_exists_only + not_exists_only=not_exists_only, ) - async def aset(self, key: str, value, *, timeout: Union[int, timedelta, datetime] = None, - exists_only: bool = False, not_exists_only: bool = False): + async def aset( + self, + key: str, + value, + *, + timeout: Union[int, timedelta, datetime] = None, + exists_only: bool = False, + not_exists_only: bool = False, + ): return await self.get_adaptor(True).set( - key, value, + key, + value, timeout=timeout, exists_only=exists_only, - not_exists_only=not_exists_only + not_exists_only=not_exists_only, ) def update(self, data: Dict[str, Any]): @@ -187,65 +212,81 @@ def expire(self, *keys: str, timeout: float): async def aexpire(self, *keys: str, timeout: float): return await self.get_adaptor(True).expire(*keys, timeout=timeout) - def alter(self, key: str, amount: Union[int, float], limit: int = None) -> Optional[Union[int, float]]: + def alter( + self, key: str, amount: Union[int, float], limit: int = None + ) -> Optional[Union[int, float]]: return self.get_adaptor(False).alter(key, amount, limit=limit) - async def aalter(self, key: str, amount: Union[int, float], limit: int = None) -> Optional[Union[int, float]]: + async def aalter( + self, key: str, amount: Union[int, float], limit: int = None + ) -> Optional[Union[int, float]]: return await self.get_adaptor(True).alter(key, amount, limit=limit) # deprecate in the future @awaitable(get) async def get(self, key: str, default=None): - warnings.warn(f'Deprecated in future, please use aget()', DeprecationWarning) + warnings.warn(f"Deprecated in future, please use aget()", DeprecationWarning) return await self.aget(key, default) @awaitable(fetch) - async def fetch(self, args=None, *keys: str, named: bool = False) -> Union[list, Dict[str, Any]]: + async def fetch( + self, args=None, *keys: str, named: bool = False + ) -> Union[list, Dict[str, Any]]: # get many - warnings.warn(f'Deprecated in future, please use afetch()', DeprecationWarning) + warnings.warn(f"Deprecated in future, please use afetch()", DeprecationWarning) return await self.afetch(args, *keys, named=named) @awaitable(set) - async def set(self, key: str, value, *, timeout: Union[int, timedelta, datetime] = None, - exists_only: bool = False, not_exists_only: bool = False): - warnings.warn(f'Deprecated in future, please use aset()', DeprecationWarning) + async def set( + self, + key: str, + value, + *, + timeout: Union[int, timedelta, datetime] = None, + exists_only: bool = False, + not_exists_only: bool = False, + ): + warnings.warn(f"Deprecated in future, please use aset()", DeprecationWarning) return await self.aset( - key, value, + key, + value, timeout=timeout, exists_only=exists_only, - not_exists_only=not_exists_only + not_exists_only=not_exists_only, ) @awaitable(update) async def update(self, data: Dict[str, Any]): - warnings.warn(f'Deprecated in future, please use aupdate()', DeprecationWarning) + warnings.warn(f"Deprecated in future, please use aupdate()", DeprecationWarning) # set many return await self.aupdate(data) @awaitable(pop) async def pop(self, key: str): - warnings.warn(f'Deprecated in future, please use apop()', DeprecationWarning) + warnings.warn(f"Deprecated in future, please use apop()", DeprecationWarning) # set many return await self.apop(key) @awaitable(delete) async def delete(self, args=None, *keys): - warnings.warn(f'Deprecated in future, please use adelete()', DeprecationWarning) + warnings.warn(f"Deprecated in future, please use adelete()", DeprecationWarning) return await self.get_adaptor(True).delete(args, *keys) @awaitable(exists) async def exists(self, args=None, *keys) -> int: - warnings.warn(f'Deprecated in future, please use aexists()', DeprecationWarning) + warnings.warn(f"Deprecated in future, please use aexists()", DeprecationWarning) return await self.aexists(args, *keys) @awaitable(expire) async def expire(self, *keys: str, timeout: float): - warnings.warn(f'Deprecated in future, please use aexpire()', DeprecationWarning) + warnings.warn(f"Deprecated in future, please use aexpire()", DeprecationWarning) return await self.aexpire(*keys, timeout=timeout) @awaitable(alter) - async def alter(self, key: str, amount: Union[int, float], limit: int = None) -> Optional[Union[int, float]]: - warnings.warn(f'Deprecated in future, please use aalter()', DeprecationWarning) + async def alter( + self, key: str, amount: Union[int, float], limit: int = None + ) -> Optional[Union[int, float]]: + warnings.warn(f"Deprecated in future, please use aalter()", DeprecationWarning) return await self.aalter(key, amount, limit=limit) @@ -270,7 +311,7 @@ def add_cache(self, service: UtilMeta, alias: str, cache: Cache): self.caches.setdefault(alias, cache) @classmethod - def get(cls, alias: str = 'default', default=unprovided) -> Cache: + def get(cls, alias: str = "default", default=unprovided) -> Cache: config = cls.config() if not config: if unprovided(default): @@ -298,5 +339,3 @@ def items(self): # async def on_shutdown(self, service): # for key, value in self.caches.items(): # await value.disconnect() - - diff --git a/utilmeta/core/cache/lock.py b/utilmeta/core/cache/lock.py index 5507f84..c0d66ea 100644 --- a/utilmeta/core/cache/lock.py +++ b/utilmeta/core/cache/lock.py @@ -2,8 +2,15 @@ class BaseLocker: - def __init__(self, *keys, key_func: Callable[[str], str] = lambda x: x + '!', block: bool = False, - timeout: int = None, blocking_timeout: int = None, sleep: int = 0.1): + def __init__( + self, + *keys, + key_func: Callable[[str], str] = lambda x: x + "!", + block: bool = False, + timeout: int = None, + blocking_timeout: int = None, + sleep: int = 0.1 + ): self.key_func = key_func self.timeout = timeout self.blocking_timeout = blocking_timeout @@ -12,7 +19,7 @@ def __init__(self, *keys, key_func: Callable[[str], str] = lambda x: x + '!', bl self.targets = [] self.block = block - def __enter__(self) -> 'BaseLocker': + def __enter__(self) -> "BaseLocker": raise NotImplementedError def __exit__(self, exc_type, exc_val, exc_tb): diff --git a/utilmeta/core/cache/plugins/api.py b/utilmeta/core/cache/plugins/api.py index f1711b1..4fe0025 100644 --- a/utilmeta/core/cache/plugins/api.py +++ b/utilmeta/core/cache/plugins/api.py @@ -1,4 +1,12 @@ -from utilmeta.utils import pop, get_interval, Header, time_now, http_time, COMMON_ERRORS, fast_digest +from utilmeta.utils import ( + pop, + get_interval, + Header, + time_now, + http_time, + COMMON_ERRORS, + fast_digest, +) from utype import type_transform from utype.parser.func import FunctionParser from utilmeta.utils import exceptions as exc @@ -15,7 +23,7 @@ NUM = Union[int, float] VAL = Union[str, bytes, list, tuple, dict] -__all__ = ['ServerCache'] +__all__ = ["ServerCache"] NOT_MODIFIED_KEEP_HEADERS = ( @@ -30,23 +38,23 @@ class ServerCache(BaseCacheInterface): - NO_CACHE = 'no-cache' - NO_STORE = 'no-store' - NO_TRANSFORM = 'no-transform' - PUBLIC = 'public' - MAX_AGE = 'max-age' - MAX_STALE = 'max-stale' - MAX_FRESH = 'max-fresh' - PRIVATE = 'private' - MUST_REVALIDATE = 'must-revalidate' + NO_CACHE = "no-cache" + NO_STORE = "no-store" + NO_TRANSFORM = "no-transform" + PUBLIC = "public" + MAX_AGE = "max-age" + MAX_STALE = "max-stale" + MAX_FRESH = "max-fresh" + PRIVATE = "private" + MUST_REVALIDATE = "must-revalidate" # volatile strategy - OBSOLETE_LRU = 'LRU' # least recently updated - OBSOLETE_LFU = 'LFU' # least frequently used - OBSOLETE_RANDOM = 'RANDOM' + OBSOLETE_LRU = "LRU" # least recently updated + OBSOLETE_LFU = "LFU" # least frequently used + OBSOLETE_RANDOM = "RANDOM" # omit - OMIT_ARG_NAMES = ('self', 'cls', 'mcs') + OMIT_ARG_NAMES = ("self", "cls", "mcs") FOREVER_TIMEDELTA = timedelta(days=365 * 50) # standard vary @@ -76,7 +84,7 @@ def expires_next_minute(cls): month=current_time.month, day=current_time.day, hour=current_time.hour, - minute=current_time.minute + minute=current_time.minute, ) + timedelta(minutes=1) @classmethod @@ -86,7 +94,7 @@ def expires_next_hour(cls): year=current_time.year, month=current_time.month, day=current_time.day, - hour=current_time.hour + hour=current_time.hour, ) + timedelta(hours=1) @classmethod @@ -96,7 +104,7 @@ def expires_next_day(cls): year=current_time.year, month=current_time.month, day=current_time.day, - hour=0 + hour=0, ) + timedelta(days=1) @classmethod @@ -107,7 +115,7 @@ def expires_next_week(cls): year=current_time.year, month=current_time.month, day=current_time.day, - hour=0 + hour=0, ) + timedelta(days=delta_days) @classmethod @@ -119,32 +127,27 @@ def expires_next_month(cls): else: year = current_time.year month = current_time.month + 1 - return datetime( - year=year, - month=month, - day=0 - ) + return datetime(year=year, month=month, day=0) @classmethod def expires_next_year(cls): current_time = time_now() - return datetime( - year=current_time.year + 1, - month=1, - day=0 - ) + return datetime(year=current_time.year + 1, month=1, day=0) @classmethod def expires_next_utc_day(cls): current_time = time_now().astimezone(timezone.utc) - return (datetime( - year=current_time.year, - month=current_time.month, - day=current_time.day, - hour=0 - ) + timedelta(days=1)).replace(tzinfo=timezone.utc) - - entity_cls: Type[CacheEntity] # can be override + return ( + datetime( + year=current_time.year, + month=current_time.month, + day=current_time.day, + hour=0, + ) + + timedelta(days=1) + ).replace(tzinfo=timezone.utc) + + entity_cls: Type[CacheEntity] # can be override cache_alias: str # expiry_time: Union[int, datetime, timedelta, Callable] scope_prefix: str @@ -152,31 +155,33 @@ def expires_next_utc_day(cls): vary_header: Union[str, List[str]] # vary_function: Callable[['Request'], str] - def __init__(self, cache_alias: str = 'default', - scope_prefix: str = None, - # user can manually assign, and allow two cache instance have the same scope key - cache_control: str = None, - cache_response: bool = False, - etag_response: bool = False, - vary_header: Union[str, List[str]] = None, - vary_function=None, - expiry_time: Union[int, datetime, timedelta, Callable, None] = 0, - # normalizer, take the request and return the normalized result - max_entries: int = None, # None means unlimited - max_entries_policy: str = OBSOLETE_LFU, - max_entries_tolerance: int = 0, - # if vary is specified, max_entries is relative to each variant - max_variants: int = None, - max_variants_policy: str = OBSOLETE_LFU, - max_variants_tolerance: int = 0, - trace_keys: bool = None, - default_timeout: Union[int, float, timedelta] = None, - entity_cls: Type[CacheEntity] = None, - document: str = None, - ): + def __init__( + self, + cache_alias: str = "default", + scope_prefix: str = None, + # user can manually assign, and allow two cache instance have the same scope key + cache_control: str = None, + cache_response: bool = False, + etag_response: bool = False, + vary_header: Union[str, List[str]] = None, + vary_function=None, + expiry_time: Union[int, datetime, timedelta, Callable, None] = 0, + # normalizer, take the request and return the normalized result + max_entries: int = None, # None means unlimited + max_entries_policy: str = OBSOLETE_LFU, + max_entries_tolerance: int = 0, + # if vary is specified, max_entries is relative to each variant + max_variants: int = None, + max_variants_policy: str = OBSOLETE_LFU, + max_variants_tolerance: int = 0, + trace_keys: bool = None, + default_timeout: Union[int, float, timedelta] = None, + entity_cls: Type[CacheEntity] = None, + document: str = None, + ): _locals = dict(locals()) - pop(_locals, 'self') + pop(_locals, "self") super().__init__(**_locals) # from utilmeta.conf import config @@ -184,8 +189,11 @@ def __init__(self, cache_alias: str = 'default', if expiry_time: if isinstance(expiry_time, (classmethod, staticmethod)): expiry_time = expiry_time.__func__ - assert isinstance(expiry_time, (int, float, datetime, timedelta)) or callable(expiry_time), \ - f'Invalid Cache expiry_time: {expiry_time}, must be instance of int/datetime/timedelta or a callable' + assert isinstance( + expiry_time, (int, float, datetime, timedelta) + ) or callable( + expiry_time + ), f"Invalid Cache expiry_time: {expiry_time}, must be instance of int/datetime/timedelta or a callable" if callable(expiry_time): expiry_time = self.function_parser_cls.apply_for(expiry_time) @@ -199,20 +207,27 @@ def __init__(self, cache_alias: str = 'default', # default function: vary to cookie vary_header = self.VARY_COOKIE else: - warnings.warn(f'Cache with vary_function ({vary_function}) should specify a vary_header' - f' to generate response["Vary"] header, like use vary_header="Cookie" ' - f'when you vary to user_id / session_id, because it is derived from cookie') - - assert callable(vary_function), f'Cache.vary_function must be a callable, got {vary_function}' + warnings.warn( + f"Cache with vary_function ({vary_function}) should specify a vary_header" + f' to generate response["Vary"] header, like use vary_header="Cookie" ' + f"when you vary to user_id / session_id, because it is derived from cookie" + ) + + assert callable( + vary_function + ), f"Cache.vary_function must be a callable, got {vary_function}" # if config.preference.validate_request_functions: from utilmeta.core.request import Request + _res = vary_function(Request()) self.vary_header = vary_header self.vary_function = vary_function if not self.varied and self.max_variants: - raise ValueError(f'Cache with max_variants: {self.max_variants} got no vary_header or vary_function') + raise ValueError( + f"Cache with max_variants: {self.max_variants} got no vary_header or vary_function" + ) # private properties self._expiry_time = expiry_time @@ -222,7 +237,7 @@ def __init__(self, cache_alias: str = 'default', self._max_stale = None self._stale = False self._max_fresh = None - self._expiry_datetime = None # runtime + self._expiry_datetime = None # runtime self._if_modified_since: Optional[datetime] = None self._if_none_match: Optional[str] = None self._if_unmodified_since: Optional[datetime] = None @@ -242,7 +257,7 @@ def varied(self) -> bool: def variant(self) -> str: return self._variant - def init_expiry(self, request: 'Request') -> Optional[datetime]: + def init_expiry(self, request: "Request") -> Optional[datetime]: expiry = self._expiry_time if callable(expiry): expiry = expiry(request) @@ -273,7 +288,7 @@ def set_expiry(self, expires: Union[datetime, timedelta, int, float]): self._expiry_datetime = time_now() + timedelta(seconds=inv) self._max_age = int(inv) - def get_variant(self, request: 'Request') -> Optional[str]: + def get_variant(self, request: "Request") -> Optional[str]: if self.vary_function: # try not to contains sensitive data in vary_function results key = self.vary_function(request) @@ -290,6 +305,7 @@ def get_variant(self, request: 'Request') -> Optional[str]: def etag_function(cls, value): # can be inherit from utilmeta.utils import etag + return etag(value) @property @@ -314,7 +330,10 @@ def no_cache(self): but run the function and get the un-cached response we may store it to cache if max_entries is not 0 and not full (unlike no-store) """ - return self.request_cache_control in (self.NO_CACHE, self.NO_STORE) or not self.has_entries + return ( + self.request_cache_control in (self.NO_CACHE, self.NO_STORE) + or not self.has_entries + ) @property def last_modified(self): @@ -334,8 +353,12 @@ def etag(self): def etag(self, val): self._etag = self.etag_function(val) - def check_modified(self, last_modified: Union[datetime, int, float, str] = None, - resource=None, etag: str = None): + def check_modified( + self, + last_modified: Union[datetime, int, float, str] = None, + resource=None, + etag: str = None, + ): if last_modified: self.last_modified = last_modified if etag or resource: @@ -351,11 +374,17 @@ def check_modified(self, last_modified: Union[datetime, int, float, str] = None, return False - def check_precondition(self, last_modified: Union[datetime, int, float, str] = None, - resource=None, etag: str = None): + def check_precondition( + self, + last_modified: Union[datetime, int, float, str] = None, + resource=None, + etag: str = None, + ): if not self._if_unmodified_since and not self._if_match: - raise exc.PreconditionRequired(f'Request should provide precondition headers like ' - f'{Header.IF_MATCH} or {Header.IF_UNMODIFIED_SINCE}') + raise exc.PreconditionRequired( + f"Request should provide precondition headers like " + f"{Header.IF_MATCH} or {Header.IF_UNMODIFIED_SINCE}" + ) if last_modified: self.last_modified = last_modified @@ -364,10 +393,14 @@ def check_precondition(self, last_modified: Union[datetime, int, float, str] = N if self.etag and self._if_match: if self.etag != self._if_match: - raise exc.PreconditionFailed(f'Resource is modified: not match {self._if_match}') + raise exc.PreconditionFailed( + f"Resource is modified: not match {self._if_match}" + ) if self.last_modified and self._if_unmodified_since: if self.last_modified > self._if_unmodified_since: - raise exc.PreconditionFailed(f'Resource has been modified since {self._if_unmodified_since}') + raise exc.PreconditionFailed( + f"Resource has been modified since {self._if_unmodified_since}" + ) return True @@ -394,7 +427,7 @@ def response_cache_control(self): return self.NO_CACHE max_age = self.max_age if max_age: - return f'{self.MAX_AGE}={max_age}' + return f"{self.MAX_AGE}={max_age}" return None @response_cache_control.setter @@ -419,13 +452,17 @@ def if_unmodified_since(self): @property def headers(self) -> dict: - return {k: v for k, v in { - Header.ETAG: self.etag, - Header.LAST_MODIFIED: http_time(self.last_modified), - Header.CACHE_CONTROL: self.response_cache_control, - Header.EXPIRES: http_time(self.expiry_datetime), - Header.VARY: self.vary_header - }.items() if v} + return { + k: v + for k, v in { + Header.ETAG: self.etag, + Header.LAST_MODIFIED: http_time(self.last_modified), + Header.CACHE_CONTROL: self.response_cache_control, + Header.EXPIRES: http_time(self.expiry_datetime), + Header.VARY: self.vary_header, + }.items() + if v + } @property def expiry_datetime(self) -> datetime: @@ -437,12 +474,14 @@ def expiry_datetime(self, dt: datetime): def make_from_params(self, result, **func_params): if not self._context_func: - warnings.warn('Cache: no context func is set') + warnings.warn("Cache: no context func is set") return - dumped_key = f'{self._context_func.__name__}(%s)' % self.dump_kwargs(**func_params) + dumped_key = f"{self._context_func.__name__}(%s)" % self.dump_kwargs( + **func_params + ) # use func.__name__ cause when cache is define in API and not in Unit # __ref__ will not contains func name - key = self.encode(dumped_key, ':') + key = self.encode(dumped_key, ":") return self.make_response(result, response_key=key) def make_response(self, result, response_key=None): @@ -452,7 +491,7 @@ def make_response(self, result, response_key=None): from django.http.response import HttpResponseBase - if hasattr(result, '__next__'): + if hasattr(result, "__next__"): return result if isinstance(result, HttpResponseBase): @@ -491,10 +530,11 @@ def make_response(self, result, response_key=None): return result @classmethod - def _normalize_func_kwargs(cls, func, /, *args, **kwargs): # noqa + def _normalize_func_kwargs(cls, func, /, *args, **kwargs): # noqa if not args: return kwargs import inspect + try: # make all args kwargs with it's arg name params = inspect.getcallargs(func, *args, **kwargs) @@ -517,7 +557,7 @@ def _normalize_func_kwargs(cls, func, /, *args, **kwargs): # noqa # pop(params, p) return params - def encode(self, key: VAL, _con: str = '-', _variant=None): + def encode(self, key: VAL, _con: str = "-", _variant=None): return super().encode(key, _con, _variant=_variant or self._variant) def decode(self, key: VAL, _variant=None): @@ -541,16 +581,18 @@ def md(self, name: str, version: str, lang: str): if you just want to cache a function with Cache use self.cache.apply(func, *args, **kwargs) or @api.cache to decorate the target function """ - assert callable(func), f'Cache.apply must apply to a callable function, got {func}' + assert callable( + func + ), f"Cache.apply must apply to a callable function, got {func}" self.check_modified() key = None self._context_func = func if self.cache_response and not self.no_cache: params = self._normalize_func_kwargs(func, *args, **kwargs) - dumped_key = f'{func.__name__}(%s)' % self.dump_kwargs(**params) + dumped_key = f"{func.__name__}(%s)" % self.dump_kwargs(**params) # use func.__name__ cause when cache is define in API and not in Unit # __ref__ will not contains func name - key = self.encode(dumped_key, ':') + key = self.encode(dumped_key, ":") if not self._response_key: self._response_key = key # use another connector @@ -565,7 +607,10 @@ def make_cache(self, request): # keep in the same scope from utilmeta.core.request import Request - assert isinstance(request, Request), f'Invalid request: {request}, must be Request object' + + assert isinstance( + request, Request + ), f"Invalid request: {request}, must be Request object" headers = request.headers modified_since = headers.get(Header.IF_MODIFIED_SINCE) unmodified_since = headers.get(Header.IF_UNMODIFIED_SINCE) @@ -573,23 +618,27 @@ def make_cache(self, request): try: cache._if_modified_since = type_transform(modified_since, datetime) except COMMON_ERRORS as e: - warnings.warn(f'Cache: transform {Header.IF_MODIFIED_SINCE} failed with: {e}') + warnings.warn( + f"Cache: transform {Header.IF_MODIFIED_SINCE} failed with: {e}" + ) if unmodified_since: try: cache._if_unmodified_since = type_transform(unmodified_since, datetime) except COMMON_ERRORS as e: - warnings.warn(f'Cache: transform {Header.IF_UNMODIFIED_SINCE} failed with: {e}') + warnings.warn( + f"Cache: transform {Header.IF_UNMODIFIED_SINCE} failed with: {e}" + ) cache._if_none_match = headers.get(Header.IF_NONE_MATCH) cache._if_match = headers.get(Header.IF_MATCH) cache._variant = cache.get_variant(request) cache._cache_control = headers.get(Header.CACHE_CONTROL) if cache._cache_control: - for derivative in [v.strip() for v in cache._cache_control.split(',')]: + for derivative in [v.strip() for v in cache._cache_control.split(",")]: if derivative.startswith(self.MAX_FRESH): - cache._max_fresh = int(derivative.split('=')[1]) + cache._max_fresh = int(derivative.split("=")[1]) elif derivative.startswith(self.MAX_STALE): - if '=' in derivative: - cache._max_stale = int(derivative.split('=')[1]) + if "=" in derivative: + cache._max_stale = int(derivative.split("=")[1]) else: cache._stale = True cache.init_expiry(request) diff --git a/utilmeta/core/cache/plugins/base.py b/utilmeta/core/cache/plugins/base.py index 97f03c3..cbf043a 100644 --- a/utilmeta/core/cache/plugins/base.py +++ b/utilmeta/core/cache/plugins/base.py @@ -1,5 +1,13 @@ from utilmeta.utils.plugin import PluginBase -from utilmeta.utils import multi, get_interval, map_dict, keys_or_args, time_now, fast_digest, COMMON_TYPES +from utilmeta.utils import ( + multi, + get_interval, + map_dict, + keys_or_args, + time_now, + fast_digest, + COMMON_TYPES, +) from typing import Union, Optional, Dict, Any, List, Type, TYPE_CHECKING from datetime import datetime, timedelta from ..lock import BaseLocker @@ -12,40 +20,46 @@ NUM = Union[int, float] VAL = Union[str, bytes, list, tuple, dict] -__all__ = ['BaseCacheInterface'] +__all__ = ["BaseCacheInterface"] class BaseCacheInterface(PluginBase): - OBSOLETE_LRU = 'LRU' # least recently updated - OBSOLETE_LFU = 'LFU' # least frequently used - OBSOLETE_RANDOM = 'RANDOM' - - def __init__(self, cache_alias: str = 'default', *, - scope_prefix: str = None, - max_entries: int = None, # None means unlimited - max_entries_policy: str = OBSOLETE_LFU, - max_entries_tolerance: int = 0, - # if vary is specified, max_entries is relative to each variant - max_variants: int = None, - max_variants_policy: str = OBSOLETE_LFU, - max_variants_tolerance: int = 0, - trace_keys: bool = None, - default_timeout: Union[int, float, timedelta] = None, - lock_timeout: Union[int, float, timedelta] = None, - lock_blocking_timeout: Union[int, float, timedelta] = None, - entity_cls: Type['CacheEntity'] = None, - document: str = None, - **kwargs - ): + OBSOLETE_LRU = "LRU" # least recently updated + OBSOLETE_LFU = "LFU" # least frequently used + OBSOLETE_RANDOM = "RANDOM" + + def __init__( + self, + cache_alias: str = "default", + *, + scope_prefix: str = None, + max_entries: int = None, # None means unlimited + max_entries_policy: str = OBSOLETE_LFU, + max_entries_tolerance: int = 0, + # if vary is specified, max_entries is relative to each variant + max_variants: int = None, + max_variants_policy: str = OBSOLETE_LFU, + max_variants_tolerance: int = 0, + trace_keys: bool = None, + default_timeout: Union[int, float, timedelta] = None, + lock_timeout: Union[int, float, timedelta] = None, + lock_blocking_timeout: Union[int, float, timedelta] = None, + entity_cls: Type["CacheEntity"] = None, + document: str = None, + **kwargs, + ): super().__init__(locals()) if max_entries: - assert isinstance(max_entries, int) and max_entries >= 0, \ - f'Invalid expose Cache max_entries: {max_entries}' \ - f', must be an int >= 0 (max_entries=0 means no store)' + assert isinstance(max_entries, int) and max_entries >= 0, ( + f"Invalid expose Cache max_entries: {max_entries}" + f", must be an int >= 0 (max_entries=0 means no store)" + ) if scope_prefix is not None: - assert isinstance(scope_prefix, str), f'Cache.scope_prefix must be a str, got {scope_prefix}' + assert isinstance( + scope_prefix, str + ), f"Cache.scope_prefix must be a str, got {scope_prefix}" if max_entries or max_variants: # in order to implement max_entries and max_variants, we must enable trace_keys @@ -61,16 +75,18 @@ def __init__(self, cache_alias: str = 'default', *, self.max_variants_tolerance = max_variants_tolerance from .entity import CacheEntity + if entity_cls: - assert issubclass(entity_cls, CacheEntity), \ - f'Cache.entity_cls must inherit from CacheEntity, got {entity_cls}' + assert issubclass( + entity_cls, CacheEntity + ), f"Cache.entity_cls must inherit from CacheEntity, got {entity_cls}" self.entity_cls = entity_cls or CacheEntity self.default_timeout = get_interval(default_timeout, null=True) self.lock_timeout = get_interval(lock_timeout, null=True) self.lock_blocking_timeout = get_interval(lock_blocking_timeout, null=True) self._scope_prefix = scope_prefix - self._service_prefix = '' + self._service_prefix = "" self._cache_alias = cache_alias @property @@ -92,7 +108,7 @@ def scope_prefix(self) -> str: if self._scope_prefix is not None: # can take empty scope prefix, which will gain access to the global scope return self._scope_prefix - return self.__ref__ or '' + return self.__ref__ or "" @property def variant(self): @@ -105,9 +121,9 @@ def varied(self) -> bool: @property def base_key_prefix(self): # not varied - return '-'.join([v for v in [self._service_prefix, self.scope_prefix] if v]) + return "-".join([v for v in [self._service_prefix, self.scope_prefix] if v]) - def encode(self, key: VAL, _con: str = '-', _variant=None): + def encode(self, key: VAL, _con: str = "-", _variant=None): """ Within the same scope, you can use different connector (1 length) to identify some sub-key-domains currently using @@ -124,14 +140,14 @@ def encode(self, key: VAL, _con: str = '-', _variant=None): else: key = str(key) if not isinstance(_con, str) or len(_con) != 1: - _con = '-' + _con = "-" if _variant: - prefix = f'{self.base_key_prefix}-{_variant}' + prefix = f"{self.base_key_prefix}-{_variant}" else: prefix = self.base_key_prefix if key.startswith(prefix): return key - return f'{prefix}{_con}{key}' + return f"{prefix}{_con}{key}" def decode(self, key: VAL, _variant=None): if isinstance(key, bytes): @@ -143,11 +159,11 @@ def decode(self, key: VAL, _variant=None): else: key = str(key) if _variant: - prefix = f'{self.base_key_prefix}-{_variant}' + prefix = f"{self.base_key_prefix}-{_variant}" else: prefix = self.base_key_prefix if key.startswith(prefix): - return key[len(prefix) + 1:] + return key[len(prefix) + 1 :] return key def __getitem__(self, item): @@ -156,7 +172,7 @@ def __getitem__(self, item): def __setitem__(self, key, value): return self.set(key, value) - def __eq__(self, other: 'BaseCacheInterface'): + def __eq__(self, other: "BaseCacheInterface"): if not isinstance(other, BaseCacheInterface): return False return self.scope_prefix == other.scope_prefix @@ -168,7 +184,9 @@ def get(self, key, default=None): return default return v - def fetch(self, args=None, *keys, named: bool = False) -> Union[list, Dict[str, Any]]: + def fetch( + self, args=None, *keys, named: bool = False + ) -> Union[list, Dict[str, Any]]: keys = keys_or_args(args, *keys) entity = self.get_entity() values = entity.get(*self.encode(keys)) @@ -179,14 +197,26 @@ def get_last_modified(self, args=None, *keys) -> Optional[datetime]: entity = self.get_entity() return entity.last_modified(*self.encode(keys)) - def get_entity(self, readonly: bool = False, variant=None) -> 'CacheEntity': + def get_entity(self, readonly: bool = False, variant=None) -> "CacheEntity": return self.entity_cls(self, readonly=readonly, variant=variant) - def set(self, key, value, *, timeout: Union[int, timedelta, datetime] = ..., - exists_only: bool = False, not_exists_only: bool = False): + def set( + self, + key, + value, + *, + timeout: Union[int, timedelta, datetime] = ..., + exists_only: bool = False, + not_exists_only: bool = False, + ): entity = self.get_entity() - entity.set(self.encode(key), value, timeout=self.get_timeout(timeout), - exists_only=exists_only, not_exists_only=not_exists_only) + entity.set( + self.encode(key), + value, + timeout=self.get_timeout(timeout), + exists_only=exists_only, + not_exists_only=not_exists_only, + ) def pop(self, key): val = self[key] @@ -269,7 +299,9 @@ def expire(self, *keys: str, timeout: float): entity = self.get_entity() entity.expire(*self.encode(keys), timeout=self.get_timeout(timeout)) - def update(self, data: Dict[str, Any], timeout: Union[int, timedelta, datetime] = ...): + def update( + self, data: Dict[str, Any], timeout: Union[int, timedelta, datetime] = ... + ): if not data: return entity = self.get_entity() @@ -283,21 +315,23 @@ def incr(self, key: str, amount: NUM = 1, upper_bound: int = None) -> Optional[N if not amount: return self[key] if amount < 0: - raise ValueError(f'Cache incr amount should > 0, got {amount}') + raise ValueError(f"Cache incr amount should > 0, got {amount}") return self.alter(key=key, amount=amount, limit=upper_bound) def decr(self, key: str, amount: NUM = 1, lower_bound: int = None) -> Optional[NUM]: """ - Decrease key by amount > 0, you can set a low_bound number - return the altered number (float/int) if successfully modified, elsewhere return None + Decrease key by amount > 0, you can set a low_bound number + return the altered number (float/int) if successfully modified, elsewhere return None """ if not amount: return self[key] if amount < 0: - raise ValueError(f'Cache decr amount should > 0, got {amount}') + raise ValueError(f"Cache decr amount should > 0, got {amount}") return self.alter(key=key, amount=-amount, limit=lower_bound) - def alter(self, key: str, amount: Union[int, float], limit: int = None) -> Optional[NUM]: + def alter( + self, key: str, amount: Union[int, float], limit: int = None + ) -> Optional[NUM]: entity = self.get_entity() return entity.alter(self.encode(key), amount=amount, limit=limit) @@ -316,8 +350,10 @@ def get_stats(self): :return: """ if not self.trace_keys: - raise NotImplementedError(f'Cache.get_stats not implemented, ' - f'please set trace_keys=True to enable this method') + raise NotImplementedError( + f"Cache.get_stats not implemented, " + f"please set trace_keys=True to enable this method" + ) requests = 0 total_hits = 0 @@ -338,13 +374,15 @@ def get_stats(self): total_keys += size if last_mod: last_modified_list.append(last_mod) - variant_data.append({ - 'variant': variant, - 'requests': req, - 'hits': hits, - 'size': size, - 'last_modified': last_mod - }) + variant_data.append( + { + "variant": variant, + "requests": req, + "hits": hits, + "size": size, + "last_modified": last_mod, + } + ) if last_modified_list: last_modified = max(last_modified_list) else: @@ -355,14 +393,14 @@ def get_stats(self): last_modified = entity.get_latest_update() return { - 'scope_prefix': self.scope_prefix, - 'max_entries': self.max_entries, - 'max_variants': self.max_variants, - 'requests': requests, - 'hits': total_hits, - 'size': total_keys, - 'last_modified': last_modified, - 'variants': variant_data + "scope_prefix": self.scope_prefix, + "max_entries": self.max_entries, + "max_variants": self.max_variants, + "requests": requests, + "hits": total_hits, + "size": total_keys, + "last_modified": last_modified, + "variants": variant_data, } @classmethod @@ -388,6 +426,7 @@ def dump_kwargs(cls, **kwargs) -> str: dump the args & kwargs of a callable to a comparable string use for API cache """ + def dump_data(data) -> str: if multi(data): if isinstance(data, set): @@ -400,15 +439,16 @@ def dump_data(data) -> str: lst = [] for d in data: lst.append(dump_data(d)) - return '[%s]' % ','.join(lst) + return "[%s]" % ",".join(lst) elif isinstance(data, dict): lst = [] for k in sorted(data.keys()): - lst.append(f'{repr(k)}:{dump_data(data[k])}') - return '{%s}' % ','.join(lst) + lst.append(f"{repr(k)}:{dump_data(data[k])}") + return "{%s}" % ",".join(lst) elif isinstance(data, COMMON_TYPES): return repr(data) return str(data) + # even if args and kwargs are empty, still get a equal length key return fast_digest( dump_data(kwargs), diff --git a/utilmeta/core/cache/plugins/entity.py b/utilmeta/core/cache/plugins/entity.py index 6280cb2..a81c0fe 100644 --- a/utilmeta/core/cache/plugins/entity.py +++ b/utilmeta/core/cache/plugins/entity.py @@ -10,47 +10,49 @@ NUM = Union[int, float] VAL = Union[str, bytes, list, tuple, dict] -__all__ = ['CacheEntity'] +__all__ = ["CacheEntity"] class CacheEntity: backend_name = None @property - def requests_key(self): # total request keys - return self.src.encode('requests', '@', self._assigned_variant) + def requests_key(self): # total request keys + return self.src.encode("requests", "@", self._assigned_variant) # cannot perform atomic update sync in common cache backend # hit ratio = sum(hits) / total @property def hits_key(self): - return self.src.encode('hits', '@', self._assigned_variant) + return self.src.encode("hits", "@", self._assigned_variant) # cannot perform atomic update sync in common cache backend @property def update_key(self): # this store every keys (update in set) - return self.src.encode('update', '@', self._assigned_variant) + return self.src.encode("update", "@", self._assigned_variant) @property def vary_hits_key(self): - return f'{self.src.base_key_prefix}@vary_hits' + return f"{self.src.base_key_prefix}@vary_hits" @property def vary_update_key(self): - return f'{self.src.base_key_prefix}@vary_updates' + return f"{self.src.base_key_prefix}@vary_updates" - def __init__(self, src: 'BaseCacheInterface', variant=None, readonly: bool = False): + def __init__(self, src: "BaseCacheInterface", variant=None, readonly: bool = False): assert isinstance(src, BaseCacheInterface) self.src = src self.readonly = readonly self._assigned_variant = variant def reset_stats(self): - self.cache.fetch({ - self.requests_key: 0, - self.hits_key: {}, - }) + self.cache.fetch( + { + self.requests_key: 0, + self.hits_key: {}, + } + ) @property def variant(self): @@ -83,7 +85,11 @@ def get_latest_update(self): return max(updates.values()) if updates else None def _get_key_data(self, vary: bool = False) -> Tuple[dict, dict]: - tk, uk = (self.vary_hits_key, self.vary_update_key) if vary else (self.hits_key, self.update_key) + tk, uk = ( + (self.vary_hits_key, self.vary_update_key) + if vary + else (self.hits_key, self.update_key) + ) _data = self.cache.fetch(tk, uk) _counts = _data.get(tk) _updates = _data.get(uk) @@ -98,11 +104,12 @@ def _set_key_data(self, counts: dict, updates: dict, vary: bool = False): counts = {} if not isinstance(updates, dict): updates = {} - tk, uk = (self.vary_hits_key, self.vary_update_key) if vary else (self.hits_key, self.update_key) - self.cache.update({ - tk: counts, - uk: updates - }) + tk, uk = ( + (self.vary_hits_key, self.vary_update_key) + if vary + else (self.hits_key, self.update_key) + ) + self.cache.update({tk: counts, uk: updates}) @classmethod def pop_min(cls, data: dict, count: int): @@ -141,7 +148,9 @@ def prepare(self, *keys: str): # counts already set keys excess: int = self.exists(*updates) - self.src.max_entries - if excess > self.src.max_entries_tolerance: # default to 0, but can set a throttle value + if ( + excess > self.src.max_entries_tolerance + ): # default to 0, but can set a throttle value # total_key >= max_entries # delete the least frequently hit key del_keys = [] @@ -174,7 +183,9 @@ def prepare(self, *keys: str): if self.src.max_variants: excess = len(vary_updates) - self.src.max_variants - if excess > self.src.max_variants_tolerance: # default to 0, but can set a throttle value + if ( + excess > self.src.max_variants_tolerance + ): # default to 0, but can set a throttle value # total_key >= max_entries # delete the least frequently hit key del_keys = [] @@ -191,7 +202,7 @@ def update(self, data: Dict[str, Any], timeout: float = unprovided): if not data: return if self.readonly: - warnings.warn(f'Attempt to set val ({data}) to a readonly cache') + warnings.warn(f"Attempt to set val ({data}) to a readonly cache") return if timeout == 0: # will expire ASAP @@ -201,11 +212,20 @@ def update(self, data: Dict[str, Any], timeout: float = unprovided): if not unprovided(timeout): self.cache.expire(*data, timeout=timeout) - def set(self, key: str, val, *, timeout: float = None, - exists_only: bool = False, not_exists_only: bool = False): + def set( + self, + key: str, + val, + *, + timeout: float = None, + exists_only: bool = False, + not_exists_only: bool = False, + ): if self.readonly: - warnings.warn(f'Attempt to set val ({key} -> {repr(val)}) to a readonly cache, ' - f'(maybe from other scope), ignoring...') + warnings.warn( + f"Attempt to set val ({key} -> {repr(val)}) to a readonly cache, " + f"(maybe from other scope), ignoring..." + ) return if timeout == 0: # will expire ASAP @@ -267,8 +287,10 @@ def get(self, *keys: str, single: bool = False): def last_modified(self, *keys: str) -> Optional[datetime]: if not self.src.trace_keys: - raise NotImplementedError(f'Cache.last_modified not implemented, ' - f'please set trace_keys=True to enable this method') + raise NotImplementedError( + f"Cache.last_modified not implemented, " + f"please set trace_keys=True to enable this method" + ) updates = self.cache.get(self.update_key) if isinstance(updates, dict): times = [] @@ -283,8 +305,10 @@ def last_modified(self, *keys: str) -> Optional[datetime]: def keys(self): if not self.src.trace_keys: - raise NotImplementedError(f'Cache.keys not implemented, ' - f'please set trace_keys=True to enable this method') + raise NotImplementedError( + f"Cache.keys not implemented, " + f"please set trace_keys=True to enable this method" + ) keys = self.cache.get(self.update_key) if isinstance(keys, dict): misses = [] @@ -306,8 +330,10 @@ def keys(self): def clear(self): if not self.src.trace_keys: - raise NotImplementedError(f'Cache.clear not implemented, ' - f'please set trace_keys=True to enable this method') + raise NotImplementedError( + f"Cache.clear not implemented, " + f"please set trace_keys=True to enable this method" + ) keys = self.cache.get(self.update_key) or {} self.delete(*keys, self.requests_key, self.hits_key, self.update_key) if self.variant: @@ -321,7 +347,7 @@ def delete(self, *keys: str): if not keys: return if self.readonly: - raise RuntimeError(f'Attempt to delete key ({keys}) at a readonly cache') + raise RuntimeError(f"Attempt to delete key ({keys}) at a readonly cache") self.cache.delete(*keys) # update key metrics if self.src.trace_keys: @@ -337,8 +363,10 @@ def expire(self, *keys: str, timeout: float): def count(self) -> int: if not self.src.trace_keys: - raise NotImplementedError(f'Cache.count not implemented, ' - f'please set trace_keys=True to enable this method') + raise NotImplementedError( + f"Cache.count not implemented, " + f"please set trace_keys=True to enable this method" + ) keys = self.cache.get(self.update_key) if isinstance(keys, dict): # consider timeout @@ -356,7 +384,7 @@ def variants(self): counts = self.cache.get(self.vary_update_key) if not counts: return [] - return list(counts) # noqa + return list(counts) # noqa def alter(self, key: str, amount: Union[int, float], limit: int = None): # cannot perform atomic limitation, only lua script in redis can do that @@ -364,7 +392,9 @@ def alter(self, key: str, amount: Union[int, float], limit: int = None): return None if not amount: return self.get(key, single=True) - self.prepare(key) # still need to prepare, since the command can generate new keys + self.prepare( + key + ) # still need to prepare, since the command can generate new keys if limit is not None: value = self.get(key, single=True) @@ -407,16 +437,16 @@ def alter(self, key: str, amount: Union[int, float], limit: int = None): return res def lock(self, *keys: str, block: bool = False): - raise NotImplementedError(f'{self.__class__} not support lock acquire') + raise NotImplementedError(f"{self.__class__} not support lock acquire") def lpush(self, key: str, *values): - raise NotImplementedError(f'{self.__class__} not support lpush') + raise NotImplementedError(f"{self.__class__} not support lpush") def rpush(self, key: str, *values): - raise NotImplementedError(f'{self.__class__} not support rpush') + raise NotImplementedError(f"{self.__class__} not support rpush") def lpop(self, key: str): - raise NotImplementedError(f'{self.__class__} not support lpop') + raise NotImplementedError(f"{self.__class__} not support lpop") def rpop(self, key: str): - raise NotImplementedError(f'{self.__class__} not support rpop') + raise NotImplementedError(f"{self.__class__} not support rpop") diff --git a/utilmeta/core/cache/plugins/sdk.py b/utilmeta/core/cache/plugins/sdk.py index 79c3c16..8ec55b2 100644 --- a/utilmeta/core/cache/plugins/sdk.py +++ b/utilmeta/core/cache/plugins/sdk.py @@ -8,25 +8,23 @@ from functools import cached_property -__all__ = ['ClientCache'] +__all__ = ["ClientCache"] -NO_CACHE = 'no-cache' -NO_STORE = 'no-store' -NO_TRANSFORM = 'no-transform' -PUBLIC = 'public' -IMMUTABLE = 'immutable' -MAX_AGE = 'max-age' -MAX_STALE = 'max-stale' -MAX_FRESH = 'max-fresh' -PRIVATE = 'private' -MUST_REVALIDATE = 'must-revalidate' +NO_CACHE = "no-cache" +NO_STORE = "no-store" +NO_TRANSFORM = "no-transform" +PUBLIC = "public" +IMMUTABLE = "immutable" +MAX_AGE = "max-age" +MAX_STALE = "max-stale" +MAX_FRESH = "max-fresh" +PRIVATE = "private" +MUST_REVALIDATE = "must-revalidate" class CacheHeaderSchema(Schema): __options__ = Options( - case_insensitive=True, - ignore_required=True, - force_default=None + case_insensitive=True, ignore_required=True, force_default=None ) cache_control: str = Field(alias_from=[Header.CACHE_CONTROL, Header.PRAGMA]) @@ -48,7 +46,7 @@ class CacheHeaderSchema(Schema): def cache_control_derivatives(self): if not self.cache_control: return [] - return [v.strip() for v in self.cache_control.split(',')] + return [v.strip() for v in self.cache_control.split(",")] @property def immutable(self): @@ -65,14 +63,22 @@ def private(self): @property def no_cache(self): # no-cache and no-store all means do not attempt to read from cached response - return NO_CACHE in self.cache_control_derivatives or NO_STORE in self.cache_control_derivatives + return ( + NO_CACHE in self.cache_control_derivatives + or NO_STORE in self.cache_control_derivatives + ) @property def no_store(self): - if self.vary == '*': + if self.vary == "*": # vary for all, this response is not cache-able return True - if not self.cache_control and not self.expires and not self.etag and not self.last_modified: + if ( + not self.cache_control + and not self.expires + and not self.etag + and not self.last_modified + ): # no cache headers is presenting return True return NO_STORE in self.cache_control_derivatives @@ -85,32 +91,34 @@ def must_revalidate(self): def max_age(self) -> Optional[int]: for d in self.cache_control_derivatives: if d.startswith(MAX_AGE): - return int(d.split('=')[1].strip()) + return int(d.split("=")[1].strip()) if self.expires: # fallback to expires - return max(0, int((self.expires - (self.date or time_now())).total_seconds())) + return max( + 0, int((self.expires - (self.date or time_now())).total_seconds()) + ) return None @property def vary_headers(self): if not self.vary: return [] - return [v.strip() for v in self.vary.split(',')] + return [v.strip() for v in self.vary.split(",")] @property def max_stale(self) -> Optional[int]: for d in self.cache_control_derivatives: if d.startswith(MAX_STALE): - if '=' not in d: + if "=" not in d: return -1 - return int(d.split('=')[1].strip()) + return int(d.split("=")[1].strip()) return None @property def max_fresh(self) -> Optional[int]: for d in self.cache_control_derivatives: if d.startswith(MAX_FRESH): - return int(d.split('=')[1].strip()) + return int(d.split("=")[1].strip()) return None @@ -124,9 +132,10 @@ class ClientCache(BaseCacheInterface): * Last-Modified ~ If-Modified-Since * Etag ~ If-None-Match """ + def __init__( self, - cache_alias: str = 'default', + cache_alias: str = "default", scope_prefix: str = None, services_sharing: bool = False, # enable this param will make cache key without service_prefix @@ -145,13 +154,17 @@ def __init__( default_timeout: int = None, excluded_statuses: List[int] = None, included_statuses: List[int] = None, - included_methods: List[str] = ('GET', 'HEAD'), - excluded_hosts: Union[str, List[str]] = None, # do not cache responses from these hosts - included_hosts: Union[str, List[str]] = None, # only cache responses from these hosts + included_methods: List[str] = ("GET", "HEAD"), + excluded_hosts: Union[ + str, List[str] + ] = None, # do not cache responses from these hosts + included_hosts: Union[ + str, List[str] + ] = None, # only cache responses from these hosts ): # use max hosts as max_variants to be part of locals() _locals = dict(locals()) - pop(_locals, 'self') + pop(_locals, "self") super().__init__(**_locals) self.disable_304 = disable_304 @@ -194,7 +207,7 @@ def bypass_request(self, request: Request): return True return False - def bypass_response(self, response: 'Response'): + def bypass_response(self, response: "Response"): if self.bypass_request(response.request): # url not set, cannot cache return True @@ -258,11 +271,13 @@ def process_request(self, request: Request): # modify the If-None-Match header if cached_headers.last_modified: if not headers.if_modified_since: - request.headers.setdefault(Header.IF_MODIFIED_SINCE, cached_headers.last_modified) + request.headers.setdefault( + Header.IF_MODIFIED_SINCE, cached_headers.last_modified + ) # modify the If-Modified-Since header return request - def get_cached_response(self, request: Request) -> Optional['Response']: + def get_cached_response(self, request: Request) -> Optional["Response"]: resp_key = self.get_response_key(request) if not resp_key: return None @@ -282,11 +297,11 @@ def get_response_key(self, request: Request): # not using vary values as cache key # us vary as a validation (to validate whether the cache is still fresh) return self.encode( - key=f'{request.method.lower()}:{request.encoded_path}', - _variant=self.vary_function(request) + key=f"{request.method.lower()}:{request.encoded_path}", + _variant=self.vary_function(request), ) - def process_response(self, response: 'Response'): # hook + def process_response(self, response: "Response"): # hook if self.bypass_response(response): # url not set, cannot cache return response @@ -349,11 +364,7 @@ def process_response(self, response: 'Response'): # hook response.request = None # --- - entity.set( - key=resp_key, - val=response, - timeout=timeout - ) + entity.set(key=resp_key, val=response, timeout=timeout) response.data = data response.raw_response = raw response.request = request diff --git a/utilmeta/core/cli/backends/aiohttp.py b/utilmeta/core/cli/backends/aiohttp.py index f4ef433..5508997 100644 --- a/utilmeta/core/cli/backends/aiohttp.py +++ b/utilmeta/core/cli/backends/aiohttp.py @@ -7,10 +7,14 @@ class AiohttpClientRequestAdaptor(ClientRequestAdaptor): # request: ClientRequest backend = aiohttp - async def __call__(self, timeout: float = None, allow_redirects: bool = None, **kwargs): + async def __call__( + self, timeout: float = None, allow_redirects: bool = None, **kwargs + ): from utilmeta.core.response.backends.aiohttp import AiohttpClientResponseAdaptor - async with aiohttp.ClientSession(timeout=ClientTimeout( - total=float(timeout) if timeout is not None else None)) as session: + + async with aiohttp.ClientSession( + timeout=ClientTimeout(total=float(timeout) if timeout is not None else None) + ) as session: resp = await session.request( method=self.request.method, url=self.request.url, diff --git a/utilmeta/core/cli/backends/base.py b/utilmeta/core/cli/backends/base.py index 18859fc..7168665 100644 --- a/utilmeta/core/cli/backends/base.py +++ b/utilmeta/core/cli/backends/base.py @@ -4,7 +4,7 @@ class ClientRequestAdaptor(BaseAdaptor): @classmethod - def get_module_name(cls, obj: 'Request'): + def get_module_name(cls, obj: "Request"): if isinstance(obj, Request): return super().get_module_name(obj.backend) return super().get_module_name(obj) @@ -13,10 +13,15 @@ def get_module_name(cls, obj: 'Request'): def qualify(cls, obj: Request): if not cls.backend or not obj.backend: return False - return cls.get_module_name(obj.backend).lower() == cls.get_module_name(cls.backend).lower() + return ( + cls.get_module_name(obj.backend).lower() + == cls.get_module_name(cls.backend).lower() + ) def __init__(self, request: Request): self.request = request def __call__(self, **kwargs): - raise NotImplementedError('This request backend does not support calling outbound requests') + raise NotImplementedError( + "This request backend does not support calling outbound requests" + ) diff --git a/utilmeta/core/cli/backends/httpx.py b/utilmeta/core/cli/backends/httpx.py index d94b75e..05980b8 100644 --- a/utilmeta/core/cli/backends/httpx.py +++ b/utilmeta/core/cli/backends/httpx.py @@ -24,13 +24,19 @@ def request_kwargs(self): def __call__(self, timeout: float = None, **kwargs): from utilmeta.core.response.backends.httpx import HttpxClientResponseAdaptor - with httpx.Client(timeout=float(timeout) if timeout is not None else None) as client: + + with httpx.Client( + timeout=float(timeout) if timeout is not None else None + ) as client: resp = client.request(**self.request_kwargs) return HttpxClientResponseAdaptor(resp) @awaitable(__call__) async def __call__(self, timeout: float = None, **kwargs): from utilmeta.core.response.backends.httpx import HttpxClientResponseAdaptor - async with httpx.AsyncClient(timeout=float(timeout) if timeout is not None else None) as client: + + async with httpx.AsyncClient( + timeout=float(timeout) if timeout is not None else None + ) as client: resp = await client.request(**self.request_kwargs) return HttpxClientResponseAdaptor(resp) diff --git a/utilmeta/core/cli/backends/requests.py b/utilmeta/core/cli/backends/requests.py index 3bc36d4..38cbbf7 100644 --- a/utilmeta/core/cli/backends/requests.py +++ b/utilmeta/core/cli/backends/requests.py @@ -5,8 +5,15 @@ class RequestsRequestAdaptor(ClientRequestAdaptor): backend = requests - def __call__(self, timeout: float = None, allow_redirects: bool = None, proxies: dict = None, **kwargs): + def __call__( + self, + timeout: float = None, + allow_redirects: bool = None, + proxies: dict = None, + **kwargs + ): from utilmeta.core.response.backends.requests import RequestsResponseAdaptor + resp = requests.request( method=self.request.method, url=self.request.url, diff --git a/utilmeta/core/cli/backends/urllib.py b/utilmeta/core/cli/backends/urllib.py index f23dbca..fce7c65 100644 --- a/utilmeta/core/cli/backends/urllib.py +++ b/utilmeta/core/cli/backends/urllib.py @@ -17,13 +17,17 @@ class UrllibRequestAdaptor(ClientRequestAdaptor): def __call__(self, timeout: float = None, **kwargs): from utilmeta.core.response.backends.urllib import UrllibResponseAdaptor + try: - resp = urlopen(Request( - url=self.request.url, - method=str(self.request.method).upper(), - data=self.request.body, - headers=self.request.headers, - ), timeout=float(timeout) if timeout is not None else None) + resp = urlopen( + Request( + url=self.request.url, + method=str(self.request.method).upper(), + data=self.request.body, + headers=self.request.headers, + ), + timeout=float(timeout) if timeout is not None else None, + ) except HTTPError as e: resp = e return UrllibResponseAdaptor(resp) diff --git a/utilmeta/core/cli/base.py b/utilmeta/core/cli/base.py index 8f20d0d..627d3f7 100644 --- a/utilmeta/core/cli/base.py +++ b/utilmeta/core/cli/base.py @@ -1,9 +1,18 @@ import inspect -from utilmeta.utils import PluginEvent, PluginTarget, \ - Error, url_join, classonlymethod, json_dumps, \ - COMMON_METHODS, EndpointAttr, valid_url, \ - parse_query_string, parse_query_dict +from utilmeta.utils import ( + PluginEvent, + PluginTarget, + Error, + url_join, + classonlymethod, + json_dumps, + COMMON_METHODS, + EndpointAttr, + valid_url, + parse_query_string, + parse_query_dict, +) from utype.types import * from http.cookies import SimpleCookie @@ -25,25 +34,28 @@ from typing import TypeVar -T = TypeVar('T') +T = TypeVar("T") -setup_class = PluginEvent('setup_class', synchronous_only=True) -process_request = PluginEvent('process_request', streamline_result=True) -handle_error = PluginEvent('handle_error') -process_response = PluginEvent('process_response', streamline_result=True) +setup_class = PluginEvent("setup_class", synchronous_only=True) +process_request = PluginEvent("process_request", streamline_result=True) +handle_error = PluginEvent("handle_error") +process_response = PluginEvent("process_response", streamline_result=True) -def parse_proxies(proxies: Union[str, List[str], Dict[str, str]], scheme=None) -> Dict[str, List[str]]: +def parse_proxies( + proxies: Union[str, List[str], Dict[str, str]], scheme=None +) -> Dict[str, List[str]]: if isinstance(proxies, str): from urllib.parse import urlparse + parsed = urlparse(proxies) if parsed.scheme: return {parsed.scheme: [proxies]} if scheme: - return {scheme: scheme + '://' + proxies} + return {scheme: scheme + "://" + proxies} return { - 'http': ['http://' + proxies], - 'https': ['https://' + proxies], + "http": ["http://" + proxies], + "https": ["https://" + proxies], } elif isinstance(proxies, list): values = {} @@ -93,9 +105,13 @@ class Client(PluginTarget): def __init_subclass__(cls, **kwargs): if not issubclass(cls._request_cls, Request): - raise TypeError(f'Invalid request class: {cls._request_cls}, must be subclass of Request') + raise TypeError( + f"Invalid request class: {cls._request_cls}, must be subclass of Request" + ) if not issubclass(cls._endpoint_cls, ClientEndpoint): - raise TypeError(f'Invalid request class: {cls._endpoint_cls}, must be subclass of ClientEndpoint') + raise TypeError( + f"Invalid request class: {cls._endpoint_cls}, must be subclass of ClientEndpoint" + ) cls._generate_endpoints() setup_class(cls, **kwargs) @@ -108,22 +124,24 @@ def _generate_endpoints(cls): clients = {} for key, api in cls.__annotations__.items(): - if key.startswith('_'): + if key.startswith("_"): continue val = cls.__dict__.get(key) if is_annotated(api): # param: Annotated[str, request.QueryParam()] - api = getattr(api, '__origin__', None) + api = getattr(api, "__origin__", None) if inspect.isclass(api) and issubclass(api, Client): kwargs = dict(route=key, name=key, parent=cls) if not val: - val = getattr(api, '_generator', None) + val = getattr(api, "_generator", None) if isinstance(val, decorator.APIGenerator): kwargs.update(val.kwargs) elif inspect.isfunction(val): - raise TypeError(f'{cls.__name__}: generate route [{repr(key)}] failed: conflict api and endpoint') + raise TypeError( + f"{cls.__name__}: generate route [{repr(key)}] failed: conflict api and endpoint" + ) route = cls._route_cls(api, **kwargs) clients[key] = route @@ -138,11 +156,15 @@ def _generate_endpoints(cls): if method: if hasattr(Client, key): if key.lower() in COMMON_METHODS: - raise TypeError(f'{cls.__name__}: generate route for {repr(key)} failed: HTTP method ' - f'name is reserved for Client class, please use @api.{key.lower()}("/")') + raise TypeError( + f"{cls.__name__}: generate route for {repr(key)} failed: HTTP method " + f'name is reserved for Client class, please use @api.{key.lower()}("/")' + ) else: - raise TypeError(f'{cls.__name__}: generate route for {repr(key)} failed: ' - f'name conflicted with Client method') + raise TypeError( + f"{cls.__name__}: generate route for {repr(key)} failed: " + f"name conflicted with Client method" + ) # a sign to wrap it in Unit # 1. @api.get (method='get') @@ -151,7 +173,9 @@ def _generate_endpoints(cls): # 4. @api(method='CUSTOM') (method='custom') val = cls._endpoint_cls.apply_for(val, cls) elif hook_type: - val = cls._hook_cls.dispatch_for(val, hook_type, target_type='client') + val = cls._hook_cls.dispatch_for( + val, hook_type, target_type="client" + ) else: continue setattr(cls, key, val) # reset value @@ -169,27 +193,26 @@ def _generate_endpoints(cls): cls._clients = clients - def __init__(self, - base_url: Union[str, List[str]] = None, - backend=None, # urllib / requests / aiohttp / httpx - service: Optional[UtilMeta] = None, - mock: bool = False, - internal: bool = False, - plugins: list = (), - - # session=None, # used to pass along the sdk classes - # prepend_route: str = None, - append_slash: bool = None, - - default_timeout: Union[float, int, timedelta] = None, - base_headers: Dict[str, str] = None, - base_cookies: Union[str, Dict[str, str], SimpleCookie] = None, - base_query: Dict[str, Any] = None, - proxies: dict = None, - allow_redirects: bool = None, - charset: str = 'utf-8', - fail_silently: bool = False, - ): + def __init__( + self, + base_url: Union[str, List[str]] = None, + backend=None, # urllib / requests / aiohttp / httpx + service: Optional[UtilMeta] = None, + mock: bool = False, + internal: bool = False, + plugins: list = (), + # session=None, # used to pass along the sdk classes + # prepend_route: str = None, + append_slash: bool = None, + default_timeout: Union[float, int, timedelta] = None, + base_headers: Dict[str, str] = None, + base_cookies: Union[str, Dict[str, str], SimpleCookie] = None, + base_query: Dict[str, Any] = None, + proxies: dict = None, + allow_redirects: bool = None, + charset: str = "utf-8", + fail_silently: bool = False, + ): super().__init__(plugins=plugins) @@ -206,7 +229,9 @@ def __init__(self, elif inspect.ismodule(backend): backend_name = backend.__name__ else: - raise TypeError(f'Invalid backend: {repr(backend)}, must be a module or str') + raise TypeError( + f"Invalid backend: {repr(backend)}, must be a module or str" + ) self._backend_name = backend_name self._backend = backend @@ -221,17 +246,15 @@ def __init__(self, res = urlsplit(base_url) if not res.scheme: # allow ws / wss in the future - raise ValueError(f'utilmeta.core.cli.Client: Invalid base_url: {repr(base_url)}, ' - f'must be a valid url') + raise ValueError( + f"utilmeta.core.cli.Client: Invalid base_url: {repr(base_url)}, " + f"must be a valid url" + ) if res.query: self._base_query.update(parse_query_string(res.query)) - base_url = urlunsplit(( - res.scheme, - res.netloc, - res.path, - '', # query - '' # fragment - )) + base_url = urlunsplit( + (res.scheme, res.netloc, res.path, "", "") # query # fragment + ) self._base_url = base_url self._proxies = proxies @@ -249,12 +272,12 @@ def __init__(self, # includes BaseCookie cookies cookies = SimpleCookie(cookies) elif cookies: - raise TypeError(f'Invalid cookies: {cookies}, must be str or dict') + raise TypeError(f"Invalid cookies: {cookies}, must be str or dict") else: cookies = SimpleCookie() for _key, _val in self._base_headers.items(): - if _key.lower() == 'cookie': + if _key.lower() == "cookie": cookies.update(SimpleCookie(_val)) break @@ -266,7 +289,7 @@ def __init__(self, self._original_headers = dict(self._base_headers) self._original_query = dict(self._base_query) - self._client_route: Optional['ClientRoute'] = None + self._client_route: Optional["ClientRoute"] = None for key, val in self.__class__.__dict__.items(): if isinstance(val, ClientEndpoint): @@ -279,8 +302,7 @@ def __init__(self, client_cls = client_route.handler params = dict(params) params.update( - base_url=client_base_url, - plugins=self._plugins # inject plugins + base_url=client_base_url, plugins=self._plugins # inject plugins ) client = client_cls(**params) client._client_route = client_route.merge_hooks(self._client_route) @@ -309,12 +331,12 @@ def request_cls(self): return self._request_cls @property - def client_route(self) -> 'ClientRoute': + def client_route(self) -> "ClientRoute": return self._client_route @classonlymethod def __reproduce_with__(cls, generator: decorator.APIGenerator): - plugins = generator.kwargs.get('plugins') + plugins = generator.kwargs.get("plugins") if plugins: cls._add_plugins(*plugins) cls._generator = generator @@ -327,7 +349,7 @@ def get_client_params(self): service=self._service, base_headers=self._base_headers, base_query=self._base_query, - base_cookies=self._cookies, # use cookies as base_cookies to pass session to sub client + base_cookies=self._cookies, # use cookies as base_cookies to pass session to sub client append_slash=self._append_slash, allow_redirects=self._allow_redirects, proxies=self._proxies, @@ -368,24 +390,20 @@ def _build_url(self, path: str, query: dict = None): # base_url: null # path: https://origin.com/path?key=value - base_url = self._base_url or (self._service.base_url if self._service else '') + base_url = self._base_url or (self._service.base_url if self._service else "") if parsed.scheme: # ignore base url - url = urlunsplit(( - parsed.scheme, - parsed.netloc, - parsed.path, - '', # query - '' # fragment - )) + url = urlunsplit( + (parsed.scheme, parsed.netloc, parsed.path, "", "") # query # fragment + ) else: url = url_join(base_url, parsed.path) if self._append_slash: - url = url.rstrip('/') + '/' + url = url.rstrip("/") + "/" - return url + (('?' + urlencode(query_params)) if query_params else '') + return url + (("?" + urlencode(query_params)) if query_params else "") def _build_headers(self, headers, cookies=None): if cookies: @@ -406,34 +424,28 @@ def _build_headers(self, headers, cookies=None): _headers[key] = value if isinstance(_cookies, SimpleCookie) and _cookies: - _headers['cookie'] = ';'.join([f'{key}={val.value}' for key, val in _cookies.items() if val.value]) + _headers["cookie"] = ";".join( + [f"{key}={val.value}" for key, val in _cookies.items() if val.value] + ) return _headers - def _build_request(self, - method: str, - path: str = None, - query: dict = None, - data=None, - # form: dict = None, - headers: dict = None, - cookies=None): - url = self._build_url( - path=path, - query=query - ) - headers = self._build_headers( - headers=headers, - cookies=cookies - ) + def _build_request( + self, + method: str, + path: str = None, + query: dict = None, + data=None, + # form: dict = None, + headers: dict = None, + cookies=None, + ): + url = self._build_url(path=path, query=query) + headers = self._build_headers(headers=headers, cookies=cookies) # if content_type: # headers.setdefault('content-type', content_type) return self._request_cls( - method=method, - url=url, - data=data, - headers=headers, - backend=self._backend + method=method, url=url, data=data, headers=headers, backend=self._backend ) def __request__(self, endpoint: ClientEndpoint, request: Request): @@ -445,11 +457,12 @@ def __request__(self, endpoint: ClientEndpoint, request: Request): def make_request(req: Request = request): return endpoint.parse_response( - self._make_request(req), - fail_silently=self._fail_silently + self._make_request(req), fail_silently=self._fail_silently ) - handler = self._chain_cls(self, endpoint).build_client_handler(make_request, asynchronous=False) + handler = self._chain_cls(self, endpoint).build_client_handler( + make_request, asynchronous=False + ) return handler(request) async def __async_request__(self, endpoint: ClientEndpoint, request: Request): @@ -461,11 +474,12 @@ async def __async_request__(self, endpoint: ClientEndpoint, request: Request): async def make_request(req: Request = request): return endpoint.parse_response( - await self._make_async_request(req), - fail_silently=self._fail_silently + await self._make_async_request(req), fail_silently=self._fail_silently ) - handler = self._chain_cls(self, endpoint).build_client_handler(make_request, asynchronous=True) + handler = self._chain_cls(self, endpoint).build_client_handler( + make_request, asynchronous=True + ) return await handler(request) def _make_request(self, request: Request, timeout: int = None) -> Response: @@ -475,35 +489,31 @@ def _make_request(self, request: Request, timeout: int = None) -> Response: from utilmeta import service root_api = service.resolve() - request.adaptor.route = request.path.strip('/') + request.adaptor.route = request.path.strip("/") try: response = root_api(request)() except Exception as e: - response = getattr(root_api, 'response', Response)(error=e, request=request) + response = getattr(root_api, "response", Response)( + error=e, request=request + ) else: adaptor: ClientRequestAdaptor = ClientRequestAdaptor.dispatch(request) if timeout is None: - timeout = request.adaptor.get_context('timeout') # slot + timeout = request.adaptor.get_context("timeout") # slot if timeout is None: timeout = self._default_timeout if timeout is not None: timeout = float(timeout) try: - resp = adaptor( - timeout=timeout, - allow_redirects=self._allow_redirects - ) + resp = adaptor(timeout=timeout, allow_redirects=self._allow_redirects) except Exception as e: if not self._fail_silently: raise e from e - timeout = 'timeout' in str(e).lower() + timeout = "timeout" in str(e).lower() response = Response( - timeout=timeout, - error=e, - request=request, - aborted=True + timeout=timeout, error=e, request=request, aborted=True ) else: response = Response(response=resp, request=request) @@ -514,50 +524,58 @@ def _make_request(self, request: Request, timeout: int = None) -> Response: return response - async def _make_async_request(self, request: Request, timeout: int = None) -> Response: + async def _make_async_request( + self, request: Request, timeout: int = None + ) -> Response: if self._internal: service = self._service if not service: from utilmeta import service root_api = service.resolve() - request.adaptor.route = request.path.strip('/') + request.adaptor.route = request.path.strip("/") try: response = root_api(request)() if inspect.isawaitable(response): response = await response except Exception as e: - response = getattr(root_api, 'response', Response)(error=e, request=request) + response = getattr(root_api, "response", Response)( + error=e, request=request + ) else: adaptor: ClientRequestAdaptor = ClientRequestAdaptor.dispatch(request) if timeout is None: - timeout = request.adaptor.get_context('timeout') # slot + timeout = request.adaptor.get_context("timeout") # slot try: resp = adaptor( timeout=timeout or self._default_timeout, - allow_redirects=self._allow_redirects + allow_redirects=self._allow_redirects, ) if inspect.isawaitable(resp): resp = await resp except Exception as e: if not self._fail_silently: raise e from e - timeout = 'timeout' in str(e).lower() + timeout = "timeout" in str(e).lower() response = Response( - error=e, - request=request, - timeout=timeout, - aborted=True + error=e, request=request, timeout=timeout, aborted=True ) else: response = Response(response=resp, request=request) return response - def request(self, method: str, path: str = None, query: dict = None, - data=None, - headers: dict = None, cookies=None, timeout: int = None) -> Response: + def request( + self, + method: str, + path: str = None, + query: dict = None, + data=None, + headers: dict = None, + cookies=None, + timeout: int = None, + ) -> Response: request = self._build_request( method=method, @@ -566,250 +584,284 @@ def request(self, method: str, path: str = None, query: dict = None, data=data, # form=form, headers=headers, - cookies=cookies + cookies=cookies, ) return self._make_request(request, timeout=timeout) - async def async_request(self, method: str, path: str = None, query: dict = None, - data=None, - headers: dict = None, cookies=None, - timeout: int = None) -> Response: + async def async_request( + self, + method: str, + path: str = None, + query: dict = None, + data=None, + headers: dict = None, + cookies=None, + timeout: int = None, + ) -> Response: request = self._build_request( method=method, path=path, query=query, data=data, headers=headers, - cookies=cookies + cookies=cookies, ) return await self._make_async_request(request, timeout=timeout) - def get(self, - path: str = None, - query: dict = None, - headers: dict = None, - cookies=None, - timeout: int = None): + def get( + self, + path: str = None, + query: dict = None, + headers: dict = None, + cookies=None, + timeout: int = None, + ): return self.request( - method='GET', + method="GET", path=path, query=query, headers=headers, cookies=cookies, - timeout=timeout + timeout=timeout, ) - async def async_get(self, - path: str = None, - query: dict = None, - headers: dict = None, - cookies=None, - timeout: int = None): + async def async_get( + self, + path: str = None, + query: dict = None, + headers: dict = None, + cookies=None, + timeout: int = None, + ): return await self.async_request( - method='GET', + method="GET", path=path, query=query, headers=headers, cookies=cookies, - timeout=timeout + timeout=timeout, ) - def post(self, - path: str = None, - query: dict = None, - data=None, - headers: dict = None, - cookies=None, - timeout: int = None): + def post( + self, + path: str = None, + query: dict = None, + data=None, + headers: dict = None, + cookies=None, + timeout: int = None, + ): return self.request( - method='POST', + method="POST", path=path, query=query, data=data, headers=headers, cookies=cookies, - timeout=timeout + timeout=timeout, ) - async def async_post(self, - path: str = None, - query: dict = None, - data=None, - headers: dict = None, - cookies=None, - timeout: int = None): + async def async_post( + self, + path: str = None, + query: dict = None, + data=None, + headers: dict = None, + cookies=None, + timeout: int = None, + ): return await self.async_request( - method='POST', + method="POST", path=path, query=query, data=data, headers=headers, cookies=cookies, - timeout=timeout + timeout=timeout, ) - def put(self, - path: str = None, - query: dict = None, - data=None, - headers: dict = None, - cookies=None, - timeout: int = None): + def put( + self, + path: str = None, + query: dict = None, + data=None, + headers: dict = None, + cookies=None, + timeout: int = None, + ): return self.request( - method='PUT', + method="PUT", path=path, query=query, data=data, headers=headers, cookies=cookies, - timeout=timeout + timeout=timeout, ) - async def async_put(self, - path: str = None, - query: dict = None, - data=None, - headers: dict = None, - cookies=None, - timeout: int = None): + async def async_put( + self, + path: str = None, + query: dict = None, + data=None, + headers: dict = None, + cookies=None, + timeout: int = None, + ): return await self.async_request( - method='PUT', + method="PUT", path=path, query=query, data=data, headers=headers, cookies=cookies, - timeout=timeout + timeout=timeout, ) - def patch(self, - path: str = None, - query: dict = None, - data=None, - headers: dict = None, - cookies=None, - timeout: int = None): + def patch( + self, + path: str = None, + query: dict = None, + data=None, + headers: dict = None, + cookies=None, + timeout: int = None, + ): return self.request( - method='PATCH', + method="PATCH", path=path, query=query, data=data, headers=headers, cookies=cookies, - timeout=timeout + timeout=timeout, ) - async def async_patch(self, - path: str = None, - query: dict = None, - data=None, - headers: dict = None, - cookies=None, - timeout: int = None): + async def async_patch( + self, + path: str = None, + query: dict = None, + data=None, + headers: dict = None, + cookies=None, + timeout: int = None, + ): return await self.async_request( - method='PATCH', + method="PATCH", path=path, query=query, data=data, headers=headers, cookies=cookies, - timeout=timeout + timeout=timeout, ) - def delete(self, - path: str = None, - query: dict = None, - headers: dict = None, - cookies=None, - timeout: int = None): + def delete( + self, + path: str = None, + query: dict = None, + headers: dict = None, + cookies=None, + timeout: int = None, + ): return self.request( - method='DELETE', + method="DELETE", path=path, query=query, headers=headers, cookies=cookies, - timeout=timeout + timeout=timeout, ) - async def async_delete(self, - path: str = None, - query: dict = None, - headers: dict = None, - cookies=None, - timeout: int = None): + async def async_delete( + self, + path: str = None, + query: dict = None, + headers: dict = None, + cookies=None, + timeout: int = None, + ): return await self.async_request( - method='DELETE', + method="DELETE", path=path, query=query, headers=headers, cookies=cookies, - timeout=timeout + timeout=timeout, ) - def options(self, - path: str = None, - query: dict = None, - headers: dict = None, - cookies=None, - timeout: int = None): + def options( + self, + path: str = None, + query: dict = None, + headers: dict = None, + cookies=None, + timeout: int = None, + ): return self.request( - method='OPTIONS', + method="OPTIONS", path=path, query=query, headers=headers, cookies=cookies, - timeout=timeout + timeout=timeout, ) - async def async_options(self, - path: str = None, - query: dict = None, - headers: dict = None, - cookies=None, - timeout: int = None): + async def async_options( + self, + path: str = None, + query: dict = None, + headers: dict = None, + cookies=None, + timeout: int = None, + ): return await self.async_request( - method='OPTIONS', + method="OPTIONS", path=path, query=query, headers=headers, cookies=cookies, - timeout=timeout + timeout=timeout, ) - def head(self, - path: str = None, - query: dict = None, - headers: dict = None, - cookies=None, - timeout: int = None): + def head( + self, + path: str = None, + query: dict = None, + headers: dict = None, + cookies=None, + timeout: int = None, + ): return self.request( - method='HEAD', + method="HEAD", path=path, query=query, headers=headers, cookies=cookies, - timeout=timeout + timeout=timeout, ) - async def async_head(self, - path: str = None, - query: dict = None, - headers: dict = None, - cookies=None, - timeout: int = None): + async def async_head( + self, + path: str = None, + query: dict = None, + headers: dict = None, + cookies=None, + timeout: int = None, + ): return self.request( - method='HEAD', + method="HEAD", path=path, query=query, headers=headers, cookies=cookies, - timeout=timeout + timeout=timeout, ) def process_request(self, request: Request): pass - def process_response(self, response: Response): # noqa : meant to be inherited + def process_response(self, response: Response): # noqa : meant to be inherited """ Process response can also be treated as a hook (callback) to handle non-blocking requests diff --git a/utilmeta/core/cli/chain.py b/utilmeta/core/cli/chain.py index 851331c..151f16e 100644 --- a/utilmeta/core/cli/chain.py +++ b/utilmeta/core/cli/chain.py @@ -3,7 +3,11 @@ from typing import Callable from utilmeta.core.request import Request from utilmeta.core.response import Response -from utilmeta.core.api.plugins.base import process_request, process_response, handle_error +from utilmeta.core.api.plugins.base import ( + process_request, + process_response, + handle_error, +) import inspect from functools import wraps from utilmeta.utils import Error, exceptions @@ -12,10 +16,11 @@ class ClientChainBuilder(BaseChainBuilder): def __init__(self, client, endpoint: ClientEndpoint): from .base import Client + if not isinstance(client, Client): - raise TypeError(f'Invalid Client: {client}') + raise TypeError(f"Invalid Client: {client}") if not isinstance(endpoint, ClientEndpoint): - raise TypeError(f'Invalid client endpoint: {endpoint}') + raise TypeError(f"Invalid client endpoint: {endpoint}") super().__init__(endpoint, client) self.client = client @@ -29,8 +34,7 @@ def idempotent(self): def parse_response(self, resp): resp = self.endpoint.parse_response( - resp, - fail_silently=self.client.fail_silently + resp, fail_silently=self.client.fail_silently ) if resp.cookies: # update response cookies @@ -49,8 +53,7 @@ async def async_client_handler( while True: try: request.adaptor.update_context( - retry_index=retry_index, - idempotent=self.idempotent + retry_index=retry_index, idempotent=self.idempotent ) req = request if request_handler: @@ -64,10 +67,7 @@ async def async_client_handler( res = response if response_handler: - res = await self.async_process( - response, - response_handler - ) + res = await self.async_process(response, response_handler) if isinstance(res, Request): request = res else: @@ -92,7 +92,9 @@ async def async_client_handler( retry_index += 1 if retry_index >= self.pref.client_max_retry_loops: - raise exceptions.MaxRetriesExceed(max_retries=self.pref.client_max_retry_loops) + raise exceptions.MaxRetriesExceed( + max_retries=self.pref.client_max_retry_loops + ) return res def client_handler( @@ -107,8 +109,7 @@ def client_handler( while True: try: request.adaptor.update_context( - retry_index=retry_index, - idempotent=self.idempotent + retry_index=retry_index, idempotent=self.idempotent ) req = request if request_handler: @@ -120,10 +121,7 @@ def client_handler( res = response if response_handler: - res = self.process( - response, - response_handler - ) + res = self.process(response, response_handler) if isinstance(res, Request): request = res else: @@ -148,7 +146,9 @@ def client_handler( retry_index += 1 if retry_index >= self.pref.client_max_retry_loops: - raise exceptions.MaxRetriesExceed(max_retries=self.pref.client_max_retry_loops) + raise exceptions.MaxRetriesExceed( + max_retries=self.pref.client_max_retry_loops + ) return res def chain_client_handler( @@ -157,37 +157,47 @@ def chain_client_handler( request_handler=None, response_handler=None, error_handler=None, - asynchronous: bool = None + asynchronous: bool = None, ): if not any([request_handler, response_handler, error_handler]): return handler if asynchronous: + @wraps(handler) async def wrapper(request: Request): return await self.async_client_handler( - request, handler, + request, + handler, request_handler=request_handler, response_handler=response_handler, - error_handler=error_handler + error_handler=error_handler, ) + else: + @wraps(handler) def wrapper(request: Request): return self.client_handler( - request, handler, + request, + handler, request_handler=request_handler, response_handler=response_handler, - error_handler=error_handler + error_handler=error_handler, ) + return wrapper def build_client_handler(self, handler, asynchronous: bool = None): # --- if asynchronous is None: - asynchronous = inspect.iscoroutinefunction(handler) or inspect.isasyncgenfunction(handler) + asynchronous = inspect.iscoroutinefunction( + handler + ) or inspect.isasyncgenfunction(handler) for request_handler, response_handler, error_handler in self.chain_plugins( - process_request, process_response, handle_error, + process_request, + process_response, + handle_error, required=False, asynchronous=asynchronous, ): @@ -196,7 +206,7 @@ def build_client_handler(self, handler, asynchronous: bool = None): request_handler=request_handler, response_handler=response_handler, error_handler=error_handler, - asynchronous=asynchronous + asynchronous=asynchronous, ) if self.endpoint.client_wrap: # most outer @@ -205,6 +215,6 @@ def build_client_handler(self, handler, asynchronous: bool = None): request_handler=self.client.process_request, response_handler=self.client.process_response, error_handler=self.client.handle_error, - asynchronous=asynchronous + asynchronous=asynchronous, ) return handler diff --git a/utilmeta/core/cli/endpoint.py b/utilmeta/core/cli/endpoint.py index e5f4498..5087db3 100644 --- a/utilmeta/core/cli/endpoint.py +++ b/utilmeta/core/cli/endpoint.py @@ -2,6 +2,7 @@ from utilmeta.utils import exceptions as exc import inspect from utilmeta.core.api.endpoint import BaseEndpoint + # from utilmeta.core.response import Response # from utype.parser.rule import LogicalType @@ -24,7 +25,7 @@ def prop_is(prop: properties.Property, ident): def prop_in(prop: properties.Property, ident): if not prop.__in__: return False - in_ident = getattr(prop.__in__, '__ident__', None) + in_ident = getattr(prop.__in__, "__ident__", None) if in_ident: return in_ident == ident return prop.__in__ == ident @@ -33,13 +34,13 @@ def prop_in(prop: properties.Property, ident): class ClientRoute(BaseRoute): def __init__( self, - handler: Union[Type['Client'], 'ClientEndpoint'], + handler: Union[Type["Client"], "ClientEndpoint"], route: str, name: str, parent=None, before_hooks: List[ClientBeforeHook] = (), after_hooks: List[ClientAfterHook] = (), - error_hooks: Dict[Type[Exception], ClientErrorHook] = None + error_hooks: Dict[Type[Exception], ClientErrorHook] = None, ): super().__init__( handler, @@ -48,7 +49,7 @@ def __init__( parent=parent, before_hooks=before_hooks, after_hooks=after_hooks, - error_hooks=error_hooks + error_hooks=error_hooks, ) @@ -59,8 +60,8 @@ class ClientEndpoint(BaseEndpoint): error_cls = Error @classmethod - def apply_for(cls, func: Callable, client: Type['Client'] = None): - _cls = getattr(func, 'cls', None) + def apply_for(cls, func: Callable, client: Type["Client"] = None): + _cls = getattr(func, "cls", None) _async = inspect.iscoroutinefunction(func) or inspect.isasyncgenfunction(func) if not _cls or not issubclass(_cls, ClientEndpoint): # override current class @@ -82,20 +83,19 @@ def apply_for(cls, func: Callable, client: Type['Client'] = None): kwargs.update(client=client) return _cls(func, **kwargs) - def __init__(self, f: Callable, *, - client: Type['Client'] = None, - method: str, - plugins: list = None, - idempotent: bool = None, - eager: bool = False - ): + def __init__( + self, + f: Callable, + *, + client: Type["Client"] = None, + method: str, + plugins: list = None, + idempotent: bool = None, + eager: bool = False, + ): super().__init__( - f, - plugins=plugins, - method=method, - idempotent=idempotent, - eager=eager + f, plugins=plugins, method=method, idempotent=idempotent, eager=eager ) # self.is_async = self.parser.is_asynchronous self.client = client @@ -106,11 +106,13 @@ def __init__(self, f: Callable, *, ) self.client_wrap = False if self.client: - self.client_wrap = not all([ - function_pass(self.client.process_request), - function_pass(self.client.process_response), - function_pass(self.client.handle_error), - ]) + self.client_wrap = not all( + [ + function_pass(self.client.process_request), + function_pass(self.client.process_response), + function_pass(self.client.handle_error), + ] + ) self.path_args = self.PATH_REGEX.findall(self.route) # if self.parser.is_asynchronous: @@ -121,12 +123,12 @@ def __init__(self, f: Callable, *, @property def ref(self) -> str: if self.client: - return f'{self.client.__ref__}.{self.f.__name__}' + return f"{self.client.__ref__}.{self.f.__name__}" if self.module_name: - return f'{self.module_name}.{self.f.__name__}' + return f"{self.module_name}.{self.f.__name__}" return self.f.__name__ - def __call__(self, client: 'Client', /, *args, **kwargs): + def __call__(self, client: "Client", /, *args, **kwargs): if not self.is_passed: return self.executor(client, *args, **kwargs) if self.parser.is_asynchronous: @@ -134,20 +136,26 @@ def __call__(self, client: 'Client', /, *args, **kwargs): else: return client.__request__(self, *args, **kwargs) - def build_request(self, client: 'Client', /, *args, **kwargs) -> Request: + def build_request(self, client: "Client", /, *args, **kwargs) -> Request: # get Call object from kwargs - args, kwargs = self.parser.parse_params(args, kwargs, context=self.parser.options.make_context()) + args, kwargs = self.parser.parse_params( + args, kwargs, context=self.parser.options.make_context() + ) for i, arg in enumerate(args): kwargs[self.parser.pos_key_map[i]] = arg client_params = client.get_client_params() try: url = utils.url_join( - client_params.base_url or '', self.route, append_slash=client_params.append_slash + client_params.base_url or "", + self.route, + append_slash=client_params.append_slash, ) except Exception as e: - raise e.__class__(f'utilmeta.core.cli.Client: build request url with base_url:' - f' {repr(client_params.base_url)} and route: {repr(self.route)} failed: {e}') from e + raise e.__class__( + f"utilmeta.core.cli.Client: build request url with base_url:" + f" {repr(client_params.base_url)} and route: {repr(self.route)} failed: {e}" + ) from e query = dict(client_params.base_query or {}) headers = dict(client_params.base_headers or {}) @@ -172,23 +180,23 @@ def build_request(self, client: 'Client', /, *args, **kwargs) -> Request: if not prop: continue - if prop_in(prop, 'path'): + if prop_in(prop, "path"): # PathParam path_params[key] = value - elif prop_in(prop, 'query'): + elif prop_in(prop, "query"): # QueryParam query[key] = value - elif prop_is(prop, 'query'): + elif prop_is(prop, "query"): # Query if isinstance(value, Mapping): query.update(value) - elif prop_in(prop, 'body'): + elif prop_in(prop, "body"): # BodyParam if isinstance(body, dict): body[key] = value else: body = {key: value} - elif prop_is(prop, 'body'): + elif prop_is(prop, "body"): # Body if isinstance(body, dict) and isinstance(value, Mapping): body.update(value) @@ -196,30 +204,38 @@ def build_request(self, client: 'Client', /, *args, **kwargs) -> Request: body = value if isinstance(prop, properties.Body): if prop.content_type: - headers.update({'content-type': prop.content_type}) - elif prop_in(prop, 'header'): + headers.update({"content-type": prop.content_type}) + elif prop_in(prop, "header"): # HeaderParam headers[key] = value - elif prop_is(prop, 'header'): + elif prop_is(prop, "header"): # Headers if isinstance(value, Mapping): headers.update(value) - elif prop_in(prop, 'cookie'): + elif prop_in(prop, "cookie"): # CookieParam cookies[key] = value - elif prop_is(prop, 'cookie'): + elif prop_is(prop, "cookie"): # Cookies if isinstance(value, Mapping): cookies.update(value) for key, val in path_params.items(): - unit = '{%s}' % key + unit = "{%s}" % key url = url.replace(unit, str(val)) if isinstance(cookies, SimpleCookie) and cookies: - headers.update({ - 'cookie': ';'.join([f'{key}={val.value}' for key, val in cookies.items() if val.value]) - }) + headers.update( + { + "cookie": ";".join( + [ + f"{key}={val.value}" + for key, val in cookies.items() + if val.value + ] + ) + } + ) return client.request_cls( method=self.method, @@ -227,10 +243,12 @@ def build_request(self, client: 'Client', /, *args, **kwargs) -> Request: query=query, data=body, headers=headers, - backend=client_params.backend + backend=client_params.backend, ) - def parse_response(self, response: Response, fail_silently: bool = False) -> Response: + def parse_response( + self, response: Response, fail_silently: bool = False + ) -> Response: if not isinstance(response, Response): response = Response(response) @@ -254,7 +272,7 @@ def parse_response(self, response: Response, fail_silently: bool = False) -> Res continue try: return response_cls(response=response, strict=True) - except Exception as e: # noqa + except Exception as e: # noqa if i == len(self.response_types) - 1 and not fail_silently: raise e continue @@ -265,7 +283,7 @@ def parse_response(self, response: Response, fail_silently: bool = False) -> Res class SyncClientEndpoint(ClientEndpoint): ASYNCHRONOUS = False - def __call__(self, client: 'Client', *args, **kwargs): + def __call__(self, client: "Client", *args, **kwargs): with self.client_route.merge_hooks(client.client_route) as route: r = None request = None @@ -278,7 +296,7 @@ def __call__(self, client: 'Client', *args, **kwargs): if not self.is_passed: r = self.executor(client, *args, **kwargs) if inspect.isawaitable(r): - raise exc.ServerError('awaitable detected in sync function') + raise exc.ServerError("awaitable detected in sync function") if r is None: r = client.__request__(self, request) @@ -288,7 +306,9 @@ def __call__(self, client: 'Client', *args, **kwargs): except Exception as e: error = self.error_cls(e, request=request) - hook = error.get_hook(route.error_hooks, exact=isinstance(error.exception, exc.Redirect)) + hook = error.get_hook( + route.error_hooks, exact=isinstance(error.exception, exc.Redirect) + ) # hook applied before handel_error plugin event if hook: r = hook(self, error) @@ -301,7 +321,7 @@ def __call__(self, client: 'Client', *args, **kwargs): class AsyncClientEndpoint(ClientEndpoint): ASYNCHRONOUS = True - async def __call__(self, client: 'Client', *args, **kwargs): + async def __call__(self, client: "Client", *args, **kwargs): # async with self: with self.client_route.merge_hooks(client.client_route) as route: r = None @@ -329,7 +349,9 @@ async def __call__(self, client: 'Client', *args, **kwargs): except Exception as e: error = self.error_cls(e, request=request) - hook = error.get_hook(route.error_hooks, exact=isinstance(error.exception, exc.Redirect)) + hook = error.get_hook( + route.error_hooks, exact=isinstance(error.exception, exc.Redirect) + ) # hook applied before handel_error plugin event if hook: r = hook(self, error) diff --git a/utilmeta/core/cli/hook.py b/utilmeta/core/cli/hook.py index bfbdb6b..a209b96 100644 --- a/utilmeta/core/cli/hook.py +++ b/utilmeta/core/cli/hook.py @@ -4,16 +4,16 @@ class ClientBeforeHook(BeforeHook): - target_type = 'client' + target_type = "client" - def serve(self, client, /, request: 'Request' = None): + def serve(self, client, /, request: "Request" = None): if not request: return args, kwargs = self.parse_request(request) return self(client, request, *args, **kwargs) @utils.awaitable(serve) - async def serve(self, client, /, request: 'Request' = None): + async def serve(self, client, /, request: "Request" = None): if not request: return args, kwargs = await self.parse_request(request) @@ -21,8 +21,8 @@ async def serve(self, client, /, request: 'Request' = None): class ClientAfterHook(AfterHook): - target_type = 'client' + target_type = "client" class ClientErrorHook(ErrorHook): - target_type = 'client' + target_type = "client" diff --git a/utilmeta/core/cli/specs/base.py b/utilmeta/core/cli/specs/base.py index 6adda30..c971e10 100644 --- a/utilmeta/core/cli/specs/base.py +++ b/utilmeta/core/cli/specs/base.py @@ -25,7 +25,7 @@ def __call__(self, file=None, console: bool = False): file_path = file if not os.path.isabs(file): file_path = os.path.join(os.getcwd(), file) - with open(file, 'w', encoding='utf-8') as f: + with open(file, "w", encoding="utf-8") as f: f.write(content) return file_path elif console: diff --git a/utilmeta/core/cli/specs/openapi.py b/utilmeta/core/cli/specs/openapi.py index 6f23ef1..10afcac 100644 --- a/utilmeta/core/cli/specs/openapi.py +++ b/utilmeta/core/cli/specs/openapi.py @@ -11,7 +11,13 @@ from utype.specs.json_schema import JsonSchemaParser, JsonSchemaGroupParser from utype.specs.python import PythonCodeGenerator from utype.parser.rule import LogicalType, Rule -from utilmeta.utils import valid_url, HTTP_METHODS_LOWER, valid_attr, time_now, json_dumps +from utilmeta.utils import ( + valid_url, + HTTP_METHODS_LOWER, + valid_attr, + time_now, + json_dumps, +) import json from typing import Tuple, List, Union, Optional, Type from utilmeta.core import request @@ -19,20 +25,20 @@ def tab_for(content: str, tabs: int = 1) -> str: - return '\n'.join([f'%s{line}' % ('\t' * tabs) for line in content.splitlines()]) + return "\n".join([f"%s{line}" % ("\t" * tabs) for line in content.splitlines()]) class OpenAPIClientGenerator(BaseClientGenerator): - __spec__ = 'openapi' - __version__ = '3.1.0' - FORMATS = ['json', 'yaml'] - PARAMS_IN = ['path', 'query', 'header', 'cookie'] + __spec__ = "openapi" + __version__ = "3.1.0" + FORMATS = ["json", "yaml"] + PARAMS_IN = ["path", "query", "header", "cookie"] PARAMS_MAP = { - 'path': request.PathParam, - 'query': request.QueryParam, - 'header': request.HeaderParam, - 'cookie': request.CookieParam, - 'body': request.BodyParam, + "path": request.PathParam, + "query": request.QueryParam, + "header": request.HeaderParam, + "cookie": request.CookieParam, + "body": request.BodyParam, } # None -> dict @@ -45,50 +51,59 @@ class OpenAPIClientGenerator(BaseClientGenerator): schema_group_parser_cls = JsonSchemaGroupParser python_generator_cls = PythonCodeGenerator - NON_NAME_REG = '[^A-Za-z0-9]+' - JSON = 'application/json' + NON_NAME_REG = "[^A-Za-z0-9]+" + JSON = "application/json" - ref_prefix = '#/components' - schema_ref_prefix = '#/components/schemas' - response_ref_prefix = '#/components/responses' - schema_def_prefix = 'schemas' - response_def_prefix = 'responses' + ref_prefix = "#/components" + schema_ref_prefix = "#/components/schemas" + response_ref_prefix = "#/components/responses" + schema_def_prefix = "schemas" + response_def_prefix = "responses" - client_class_name = 'APIClient' + client_class_name = "APIClient" IMPORTS = """from utilmeta.core import api, cli, response, request import utype from utype.types import * """ - def __init__(self, document: dict, - space_ident: bool = False, - black_format: bool = False, - split_body_params: bool = False, - ): + def __init__( + self, + document: dict, + space_ident: bool = False, + black_format: bool = False, + split_body_params: bool = False, + ): if not isinstance(document, dict) or not document.get(self.__spec__): - raise ValueError(f'Invalid openapi document: {document}') + raise ValueError(f"Invalid openapi document: {document}") super().__init__(document) try: self.openapi = OpenAPISchema(document) except utype.exc.ParseError as e: - raise e.__class__(f'Invalid openapi document: {e}') from e - self.ref_prefix = self.ref_prefix.rstrip('/') + '/' - self.schema_ref_prefix = self.schema_ref_prefix.rstrip('/') + '/' - self.response_ref_prefix = self.response_ref_prefix.rstrip('/') + '/' - self.schema_def_prefix = (self.schema_def_prefix.rstrip('.') + '.') if self.schema_def_prefix else '' - self.response_def_prefix = (self.response_def_prefix.rstrip('.') + '.') if self.response_def_prefix else '' + raise e.__class__(f"Invalid openapi document: {e}") from e + self.ref_prefix = self.ref_prefix.rstrip("/") + "/" + self.schema_ref_prefix = self.schema_ref_prefix.rstrip("/") + "/" + self.response_ref_prefix = self.response_ref_prefix.rstrip("/") + "/" + self.schema_def_prefix = ( + (self.schema_def_prefix.rstrip(".") + ".") if self.schema_def_prefix else "" + ) + self.response_def_prefix = ( + (self.response_def_prefix.rstrip(".") + ".") + if self.response_def_prefix + else "" + ) self.schema_refs = dict(self.openapi.components.schemas) self.responses_refs = dict(self.openapi.components.responses) self.space_ident = space_ident self.black_format = black_format self.split_body_params = split_body_params - def get_schema_parser(self, - json_schema: dict, - name: str = None, - description: str = None, - force_forward_ref: bool = True, - ): + def get_schema_parser( + self, + json_schema: dict, + name: str = None, + description: str = None, + force_forward_ref: bool = True, + ): return self.schema_parser_cls( json_schema, name=name, @@ -96,7 +111,7 @@ def get_schema_parser(self, refs=self.schema_refs, ref_prefix=self.schema_ref_prefix, def_prefix=self.schema_def_prefix, - force_forward_ref=force_forward_ref + force_forward_ref=force_forward_ref, ) def get_code_parser(self, t): @@ -104,13 +119,13 @@ def get_code_parser(self, t): def get_def_name(self, ref: str) -> str: if ref.startswith(self.schema_ref_prefix): - ref = ref[len(self.schema_ref_prefix):] + ref = ref[len(self.schema_ref_prefix) :] ref_name = self.get_param_name(ref) return self.schema_def_prefix + ref_name def get_response_def_name(self, ref: str) -> str: if ref.startswith(self.response_ref_prefix): - ref = ref[len(self.response_ref_prefix):] + ref = ref[len(self.response_ref_prefix) :] ref_name = self.get_param_name(ref) return self.response_def_prefix + ref_name @@ -118,16 +133,16 @@ def register_response_ref(self, name: str, schema: dict) -> str: i = 1 cls_name = name while name in self.responses_refs: - name = f'{cls_name}_{i}' + name = f"{cls_name}_{i}" i += 1 self.responses_refs[name] = schema return self.get_response_def_name(name) def get_ref_object(self, ref: str) -> Optional[dict]: if ref.startswith(self.ref_prefix): - ref = ref[len(self.ref_prefix):] + ref = ref[len(self.ref_prefix) :] - ref_routes = ref.strip('/').split('/') + ref_routes = ref.strip("/").split("/") obj = self.openapi.components for route in ref_routes: if not obj: @@ -142,18 +157,22 @@ def generate_from(cls, url_or_file: str): pass else: file_path = url_or_file - content = open(file_path, 'r').read() - if file_path.endswith(',yml') or file_path.endswith('.yaml'): + content = open(file_path, "r").read() + if file_path.endswith(",yml") or file_path.endswith(".yaml"): from utilmeta.utils import requires - requires(yaml='pyyaml', install_when_require=True) + + requires(yaml="pyyaml", install_when_require=True) import yaml + document = yaml.safe_load(content) else: # try to load with json try: document = json.loads(content) except json.decoder.JSONDecodeError as e: - raise ValueError(f'Invalid openapi document at {repr(file_path)}: {e}') from e + raise ValueError( + f"Invalid openapi document at {repr(file_path)}: {e}" + ) from e return cls(document) def generate(self): @@ -162,6 +181,7 @@ def generate(self): responses_content = self.generate_responses() schemas_content = self.generate_schemas() from utilmeta import __version__ + content = f"""# Generated by UtilMeta {__version__} on {str(time_now().strftime("%Y-%m-%d %H:%M"))} # generator spec: {self.__spec__} {self.__version__} # generator class: utilmeta.core.cli.specs.openapi.OpenAPIClientGenerator @@ -177,7 +197,7 @@ def generate(self): ) """ if self.space_ident: - content.replace('\t', ' ' * 4) + content.replace("\t", " " * 4) if self.black_format: try: import black @@ -189,7 +209,7 @@ def generate(self): @classmethod def represent_data(cls, data): - return json_dumps(data, sort_keys=True, indent='\t') + return json_dumps(data, sort_keys=True, indent="\t") def generate_paths(self): operations = [] @@ -197,98 +217,113 @@ def generate_paths(self): methods = self.get_schema(methods) if not methods: continue - summary = methods.get('summary') - description = methods.get('description') - path_parameters = methods.get('parameters') + summary = methods.get("summary") + description = methods.get("description") + path_parameters = methods.get("parameters") for method, operation in methods.items(): if str(method).lower() in HTTP_METHODS_LOWER: operations.append( - tab_for(self.generate_path_item( - method=method, - path=path, - operation=operation, - summary=summary, - description=description, - parameters=path_parameters - ), tabs=1) + '\n' + tab_for( + self.generate_path_item( + method=method, + path=path, + operation=operation, + summary=summary, + description=description, + parameters=path_parameters, + ), + tabs=1, + ) + + "\n" ) - client_lines = ['class APIClient(cli.Client):'] + client_lines = ["class APIClient(cli.Client):"] if self.openapi.info: - client_lines.append(tab_for(f'__info__ = {self.represent_data(self.openapi.info)}', tabs=1)) + client_lines.append( + tab_for(f"__info__ = {self.represent_data(self.openapi.info)}", tabs=1) + ) if self.openapi.servers: - client_lines.append(tab_for(f'__servers__ = {self.represent_data(self.openapi.servers)}', tabs=1)) + client_lines.append( + tab_for( + f"__servers__ = {self.represent_data(self.openapi.servers)}", tabs=1 + ) + ) client_lines.extend(operations) - return '\n'.join(client_lines) + return "\n".join(client_lines) def generate_schemas(self): schemas = self.schema_refs if not schemas: - return '' + return "" schemas_parser = self.schema_group_parser_cls( schemas=schemas, ref_prefix=self.schema_ref_prefix, - def_prefix=self.schema_def_prefix + def_prefix=self.schema_def_prefix, ) schemas_refs = schemas_parser.parse() - group_lines = ['class schemas:'] + group_lines = ["class schemas:"] for ref, schema_cls in schemas_refs.items(): schema_content = self.get_code_parser(schema_cls)() - group_lines.append(tab_for(schema_content, tabs=1) + '\n') - return '\n'.join(group_lines) + group_lines.append(tab_for(schema_content, tabs=1) + "\n") + return "\n".join(group_lines) def generate_responses(self): responses = self.responses_refs if not responses: - return '' - group_lines = ['class responses:\n'] + return "" + group_lines = ["class responses:\n"] for name, response in responses.items(): response_content = self.generate_response(response, name=name) group_lines.append(tab_for(response_content, tabs=1)) - group_lines.append('\n\n') - return ''.join(group_lines) + group_lines.append("\n\n") + return "".join(group_lines) def get_headers_schema(self, headers: dict, name: str): - return self.get_schema_parser({ - 'type': 'object', - 'properties': {key: val.get('schema') for key, val in headers.items()}, - 'required': [key for key in headers.keys() if headers[key].get('required')] - }, + return self.get_schema_parser( + { + "type": "object", + "properties": {key: val.get("schema") for key, val in headers.items()}, + "required": [ + key for key in headers.keys() if headers[key].get("required") + ], + }, name=name, - force_forward_ref=False + force_forward_ref=False, )() def get_schema(self, schema: dict): if not schema: return {} - ref = schema.get('$ref') + ref = schema.get("$ref") if ref: return self.get_ref_object(ref) return schema def generate_response(self, response: dict, name: str): - content = response.get('content') or {} - headers = response.get('headers') or {} - description = response.get('description') or '' - response_name = response.get('x-response-name') or '' + content = response.get("content") or {} + headers = response.get("headers") or {} + description = response.get("description") or "" + response_name = response.get("x-response-name") or "" headers_content = None headers_annotation = None - resp_name = re.sub(self.NON_NAME_REG, '_', name).strip('_') + resp_name = re.sub(self.NON_NAME_REG, "_", name).strip("_") if headers: - headers_schema = self.get_headers_schema(headers, name=resp_name + 'Headers') + headers_schema = self.get_headers_schema( + headers, name=resp_name + "Headers" + ) headers_content = self.get_code_parser(headers_schema)() headers_annotation = headers_schema.__name__ # single_content = len(content) == 1 resp_lines = [ - f'class {resp_name}(response.Response):', + f"class {resp_name}(response.Response):", ] if description: resp_lines.append(f'\t"""{description}"""') if response_name: - resp_lines.append(f'\tname = {repr(str(response_name))}') + resp_lines.append(f"\tname = {repr(str(response_name))}") result_annotations = [] result_contents = [] @@ -297,32 +332,32 @@ def generate_response(self, response: dict, name: str): for content_type, content_obj in content.items(): if not content_obj or not isinstance(content_obj, dict): continue - content_schema = content_obj.get('schema') or {} - content_description = content_obj.get('description') + content_schema = content_obj.get("schema") or {} + content_description = content_obj.get("description") result_schema = content_schema if self.JSON in content_type: # HANDLE RESPONSE KEYS schema = dict(self.get_schema(content_schema)) - if schema.get('type') == 'object': - props = schema.get('properties') or {} - result_key = schema.get('x-response-result-key') or '' - message_key = schema.get('x-response-message-key') or '' - state_key = schema.get('x-response-state-key') or '' - count_key = schema.get('x-response-count-key') or '' + if schema.get("type") == "object": + props = schema.get("properties") or {} + result_key = schema.get("x-response-result-key") or "" + message_key = schema.get("x-response-message-key") or "" + state_key = schema.get("x-response-state-key") or "" + count_key = schema.get("x-response-count-key") or "" if result_key: result_schema = props.get(result_key) - resp_lines.append(f'\tresult_key = {repr(str(result_key))}') + resp_lines.append(f"\tresult_key = {repr(str(result_key))}") if message_key: - resp_lines.append(f'\tmessage_key = {repr(str(message_key))}') + resp_lines.append(f"\tmessage_key = {repr(str(message_key))}") if state_key: - resp_lines.append(f'\tstate_key = {repr(str(state_key))}') + resp_lines.append(f"\tstate_key = {repr(str(state_key))}") if count_key: - resp_lines.append(f'\tcount_key = {repr(str(count_key))}') + resp_lines.append(f"\tcount_key = {repr(str(count_key))}") result_annotation, schema_contents = self.get_schema_annotations( result_schema, - name=resp_name + 'Result', - description=content_description + name=resp_name + "Result", + description=content_description, ) if result_annotation: result_annotations.append(result_annotation) @@ -331,43 +366,46 @@ def generate_response(self, response: dict, name: str): content_types.append(content_type) if headers_content: - resp_lines.append(tab_for(headers_content, tabs=1) + '\n') + resp_lines.append(tab_for(headers_content, tabs=1) + "\n") if result_contents: for result_content in result_contents: - resp_lines.append(tab_for(result_content, tabs=1) + '\n') + resp_lines.append(tab_for(result_content, tabs=1) + "\n") if len(content_types) == 1: - resp_lines.append(f'\tcontent_type = {repr(content_types[0])}') + resp_lines.append(f"\tcontent_type = {repr(content_types[0])}") if headers_annotation: - resp_lines.append(f'\theaders: {headers_annotation}') + resp_lines.append(f"\theaders: {headers_annotation}") if result_annotations: - result_annotation = result_annotations[0] if len(result_annotations) == 1 \ - else f'Union[%s]' % (', '.join(result_annotations)) - resp_lines.append(f'\tresult: {result_annotation}') + result_annotation = ( + result_annotations[0] + if len(result_annotations) == 1 + else f"Union[%s]" % (", ".join(result_annotations)) + ) + resp_lines.append(f"\tresult: {result_annotation}") if len(resp_lines) == 1: - resp_lines.append('\tpass') - return '\n'.join(resp_lines) + resp_lines.append("\tpass") + return "\n".join(resp_lines) - def get_schema_annotations(self, json_schema: dict, - name: str = None, - description: str = None) -> Tuple[str, List[str]]: + def get_schema_annotations( + self, json_schema: dict, name: str = None, description: str = None + ) -> Tuple[str, List[str]]: if not json_schema: - return '', [] + return "", [] # Union[X1, X2] # Optional[X1] # List[...] # Tuple[...] # Dict[...] # ClassName - ref = json_schema.get('$ref') + ref = json_schema.get("$ref") if ref: return repr(self.get_def_name(ref)), [] schema = self.get_schema_parser( json_schema=json_schema, name=self.get_param_name(name), description=description, - force_forward_ref=False + force_forward_ref=False, )() parser = self.get_code_parser(schema) @@ -380,18 +418,22 @@ def get_schema_annotations(self, json_schema: dict, if not args: args = [schema] - annotation = parser.generate_for_type(schema, with_constraints=False, annotation=True) + annotation = parser.generate_for_type( + schema, with_constraints=False, annotation=True + ) schema_contents = [ parser.generate_for_type(arg, with_constraints=True, annotation=False) - for arg in args if self.required_generate(arg) + for arg in args + if self.required_generate(arg) ] return annotation, schema_contents @classmethod def required_generate(cls, t): - parser = getattr(t, '__parser__', None) + parser = getattr(t, "__parser__", None) if parser: from utype.parser.cls import ClassParser + if isinstance(parser, ClassParser): return True elif isinstance(t, LogicalType) and issubclass(t, Rule): @@ -405,36 +447,38 @@ def required_generate(cls, t): @classmethod def get_param_name(cls, name: str, excludes: list = None): - name = re.sub(cls.NON_NAME_REG, '_', name).strip('_') + name = re.sub(cls.NON_NAME_REG, "_", name).strip("_") if keyword.iskeyword(name): - name += '_value' + name += "_value" if excludes: i = 1 origin = name while name in excludes: - name = f'{origin}_{i}' + name = f"{origin}_{i}" i += 1 return name - def get_parameter(self, param: dict, excludes: list = None) -> Optional[inspect.Parameter]: + def get_parameter( + self, param: dict, excludes: list = None + ) -> Optional[inspect.Parameter]: param: dict = self.get_schema(param) if not param: return None - name = param.get('name') - _in = param.get('in') - description = param.get('description') - required = param.get('required') - schema = self.get_schema(param.get('schema')) + name = param.get("name") + _in = param.get("in") + description = param.get("description") + required = param.get("required") + schema = self.get_schema(param.get("schema")) param_cls = self.PARAMS_MAP.get(_in) parser = self.get_schema_parser(schema) - attname = schema.get('x-var-name') or name + attname = schema.get("x-var-name") or name alias = None if not valid_attr(attname) or excludes and attname in excludes: attname = self.get_param_name(attname, excludes) if attname != name: alias = name - elif _in == 'header': + elif _in == "header": alias = name field_type, field = parser.parse_field( @@ -453,46 +497,46 @@ def get_parameter(self, param: dict, excludes: list = None) -> Optional[inspect. name=attname, kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=field_type, - default=field + default=field, ) @classmethod def get_body_param_name(cls, content_type: str, excludes: List[str] = None): - name = 'body' - if 'json' in content_type: - name = 'data' - elif 'form' in content_type: - name = 'form' - elif 'html' in content_type: - name = 'html' - elif 'xml' in content_type: - name = 'xml' - elif 'text' in content_type: - name = 'text' - elif 'stream' in content_type: - name = 'file' - elif 'image' in content_type: - name = 'image' - elif 'audio' in content_type: - name = 'audio' - elif 'video' in content_type: - name = 'video' + name = "body" + if "json" in content_type: + name = "data" + elif "form" in content_type: + name = "form" + elif "html" in content_type: + name = "html" + elif "xml" in content_type: + name = "xml" + elif "text" in content_type: + name = "text" + elif "stream" in content_type: + name = "file" + elif "image" in content_type: + name = "image" + elif "audio" in content_type: + name = "audio" + elif "video" in content_type: + name = "video" while excludes and name in excludes: - if name == 'data': - name = 'json' + if name == "data": + name = "json" else: - name = name + '_data' + name = name + "_data" return name - def get_body_parameters(self, body: dict, - endpoint_name: str = None, - excludes: List[str] = None) -> List[inspect.Parameter]: + def get_body_parameters( + self, body: dict, endpoint_name: str = None, excludes: List[str] = None + ) -> List[inspect.Parameter]: body = self.get_schema(body) if not body: return [] - body_required = body.get('required') - body_content = body.get('content') or {} - body_description = body.get('description') + body_required = body.get("required") + body_content = body.get("content") or {} + body_description = body.get("description") excludes = list(excludes or []) body_params = [] @@ -502,17 +546,23 @@ def get_body_parameters(self, body: dict, continue param_name = self.get_body_param_name(content_type, excludes=excludes) - schema_name = ''.join([v.capitalize() for v in endpoint_name.split('_')] + [param_name.capitalize()]) \ - if endpoint_name else param_name.capitalize() - content_example = content.get('example') - content_schema = content.get('schema') or {} + schema_name = ( + "".join( + [v.capitalize() for v in endpoint_name.split("_")] + + [param_name.capitalize()] + ) + if endpoint_name + else param_name.capitalize() + ) + content_example = content.get("example") + content_schema = content.get("schema") or {} if self.split_body_params and self.JSON in content_type: # if content_schema is $ref, we will directly use Body instead of split - if content_schema.get('type') == 'object': - schema_props = content_schema.get('properties') + if content_schema.get("type") == "object": + schema_props = content_schema.get("properties") if schema_props: - schema_required = content_schema.get('required') or [] + schema_required = content_schema.get("required") or [] for key, prop in schema_props.items(): schema_key_name = schema_name + str(key).capitalize() prop_parser = self.get_schema_parser( @@ -525,16 +575,18 @@ def get_body_parameters(self, body: dict, description=body_description, required=field_required, field_cls=request.BodyParam, - name=schema_key_name + name=schema_key_name, ) attname = self.get_param_name(key, excludes=excludes) excludes.append(attname) - body_params.append(inspect.Parameter( - name=attname, - kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, - annotation=field_type, - default=field - )) + body_params.append( + inspect.Parameter( + name=attname, + kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=field_type, + default=field, + ) + ) continue excludes.append(param_name) parser = self.get_schema_parser( @@ -548,17 +600,19 @@ def get_body_parameters(self, body: dict, example=content_example, field_cls=request.Body, content_type=content_type, - name=schema_name + name=schema_name, ) if not field.__spec_kwargs__: field = request.Body - body_params.append(inspect.Parameter( - name=param_name, - kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, - annotation=field_type, - default=field - )) + body_params.append( + inspect.Parameter( + name=param_name, + kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=field_type, + default=field, + ) + ) return body_params @@ -618,18 +672,22 @@ def get_responses_annotation(self, responses: dict, endpoint_name: str = None): # response_args.append(response_cls) if not resp: continue - ref = resp.get('$ref') + ref = resp.get("$ref") if ref: resp_def_name = self.get_response_def_name(ref) else: - resp_name = ''.join([v.capitalize() for v in endpoint_name.split('_')] + ['Response']) \ - if endpoint_name else 'Response' - resp_def_name = self.register_response_ref( - name=resp_name, schema=resp + resp_name = ( + "".join( + [v.capitalize() for v in endpoint_name.split("_")] + + ["Response"] + ) + if endpoint_name + else "Response" ) - suffix = f'[{name}]' if name != 'default' else '' + resp_def_name = self.register_response_ref(name=resp_name, schema=resp) + suffix = f"[{name}]" if name != "default" else "" response_args.append(ForwardRef(resp_def_name + suffix)) - response_args.append(ForwardRef('response.Response')) + response_args.append(ForwardRef("response.Response")) # if not response_args: # return None if len(response_args) == 1: @@ -637,20 +695,22 @@ def get_responses_annotation(self, responses: dict, endpoint_name: str = None): return Union[tuple(response_args)] def get_operation_function(self, operation: dict, path_parameters: list = None): - func_name = operation.get('operationId') - parameters = operation.get('parameters') or [] - body = operation.get('requestBody') - responses = operation.get('responses') - - func_parameters = [inspect.Parameter( - name='self', - kind=inspect.Parameter.POSITIONAL_ONLY, - )] + func_name = operation.get("operationId") + parameters = operation.get("parameters") or [] + body = operation.get("requestBody") + responses = operation.get("responses") + + func_parameters = [ + inspect.Parameter( + name="self", + kind=inspect.Parameter.POSITIONAL_ONLY, + ) + ] if path_parameters: parameters.extend(path_parameters) - func_args = ['self'] + func_args = ["self"] if parameters: for param in parameters: func_param = self.get_parameter(param, excludes=func_args) @@ -663,7 +723,7 @@ def get_operation_function(self, operation: dict, path_parameters: list = None): body_parameters = self.get_body_parameters( body, endpoint_name=func_name, - excludes=[param.name for param in func_parameters] + excludes=[param.name for param in func_parameters], ) func_parameters.extend(body_parameters) @@ -673,28 +733,31 @@ def f() -> return_annotation: pass f.__name__ = func_name - f.__qualname__ = f'{self.client_class_name}.{func_name}' + f.__qualname__ = f"{self.client_class_name}.{func_name}" f.__signature__ = inspect.signature(f).replace( parameters=func_parameters, ) return f - def generate_path_item(self, - method: str, - path: str, - operation: dict, - summary: str = None, - description: str = None, - parameters: list = None, - ): + def generate_path_item( + self, + method: str, + path: str, + operation: dict, + summary: str = None, + description: str = None, + parameters: list = None, + ): api_kwargs = [] - for key in ['tags', 'summary', 'description', 'deprecated', 'security']: + for key in ["tags", "summary", "description", "deprecated", "security"]: val = operation.get(key) or locals().get(key) if val is not None: - api_kwargs.append(f'{key}={repr(operation[key])}') - api_kwargs_str = ', '.join(api_kwargs) - decorator = f'@api.{method}({repr(path)}%s)' % ((', ' + api_kwargs_str) if api_kwargs else '') + api_kwargs.append(f"{key}={repr(operation[key])}") + api_kwargs_str = ", ".join(api_kwargs) + decorator = f"@api.{method}({repr(path)}%s)" % ( + (", " + api_kwargs_str) if api_kwargs else "" + ) func = self.get_operation_function(operation, path_parameters=parameters) func_content = self.get_code_parser(func)() # todo: get schemas - return decorator + '\n' + func_content + '\n' + return decorator + "\n" + func_content + "\n" diff --git a/utilmeta/core/file/backends/base.py b/utilmeta/core/file/backends/base.py index b62ad82..50cb657 100644 --- a/utilmeta/core/file/backends/base.py +++ b/utilmeta/core/file/backends/base.py @@ -12,12 +12,15 @@ def __init__(self, file): def get_module_name(cls, obj): from io import BytesIO, TextIOWrapper, BufferedRandom, BufferedReader from utilmeta.core.response.base import Response, ResponseAdaptor + if isinstance(obj, BytesIO): - return 'bytesio' + return "bytesio" elif isinstance(obj, (BufferedReader, BufferedRandom, TextIOWrapper)): - return 'fileio' - elif isinstance(obj, (Response, ResponseAdaptor)) or Response.response_like(obj): - return 'response' + return "fileio" + elif isinstance(obj, (Response, ResponseAdaptor)) or Response.response_like( + obj + ): + return "response" return super().get_module_name(obj) def get_object(self): @@ -61,8 +64,10 @@ def save(self, path: str, name: str = None): file_path = os.path.join(file_path, name) else: if os.path.isdir(file_path): - raise PermissionError(f'Attempt to write file to directory: {file_path}') - with open(file_path, 'wb') as fp: + raise PermissionError( + f"Attempt to write file to directory: {file_path}" + ) + with open(file_path, "wb") as fp: if self.seekable: self.object.seek(0) content = self.object.read() @@ -81,7 +86,7 @@ async def asave(self, path: str, name: str = None): if os.path.isdir(file_path): file_path = os.path.join(file_path, name) - with open(file_path, 'wb') as fp: + with open(file_path, "wb") as fp: if self.seekable: r = self.object.seek(0) if inspect.isawaitable(r): @@ -96,8 +101,8 @@ async def asave(self, path: str, name: str = None): return file_path def close(self): - if hasattr(self.object, 'close'): + if hasattr(self.object, "close"): try: self.object.close() - except Exception: # noqa + except Exception: # noqa pass diff --git a/utilmeta/core/file/backends/response.py b/utilmeta/core/file/backends/response.py index dfbaf62..1bb727a 100644 --- a/utilmeta/core/file/backends/response.py +++ b/utilmeta/core/file/backends/response.py @@ -28,6 +28,7 @@ def get_object(self): if file: return file.file from io import BytesIO + return BytesIO(self.file.body) @property diff --git a/utilmeta/core/file/base.py b/utilmeta/core/file/base.py index 6b65f1b..60c0f6c 100644 --- a/utilmeta/core/file/base.py +++ b/utilmeta/core/file/base.py @@ -6,7 +6,7 @@ from utilmeta.utils import file_like from pathlib import Path -__all__ = ['File', 'Image', 'Audio', 'Video', 'FileType'] +__all__ = ["File", "Image", "Audio", "Video", "FileType"] class InvalidFileType(UnprocessableEntity): @@ -15,8 +15,8 @@ class InvalidFileType(UnprocessableEntity): class File: file: BytesIO - format = 'binary' - accept = '*/*' + format = "binary" + accept = "*/*" # FOR JSON SCHEMA encoding = property(lambda self: self.file.encoding) @@ -58,7 +58,7 @@ def _make_file_like(self, value): return value if isinstance(value, (bytes, memoryview, bytearray)): return BytesIO(value) - charset = self.charset or 'utf-8' + charset = self.charset or "utf-8" if isinstance(value, str): return BytesIO(value.encode(charset)) # Handle non-string types. @@ -70,7 +70,7 @@ def validate(self): @property def closed(self): - return not self.file or getattr(self.file, 'closed', None) + return not self.file or getattr(self.file, "closed", None) def close(self): self.adaptor.close() @@ -124,71 +124,78 @@ def size(self) -> int: @property def suffix(self) -> str: - if '.' in self.filename: - return str(self.filename.split('.')[-1]).lower() + if "." in self.filename: + return str(self.filename.split(".")[-1]).lower() type = self.content_type if not type: - return '' - if '/' in type: - return str(type.split('/')[1]).lower() + return "" + if "/" in type: + return str(type.split("/")[1]).lower() return type.lower() @property def is_image(self): - return self.content_type.startswith('image') + return self.content_type.startswith("image") @property def is_audio(self): - return self.content_type.startswith('audio') + return self.content_type.startswith("audio") @property def is_video(self): - return self.content_type.startswith('video') + return self.content_type.startswith("video") class Image(File): - accept = 'image/*' + accept = "image/*" def validate(self): if not self.content_type or not self.is_image: - raise InvalidFileType(f'Invalid file type: {repr(self.content_type)}, image expected') + raise InvalidFileType( + f"Invalid file type: {repr(self.content_type)}, image expected" + ) def get_image(self): from PIL import Image, ImageOps + return ImageOps.exif_transpose(Image.open(self.file)) class Audio(File): - accept = 'audio/*' + accept = "audio/*" def validate(self): if not self.content_type or not self.is_audio: - raise InvalidFileType(f'Invalid file type: {repr(self.content_type)}, audio expected') + raise InvalidFileType( + f"Invalid file type: {repr(self.content_type)}, audio expected" + ) class Video(File): - accept = 'video/*' + accept = "video/*" def validate(self): if not self.content_type or not self.is_video: - raise InvalidFileType(f'Invalid file type: {repr(self.content_type)}, video expected') + raise InvalidFileType( + f"Invalid file type: {repr(self.content_type)}, video expected" + ) def FileType(content_type: str): - if '/' in content_type: - content_class, suffix = content_type.split('/') + if "/" in content_type: + content_class, suffix = content_type.split("/") else: content_class, suffix = content_type, None class FileCls(File): def validate_type(self): if self.content_type: - if '/' in self.content_type: - cc, suf = self.content_type.split('/') - if content_class != '*': + if "/" in self.content_type: + cc, suf = self.content_type.split("/") + if content_class != "*": if content_class != cc: return False - if suffix not in (None, '*'): + if suffix not in (None, "*"): if suf != suffix: return False return True @@ -196,7 +203,9 @@ def validate_type(self): def validate(self): if not self.validate_type(): - raise InvalidFileType(f'Invalid file type: {repr(self.content_type)}, video expected') + raise InvalidFileType( + f"Invalid file type: {repr(self.content_type)}, video expected" + ) return FileCls @@ -206,7 +215,7 @@ def transform_file(transformer, file, cls: Type[File]): if isinstance(file, (list, tuple)) and file: file = file[0] if file is None: - raise TypeError('Invalid file: None') + raise TypeError("Invalid file: None") if isinstance(file, cls): return cls(file.adaptor) return cls(file) diff --git a/utilmeta/core/orm/__init__.py b/utilmeta/core/orm/__init__.py index 52e3049..06f1a12 100644 --- a/utilmeta/core/orm/__init__.py +++ b/utilmeta/core/orm/__init__.py @@ -1,4 +1,5 @@ from .plugins.atomic import AtomicPlugin as Atomic + # from .plugins.relate import Relate from .fields import * from .schema import Schema, Query @@ -12,7 +13,8 @@ from utype import Options -W = Options(mode='w', override=True) -WP = Options(mode='w', ignore_required=True, override=True) -A = Options(mode='a', override=True) -R = Options(mode='r', override=True) + +W = Options(mode="w", override=True) +WP = Options(mode="w", ignore_required=True, override=True) +A = Options(mode="a", override=True) +R = Options(mode="r", override=True) diff --git a/utilmeta/core/orm/backends/base.py b/utilmeta/core/orm/backends/base.py index 04b72cc..c524446 100644 --- a/utilmeta/core/orm/backends/base.py +++ b/utilmeta/core/orm/backends/base.py @@ -8,14 +8,20 @@ class ModelFieldAdaptor(BaseAdaptor): @classmethod - def reconstruct(cls, adaptor: 'BaseAdaptor'): + def reconstruct(cls, adaptor: "BaseAdaptor"): pass - __backends_route__ = 'backends' + __backends_route__ = "backends" model_adaptor_cls = None # hold a model field or expression - def __init__(self, field, addon: str = None, model: 'ModelAdaptor' = None, lookup_name: str = None): + def __init__( + self, + field, + addon: str = None, + model: "ModelAdaptor" = None, + lookup_name: str = None, + ): self.field = field self.addon = addon self.lookup_name = lookup_name @@ -30,11 +36,11 @@ def description(self) -> Optional[str]: return None @property - def related_model(self) -> Optional['ModelAdaptor']: + def related_model(self) -> Optional["ModelAdaptor"]: raise NotImplementedError @property - def remote_field(self) -> Optional['ModelFieldAdaptor']: + def remote_field(self) -> Optional["ModelFieldAdaptor"]: raise NotImplementedError # @property @@ -46,15 +52,17 @@ def reverse_lookup(self) -> Tuple[str, str]: return self.model.get_reverse_lookup(self.lookup_name) @property - def target_field(self) -> Optional['ModelFieldAdaptor']: + def target_field(self) -> Optional["ModelFieldAdaptor"]: raise NotImplementedError @property - def through_model(self) -> Optional['ModelAdaptor']: + def through_model(self) -> Optional["ModelAdaptor"]: raise NotImplementedError @property - def through_fields(self) -> Tuple[Optional['ModelFieldAdaptor'], Optional['ModelFieldAdaptor']]: + def through_fields( + self, + ) -> Tuple[Optional["ModelFieldAdaptor"], Optional["ModelFieldAdaptor"]]: raise NotImplementedError @property @@ -187,24 +195,24 @@ class ModelAdaptor(BaseAdaptor): model_cls = None queryset_cls = None - __backends_names__ = ['django', 'peewee', 'sqlalchemy'] + __backends_names__ = ["django", "peewee", "sqlalchemy"] @classmethod - def reconstruct(cls, adaptor: 'BaseAdaptor'): + def reconstruct(cls, adaptor: "BaseAdaptor"): pass def __init__(self, model): if not self.qualify(model): - raise TypeError(f'{self.__class__}: Invalid model: {model}') + raise TypeError(f"{self.__class__}: Invalid model: {model}") self.model = model @property def ident(self): - return f'{self.model.__module__}.{self.model.__name__}' + return f"{self.model.__module__}.{self.model.__name__}" @property def field_errors(self) -> Tuple[Type[Exception], ...]: - return (Exception, ) + return (Exception,) @property def pk_field(self) -> field_adaptor_cls: @@ -286,7 +294,7 @@ def check_subquery(self, qs): def check_queryset(self, qs, check_model: bool = False): raise NotImplementedError - def get_model(self, qs) -> 'ModelAdaptor': + def get_model(self, qs) -> "ModelAdaptor": raise NotImplementedError @property @@ -302,7 +310,7 @@ def table_name(self) -> str: @property def default_db_alias(self) -> str: - return 'default' + return "default" def get_parents(self) -> list: raise NotImplementedError @@ -310,9 +318,13 @@ def get_parents(self) -> list: def cross_models(self, field): raise NotImplementedError - def get_field(self, name: str, validator: Callable = None, - silently: bool = False, - allow_addon: bool = False) -> Optional[field_adaptor_cls]: + def get_field( + self, + name: str, + validator: Callable = None, + silently: bool = False, + allow_addon: bool = False, + ) -> Optional[field_adaptor_cls]: """ Get name from a field references """ @@ -333,7 +345,9 @@ def get_fields(self, many=False, no_inherit=False) -> List[ModelFieldAdaptor]: def get_related_adaptor(self, field): return self.__class__(field.related_model) if field.related_model else None - def gen_lookup_keys(self, field: str, keys, strict: bool = True, excludes: List[str] = None) -> list: + def gen_lookup_keys( + self, field: str, keys, strict: bool = True, excludes: List[str] = None + ) -> list: raise NotImplementedError def gen_lookup_filter(self, field, q, excludes: List[str] = None): diff --git a/utilmeta/core/orm/backends/django/compiler.py b/utilmeta/core/orm/backends/django/compiler.py index 7b6aa0d..be62e47 100644 --- a/utilmeta/core/orm/backends/django/compiler.py +++ b/utilmeta/core/orm/backends/django/compiler.py @@ -15,7 +15,9 @@ from enum import Enum -def get_ignored_errors(errors: Union[bool, Type[Exception], List[Exception]]) -> Tuple[Type[Exception], ...]: +def get_ignored_errors( + errors: Union[bool, Type[Exception], List[Exception]] +) -> Tuple[Type[Exception], ...]: if not errors: return () if errors is True: @@ -43,11 +45,12 @@ def __init__(self, *args, **kwargs): def _get_pk(self, value, robust: bool = False): if robust: if isinstance(value, models.Model): - return getattr(value, 'pk', None) + return getattr(value, "pk", None) else: if isinstance(value, self.model.model): - return getattr(value, 'pk', None) + return getattr(value, "pk", None) from utilmeta.core.orm.schema import Schema + if isinstance(value, Schema): return value.pk if isinstance(value, dict): @@ -179,17 +182,23 @@ async def query_isolated(f): return self.values def handle_isolated_field(self, field: ParserQueryField, e: Exception): - prepend = f'{self.parser.name}[{self.parser.model.model}] ' \ - f'serialize isolated field: [{repr(field.name)}] failed with error: ' + prepend = ( + f"{self.parser.name}[{self.parser.model.model}] " + f"serialize isolated field: [{repr(field.name)}] failed with error: " + ) if not field.fail_silently or self.context.force_raise_error: raise Error(e).throw(prepend=prepend) - warnings.warn(f'{prepend}{e}') + warnings.warn(f"{prepend}{e}") def process_expression(self, expression): if isinstance(expression, exp.Sum) and self.queryset.query.is_sliced: # use subquery to avoid wrong value when sum multiple aggregates - expression = exp.Subquery(self.base_queryset().filter( - pk=exp.OuterRef('pk')).annotate(v=expression).values('v')) + expression = exp.Subquery( + self.base_queryset() + .filter(pk=exp.OuterRef("pk")) + .annotate(v=expression) + .values("v") + ) # once a queryset is sliced, query it's many-related data may return wrong values # for example, qs[:2] should return [{"id": 1, "many": [1, 2, 3]}, {...}], but the slice of main queryset # is affected on the join queries, so it only return [{"id": 1, "many": [1, 2]}, {...}] @@ -205,7 +214,7 @@ def get_query_name(cls, field: ParserQueryField): name = field.field_name if not isinstance(name, str): return None - return name.replace('.', '__') + return name.replace(".", "__") def process_query_field(self, field: ParserQueryField): if field.primary_key: @@ -220,7 +229,9 @@ def process_query_field(self, field: ParserQueryField): if field.related_schema: self.recursively = True elif field.expression: - self.expressions.setdefault(field.name, self.process_expression(field.expression)) + self.expressions.setdefault( + field.name, self.process_expression(field.expression) + ) return if field.included: @@ -251,7 +262,7 @@ def query_isolated_field(self, field: ParserQueryField): pk_map = {} key = field.name - query_key = '__' + key + query_key = "__" + key # avoid "conflicts with a field on the model." current_qs: models.QuerySet = self.model.get_queryset(pk__in=pk_list) @@ -262,13 +273,16 @@ def query_isolated_field(self, field: ParserQueryField): # - related_schema.serialize(related_qs) [related_schema provided] if field.expression: - pk_map = {val[PK]: val[query_key] for val in current_qs.values(PK, **{query_key: field.expression})} + pk_map = { + val[PK]: val[query_key] + for val in current_qs.values(PK, **{query_key: field.expression}) + } elif isinstance(related_queryset, models.QuerySet): # add reverse lookup if field.reverse_lookup: related_queryset = related_queryset.filter( - **{field.reverse_lookup + '__in': pk_list} + **{field.reverse_lookup + "__in": pk_list} ) for val in related_queryset.values(PK, field.reverse_lookup): rel = val[PK] @@ -285,7 +299,9 @@ def query_isolated_field(self, field: ParserQueryField): # so the final values might not be the exact 'pk' # we do not override if user has already selected - for val in current_qs.values(PK, **{query_key: exp.Subquery(related_subquery)}): + for val in current_qs.values( + PK, **{query_key: exp.Subquery(related_subquery)} + ): rel = val[query_key] if rel is not None: pk_map.setdefault(val[PK], []).append(rel) @@ -325,9 +341,10 @@ def query_isolated_field(self, field: ParserQueryField): # also prevent redundant "None" over the non-exist fk if m and f: - for val in m.get_queryset( - **{f + '__in': pk_list}).values(c or PK, __target=exp.F(f)): - rel = val['__target'] + for val in m.get_queryset(**{f + "__in": pk_list}).values( + c or PK, __target=exp.F(f) + ): + rel = val["__target"] if rel is not None: pk_map.setdefault(rel, []).append(val[c or PK]) else: @@ -369,8 +386,8 @@ def query_isolated_field(self, field: ParserQueryField): # if field.related_model else list(related_pks), # for func without related model context=self.get_related_context( - field, force_expressions={SEG + PK: exp.F('pk')} - ) + field, force_expressions={SEG + PK: exp.F("pk")} + ), ): pk = pop(inst, SEG + PK) or inst.get(PK) or inst.get(ID) # try to get pk value @@ -416,7 +433,7 @@ def query_isolated_field(self, field: ParserQueryField): def normalize_pk_list(self, value): if isinstance(value, models.QuerySet): - value = list(value.values_list('pk', flat=True)) + value = list(value.values_list("pk", flat=True)) if not multi(value): value = [value] lst = [] @@ -430,7 +447,7 @@ def normalize_pk_list(self, value): @awaitable(normalize_pk_list) async def normalize_pk_list(self, value): if isinstance(value, models.QuerySet): - value = [pk async for pk in value.values_list('pk', flat=True)] + value = [pk async for pk in value.values_list("pk", flat=True)] if not multi(value): value = [value] lst = [] @@ -443,7 +460,7 @@ async def normalize_pk_list(self, value): def normalize_pk_map(self, pk_map: dict): if not isinstance(pk_map, dict): - raise TypeError(f'Invalid pk map: {pk_map}, must be a dict') + raise TypeError(f"Invalid pk map: {pk_map}, must be a dict") result = {} for k, value in pk_map.items(): lst = self.normalize_pk_list(value) @@ -454,7 +471,7 @@ def normalize_pk_map(self, pk_map: dict): @awaitable(normalize_pk_map) async def normalize_pk_map(self, pk_map: dict): if not isinstance(pk_map, dict): - raise TypeError(f'Invalid pk map: {pk_map}, must be a dict') + raise TypeError(f"Invalid pk map: {pk_map}, must be a dict") result = {} for k, value in pk_map.items(): lst = await self.normalize_pk_list(value) @@ -481,7 +498,7 @@ async def query_isolated_field(self, field: ParserQueryField): return pk_map = {} key = field.name - query_key = '__' + key + query_key = "__" + key # avoid "conflicts with a field on the model." current_qs: models.QuerySet = self.model.get_queryset(pk__in=pk_list) @@ -489,13 +506,16 @@ async def query_isolated_field(self, field: ParserQueryField): related_queryset: models.QuerySet = field.queryset if field.expression: - pk_map = {val[PK]: val[query_key] async for val in current_qs.values(PK, **{query_key: field.expression})} + pk_map = { + val[PK]: val[query_key] + async for val in current_qs.values(PK, **{query_key: field.expression}) + } elif isinstance(related_queryset, models.QuerySet): # add reverse lookup if field.reverse_lookup: related_queryset = related_queryset.filter( - **{field.reverse_lookup + '__in': pk_list} + **{field.reverse_lookup + "__in": pk_list} ) async for val in related_queryset.values(PK, field.reverse_lookup): rel = val[PK] @@ -509,7 +529,9 @@ async def query_isolated_field(self, field: ParserQueryField): # 2. this is a related schema query, we should override the values to PK related_subquery = related_subquery.values(PK) - async for val in current_qs.values(PK, **{query_key: exp.Subquery(related_subquery)}): + async for val in current_qs.values( + PK, **{query_key: exp.Subquery(related_subquery)} + ): rel = val[query_key] if rel is not None: pk_map.setdefault(val[PK], []).append(rel) @@ -547,9 +569,10 @@ async def query_isolated_field(self, field: ParserQueryField): # use reverse query due to the unfixed issue on the async backend # also prevent redundant "None" over the non-exist fk if m and f: - async for val in m.get_queryset( - **{f + '__in': pk_list}).values(c or PK, __target=exp.F(f)): - rel = val['__target'] + async for val in m.get_queryset(**{f + "__in": pk_list}).values( + c or PK, __target=exp.F(f) + ): + rel = val["__target"] if rel is not None: pk_map.setdefault(rel, []).append(val[c or PK]) else: @@ -591,8 +614,8 @@ async def query_isolated_field(self, field: ParserQueryField): # for func without related model, # or the related schema model is not exactly the related model (maybe sub model) context=self.get_related_context( - field, force_expressions={SEG + PK: exp.F('pk')} - ) + field, force_expressions={SEG + PK: exp.F("pk")} + ), ): pk = pop(inst, SEG + PK) or inst.get(PK) or inst.get(ID) # try to get pk value @@ -637,7 +660,9 @@ async def query_isolated_field(self, field: ParserQueryField): def process_value(self, field: ParserQueryField, value): if not field.model_field: return value - if isinstance(field.model_field, models.DurationField) and isinstance(value, (int, float)): + if isinstance(field.model_field, models.DurationField) and isinstance( + value, (int, float) + ): return timedelta(seconds=value) elif multi(value): # convert tuple/set to list @@ -676,18 +701,21 @@ async def commit_data(self, data: dict): # def get_instance(self, pk): # self.model.get_instance_recursively(pk=pk) - def save_data(self, - data, - must_create: bool = False, - must_update: bool = False, - ignore_bulk_errors: bool = False, - ignore_relation_errors: bool = False, - with_relations: bool = None, - transaction: bool = False, - ): + def save_data( + self, + data, + must_create: bool = False, + must_update: bool = False, + ignore_bulk_errors: bool = False, + ignore_relation_errors: bool = False, + with_relations: bool = None, + transaction: bool = False, + ): if with_relations is None: with_relations = self.pref.orm_default_save_with_relations - with TransactionWrapper(self.model, transaction, errors_map=self.get_errors_map(False)): + with TransactionWrapper( + self.model, transaction, errors_map=self.get_errors_map(False) + ): if multi(data): # TODO: implement bulk create/update error_classes = get_ignored_errors(ignore_bulk_errors) @@ -704,11 +732,14 @@ def save_data(self, except error_classes as e: pk = None # leave it to None to keep the result pk_list the same length as values - warnings.warn(f'orm.Schema[{self.model.model}]: ignoring bulk_save errors: {e}') + warnings.warn( + f"orm.Schema[{self.model.model}]: ignoring bulk_save errors: {e}" + ) pk_list.append(pk) return pk_list else: from utilmeta.core.orm.schema import Schema + pk = None if isinstance(data, Schema): pk = data.pk @@ -717,7 +748,9 @@ def save_data(self, pk = data.get(p) if pk is not None: break - data, rel_keys, rel_objs = self.process_data(data, with_relations=with_relations) + data, rel_keys, rel_objs = self.process_data( + data, with_relations=with_relations + ) if pk is None: # create if must_update: @@ -762,18 +795,21 @@ def save_data(self, return pk @awaitable(save_data, bind_service=True, close_conn=True) - async def save_data(self, - data, - must_create: bool = False, - must_update: bool = False, - ignore_bulk_errors: bool = False, - ignore_relation_errors: bool = False, - with_relations: bool = None, - transaction: bool = False, - ): + async def save_data( + self, + data, + must_create: bool = False, + must_update: bool = False, + ignore_bulk_errors: bool = False, + ignore_relation_errors: bool = False, + with_relations: bool = None, + transaction: bool = False, + ): if with_relations is None: with_relations = self.pref.orm_default_save_with_relations - async with TransactionWrapper(self.model, transaction, errors_map=self.get_errors_map(True)): + async with TransactionWrapper( + self.model, transaction, errors_map=self.get_errors_map(True) + ): if multi(data): # TODO: implement bulk create/update error_classes = get_ignored_errors(ignore_bulk_errors) @@ -785,15 +821,18 @@ async def save_data(self, must_create=must_create, must_update=must_update, ignore_relation_errors=ignore_relation_errors, - with_relations=with_relations + with_relations=with_relations, ) except error_classes as e: pk = None - warnings.warn(f'orm.Schema[{self.model.model}]: ignoring bulk_save errors: {e}') + warnings.warn( + f"orm.Schema[{self.model.model}]: ignoring bulk_save errors: {e}" + ) pk_list.append(pk) return pk_list else: from utilmeta.core.orm.schema import Schema + pk = None if isinstance(data, Schema): pk = data.pk @@ -803,7 +842,9 @@ async def save_data(self, if pk is not None: break - data, rel_keys, rel_objs = self.process_data(data, with_relations=with_relations) + data, rel_keys, rel_objs = self.process_data( + data, with_relations=with_relations + ) if pk is None: if must_update: @@ -850,14 +891,16 @@ async def save_data(self, ) return pk - def save_relations(self, - pk, - relation_keys: dict, - relation_objects: dict, - must_create: bool = False, - ignore_errors: bool = False, - ): + def save_relations( + self, + pk, + relation_keys: dict, + relation_objects: dict, + must_create: bool = False, + ignore_errors: bool = False, + ): from utilmeta.core.orm.schema import Schema + error_classes = get_ignored_errors(ignore_errors) # todo: update single object (fk + unique=True) @@ -880,7 +923,9 @@ def save_relations(self, try: related_inst.save(update_fields=[relation_field]) except error_classes as e: - warnings.warn(f'orm.Schema(pk={repr(pk)}): ignoring relational errors for {repr(name)}: {e}') + warnings.warn( + f"orm.Schema(pk={repr(pk)}): ignoring relational errors for {repr(name)}: {e}" + ) else: rel_field = getattr(inst, name, None) if not rel_field: @@ -892,7 +937,9 @@ def save_relations(self, else: rel_field.set(keys) except error_classes as e: - warnings.warn(f'orm.Schema(pk={repr(pk)}): ignoring relational errors for {repr(name)}: {e}') + warnings.warn( + f"orm.Schema(pk={repr(pk)}): ignoring relational errors for {repr(name)}: {e}" + ) for key, (field, objects) in relation_objects.items(): field: ParserQueryField @@ -900,7 +947,7 @@ def save_relations(self, if not related_schema or not issubclass(related_schema, Schema): continue # SET PK - relation_fields = getattr(related_schema, '__relational_fields__', []) or [] + relation_fields = getattr(related_schema, "__relational_fields__", []) or [] if isinstance(objects, Schema): objects = [objects] elif not multi(objects): @@ -912,7 +959,7 @@ def save_relations(self, objects, must_create=must_create and not field.model_field.remote_field.is_pk, ignore_errors=ignore_errors, - with_relations=True + with_relations=True, ) if not must_create: # delete the unrelated-relation @@ -920,19 +967,18 @@ def save_relations(self, field_name = field.model_field.remote_field.name if not field_name: continue - field.related_model.get_queryset( - **{field_name: pk} - ).exclude(pk__in=[val.pk for val in result if val.pk]).delete() + field.related_model.get_queryset(**{field_name: pk}).exclude( + pk__in=[val.pk for val in result if val.pk] + ).delete() except error_classes as e: - warnings.warn(f'orm.Schema(pk={repr(pk)}): ignoring relational ' - f'deletion errors for {repr(key)}: {e}') - - async def asave_relation_keys(self, - obj, - keys: list, - field: ParserQueryField, - add_only: bool = False - ): + warnings.warn( + f"orm.Schema(pk={repr(pk)}): ignoring relational " + f"deletion errors for {repr(key)}: {e}" + ) + + async def asave_relation_keys( + self, obj, keys: list, field: ParserQueryField, add_only: bool = False + ): if not isinstance(obj, self.model.model): obj = self.model.init_instance(pk=obj) @@ -940,8 +986,10 @@ async def asave_relation_keys(self, related_model = field.model_field.related_model from_field, to_field = field.model_field.through_fields if not through_model or not related_model or not from_field or not to_field: - raise exceptions.InvalidRelationalUpdate(f'Invalid relational keys update field: ' - f'{repr(field.model_field.name)}, must be a many-to-may field/rel') + raise exceptions.InvalidRelationalUpdate( + f"Invalid relational keys update field: " + f"{repr(field.model_field.name)}, must be a many-to-may field/rel" + ) create_objs = [] all_keys = [] for key in keys: @@ -949,10 +997,7 @@ async def asave_relation_keys(self, rel_obj = key else: rel_obj = related_model.init_instance(pk=key) - thr_data = { - from_field.name: obj, - to_field.name: rel_obj - } + thr_data = {from_field.name: obj, to_field.name: rel_obj} if not add_only: thr_obj = await through_model.aget_instance(**thr_data) if thr_obj: @@ -960,7 +1005,9 @@ async def asave_relation_keys(self, continue create_objs.append(thr_data) - through_qs = AwaitableQuerySet(model=through_model.model).filter(**{from_field.name: obj}) + through_qs = AwaitableQuerySet(model=through_model.model).filter( + **{from_field.name: obj} + ) db = through_qs.connections_cls.get(through_qs.db) async with db.async_transaction(savepoint=False): @@ -968,18 +1015,18 @@ async def asave_relation_keys(self, obj = await AwaitableQuerySet(model=through_model.model).acreate(**val) all_keys.append(obj.pk) if not add_only: - await through_qs.exclude( - pk__in=all_keys - ).adelete() - - async def asave_relations(self, - pk, - relation_keys: dict, - relation_objects: dict, - must_create: bool = False, - ignore_errors: bool = False, - ): + await through_qs.exclude(pk__in=all_keys).adelete() + + async def asave_relations( + self, + pk, + relation_keys: dict, + relation_objects: dict, + must_create: bool = False, + ignore_errors: bool = False, + ): from utilmeta.core.orm.schema import Schema + error_classes = get_ignored_errors(ignore_errors) for name, (field, keys) in relation_keys.items(): @@ -998,17 +1045,18 @@ async def asave_relations(self, try: await related_inst.asave(update_fields=[relation_field]) except error_classes as e: - warnings.warn(f'orm.Schema(pk={repr(pk)}): ignoring relational errors for {repr(name)}: {e}') + warnings.warn( + f"orm.Schema(pk={repr(pk)}): ignoring relational errors for {repr(name)}: {e}" + ) else: try: await self.asave_relation_keys( - inst, - keys=keys, - field=field, - add_only=must_create + inst, keys=keys, field=field, add_only=must_create ) except error_classes as e: - warnings.warn(f'orm.Schema(pk={repr(pk)}): ignoring relational errors for {repr(name)}: {e}') + warnings.warn( + f"orm.Schema(pk={repr(pk)}): ignoring relational errors for {repr(name)}: {e}" + ) # async tasks may cause update problem? don't know, to be tested for key, (field, objects) in relation_objects.items(): @@ -1017,7 +1065,7 @@ async def asave_relations(self, if not related_schema or not issubclass(related_schema, Schema): continue # SET PK - relation_fields = getattr(related_schema, '__relational_fields__', []) or [] + relation_fields = getattr(related_schema, "__relational_fields__", []) or [] if isinstance(objects, Schema): objects = [objects] elif not multi(objects): @@ -1031,7 +1079,7 @@ async def asave_relations(self, objects, must_create=must_create and not field.model_field.remote_field.is_pk, ignore_errors=ignore_errors, - with_relations=True + with_relations=True, ) if not must_create: # delete the unrelated-relation @@ -1039,12 +1087,14 @@ async def asave_relations(self, field_name = field.model_field.remote_field.name if not field_name: continue - await field.related_model.get_queryset( - **{field_name: pk} - ).exclude(pk__in=[val.pk for val in result if val.pk]).adelete() + await field.related_model.get_queryset(**{field_name: pk}).exclude( + pk__in=[val.pk for val in result if val.pk] + ).adelete() except error_classes as e: - warnings.warn(f'orm.Schema(pk={repr(pk)}): ignoring relational ' - f'deletion errors for {repr(key)}: {e}') + warnings.warn( + f"orm.Schema(pk={repr(pk)}): ignoring relational " + f"deletion errors for {repr(key)}: {e}" + ) def get_errors_map(self, asynchronous: bool = False) -> dict: if self.context.integrity_error_cls: @@ -1059,12 +1109,17 @@ def get_integrity_errors(self, asynchronous: bool = False): # we should not return any return () from .queryset import AwaitableQuerySet + qs = self.model.get_queryset() from django.db.utils import IntegrityError + if isinstance(qs, AwaitableQuerySet) or asynchronous: from utilmeta.core.orm import DatabaseConnections + db = DatabaseConnections.get(qs.db) - errors = list(db.get_adaptor(asynchronous=asynchronous).get_integrity_errors()) + errors = list( + db.get_adaptor(asynchronous=asynchronous).get_integrity_errors() + ) else: errors = [] if IntegrityError not in errors: diff --git a/utilmeta/core/orm/backends/django/constant.py b/utilmeta/core/orm/backends/django/constant.py index 1c043a6..101bdc1 100644 --- a/utilmeta/core/orm/backends/django/constant.py +++ b/utilmeta/core/orm/backends/django/constant.py @@ -6,134 +6,153 @@ SM = 32767 MD = 2147483647 LG = 9223372036854775807 -PK = 'pk' -ID = 'id' -SEG = '__' +PK = "pk" +ID = "id" +SEG = "__" FIELDS_TYPE = { - ('CharField', 'ImageField', 'ChoiceField', 'PasswordField', - 'EmailField', 'FilePathField', 'FileField', 'URLField', 'SlugField', - 'GenericIPAddressField', 'IPAddressField', 'TextField', - 'RichTextField',): str, - ('UUIDField',): UUID, - ('TimeField',): time, - ('DateField',): date, - ('DurationField',): timedelta, - ('DateTimeField',): datetime, - ('AutoField', 'BigAutoField', 'SmallAutoField', 'BigIntegerField', - 'IntegerField', 'PositiveIntegerField', 'PositiveBigIntegerField', - 'SmallIntegerField', 'PositiveSmallIntegerField', 'SmallIntegerField',): int, - ('FloatField',): float, - ('DecimalField',): Decimal, - ('BooleanField', 'NullBooleanField',): bool, - ('CommaSeparatedIntegerField', 'ArrayField', 'ManyToManyField', - 'ManyToOneRel', 'ManyToManyRel'): list, - ('HStoreField',): dict, - ('JSONField',): Any, - ('BinaryField',): bytes + ( + "CharField", + "ImageField", + "ChoiceField", + "PasswordField", + "EmailField", + "FilePathField", + "FileField", + "URLField", + "SlugField", + "GenericIPAddressField", + "IPAddressField", + "TextField", + "RichTextField", + ): str, + ("UUIDField",): UUID, + ("TimeField",): time, + ("DateField",): date, + ("DurationField",): timedelta, + ("DateTimeField",): datetime, + ( + "AutoField", + "BigAutoField", + "SmallAutoField", + "BigIntegerField", + "IntegerField", + "PositiveIntegerField", + "PositiveBigIntegerField", + "SmallIntegerField", + "PositiveSmallIntegerField", + "SmallIntegerField", + ): int, + ("FloatField",): float, + ("DecimalField",): Decimal, + ( + "BooleanField", + "NullBooleanField", + ): bool, + ( + "CommaSeparatedIntegerField", + "ArrayField", + "ManyToManyField", + "ManyToOneRel", + "ManyToManyRel", + ): list, + ("HStoreField",): dict, + ("JSONField",): Any, + ("BinaryField",): bytes, } -datetime_lookups = ['date', 'time'] -date_lookups = ['year', 'iso_year', 'month', 'day', 'week', 'week_day', 'quarter'] -time_lookups = ['hour', 'minute', 'second'] -option_allowed_lookups = [*datetime_lookups, *date_lookups, *time_lookups, 'len'] +datetime_lookups = ["date", "time"] +date_lookups = ["year", "iso_year", "month", "day", "week", "week_day", "quarter"] +time_lookups = ["hour", "minute", "second"] +option_allowed_lookups = [*datetime_lookups, *date_lookups, *time_lookups, "len"] ADDON_FIELD_LOOKUPS = { - 'DateField': date_lookups, - 'TimeField': time_lookups, - 'DateTimeField': [*date_lookups, *time_lookups, *datetime_lookups], - 'JSONField': ['contains', 'contained_by', 'has_key', 'has_any_keys', 'has_keys'], - 'ArrayField': ['contains', 'contained_by', 'overlap', 'len'], - 'HStoreField': ['contains', 'contained_by', 'has_key', 'has_any_keys', 'has_keys', 'keys', 'values'], - 'RangeField': ['contains', 'contained_by', 'overlap', 'fully_lt', 'fully_gt', 'not_lt', - 'not_gt', 'adjacent_to', 'isempty', 'lower_inc', 'lower_inf', 'upper_inc', 'upper_inf'] + "DateField": date_lookups, + "TimeField": time_lookups, + "DateTimeField": [*date_lookups, *time_lookups, *datetime_lookups], + "JSONField": ["contains", "contained_by", "has_key", "has_any_keys", "has_keys"], + "ArrayField": ["contains", "contained_by", "overlap", "len"], + "HStoreField": [ + "contains", + "contained_by", + "has_key", + "has_any_keys", + "has_keys", + "keys", + "values", + ], + "RangeField": [ + "contains", + "contained_by", + "overlap", + "fully_lt", + "fully_gt", + "not_lt", + "not_gt", + "adjacent_to", + "isempty", + "lower_inc", + "lower_inf", + "upper_inc", + "upper_inf", + ], } ADDON_LOOKUP_RULES = { - 'date': date, - 'time': time, - 'year': Year, - 'iso_year': Year, - 'month': Month, - 'day': Day, - 'week': Week, - 'week_day': WeekDay, - 'quarter': Quarter, - 'hour': Hour, - 'minute': Minute, - 'second': Second, - 'len': int, - 'has_key': str, - 'has_any_keys': list, - 'has_keys': list, - 'keys': list, - 'values': list, - 'isempty': bool, - 'upper_inc': bool, - 'lower_inc': bool, - 'upper_inf': bool, - 'lower_inf': bool, + "date": date, + "time": time, + "year": Year, + "iso_year": Year, + "month": Month, + "day": Day, + "week": Week, + "week_day": WeekDay, + "quarter": Quarter, + "hour": Hour, + "minute": Minute, + "second": Second, + "len": int, + "has_key": str, + "has_any_keys": list, + "has_keys": list, + "keys": list, + "values": list, + "isempty": bool, + "upper_inc": bool, + "lower_inc": bool, + "upper_inf": bool, + "lower_inf": bool, } ADDON_FIELDS = { - 'date': models.DateField, - 'time': models.TimeField, - 'year': models.PositiveIntegerField, - 'iso_year': models.PositiveIntegerField, - 'month': models.PositiveSmallIntegerField, - 'day': models.PositiveSmallIntegerField, - 'week': models.PositiveSmallIntegerField, - 'week_day': models.PositiveSmallIntegerField, - 'quarter': models.PositiveSmallIntegerField, - 'hour': models.PositiveSmallIntegerField, - 'minute': models.PositiveSmallIntegerField, - 'second': models.PositiveSmallIntegerField, - 'len': models.PositiveIntegerField, + "date": models.DateField, + "time": models.TimeField, + "year": models.PositiveIntegerField, + "iso_year": models.PositiveIntegerField, + "month": models.PositiveSmallIntegerField, + "day": models.PositiveSmallIntegerField, + "week": models.PositiveSmallIntegerField, + "week_day": models.PositiveSmallIntegerField, + "quarter": models.PositiveSmallIntegerField, + "hour": models.PositiveSmallIntegerField, + "minute": models.PositiveSmallIntegerField, + "second": models.PositiveSmallIntegerField, + "len": models.PositiveIntegerField, } -OPERATOR_FIELDS = [dict( - cls=models.FloatField, - type=float, - operators=['+', '-', '/', '*', '%', '^'] -), dict( - cls=models.IntegerField, - type=int, - operators=['+', '-', '%'] -), dict( - cls=models.IntegerField, - type=float, - operators=['/', '*', '^'] -), dict( - cls=models.DecimalField, - type=Decimal, - operators=['+', '-', '/', '*', '%', '^'] -), dict( - cls=models.CharField, - type=str, - operators=['+'] # 'abc' + 'd' = 'abcd' -), dict( - cls=models.TextField, - type=str, - operators=['+'] -), dict( - cls=models.CharField, - type=int, - operators=['*'] -), dict( - cls=models.TextField, - type=int, - operators=['*'] -), dict( - cls=models.DurationField, - type=timedelta, - operators=['+', '-'] -), dict( - cls=models.DurationField, - type=float, - operators=['*', '/'] -), dict( - cls=models.DateTimeField, - type=timedelta, - operators=['+', '-'] -)] +OPERATOR_FIELDS = [ + dict(cls=models.FloatField, type=float, operators=["+", "-", "/", "*", "%", "^"]), + dict(cls=models.IntegerField, type=int, operators=["+", "-", "%"]), + dict(cls=models.IntegerField, type=float, operators=["/", "*", "^"]), + dict( + cls=models.DecimalField, type=Decimal, operators=["+", "-", "/", "*", "%", "^"] + ), + dict(cls=models.CharField, type=str, operators=["+"]), # 'abc' + 'd' = 'abcd' + dict(cls=models.TextField, type=str, operators=["+"]), + dict(cls=models.CharField, type=int, operators=["*"]), + dict(cls=models.TextField, type=int, operators=["*"]), + dict(cls=models.DurationField, type=timedelta, operators=["+", "-"]), + dict(cls=models.DurationField, type=float, operators=["*", "/"]), + dict(cls=models.DateTimeField, type=timedelta, operators=["+", "-"]), +] # PK_TYPES = (int, str, float, Decimal, UUID) diff --git a/utilmeta/core/orm/backends/django/database.py b/utilmeta/core/orm/backends/django/database.py index ca6f752..65b784d 100644 --- a/utilmeta/core/orm/backends/django/database.py +++ b/utilmeta/core/orm/backends/django/database.py @@ -6,27 +6,28 @@ class DjangoDatabaseAdaptor(BaseDatabaseAdaptor): - SQLITE = 'django.db.backends.sqlite3' - ORACLE = 'django.db.backends.oracle' - MYSQL = 'django.db.backends.mysql' - POSTGRESQL = 'django.db.backends.postgresql' + SQLITE = "django.db.backends.sqlite3" + ORACLE = "django.db.backends.oracle" + MYSQL = "django.db.backends.mysql" + POSTGRESQL = "django.db.backends.postgresql" # -- pooled backends - POOLED_POSTGRESQL = 'utilmeta.util.query.pooled_backends.postgresql' - POOLED_GEVENT_POSTGRESQL = 'utilmeta.util.query.pooled_backends.postgresql_gevent' + POOLED_POSTGRESQL = "utilmeta.util.query.pooled_backends.postgresql" + POOLED_GEVENT_POSTGRESQL = "utilmeta.util.query.pooled_backends.postgresql_gevent" # POOLED_MYSQL = 'utilmeta.util.query.pooled_backends.mysql' # POOLED_ORACLE = 'utilmeta.util.query.pooled_backends.oracle' DEFAULT_ENGINES = { - 'sqlite': SQLITE, - 'sqlite3': SQLITE, - 'oracle': ORACLE, - 'mysql': MYSQL, - 'postgresql': POSTGRESQL, - 'postgres': POSTGRESQL + "sqlite": SQLITE, + "sqlite3": SQLITE, + "oracle": ORACLE, + "mysql": MYSQL, + "postgresql": POSTGRESQL, + "postgres": POSTGRESQL, } def get_integrity_errors(self): from django.db.utils import IntegrityError + return (IntegrityError,) @classmethod @@ -67,10 +68,12 @@ def allow_migrate(db, app_label, model_name=None, **hints): def connect(self): from django.db import connections + return connections[self.alias] def disconnect(self): from django.db import connections + connections.close_all() def execute(self, sql, params=None): @@ -92,6 +95,7 @@ def fetchone(self, sql, params=None): def fetchall(self, sql, params=None): from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE + db = self.connect() with db.cursor() as cursor: cursor.execute(sql, params) @@ -99,6 +103,7 @@ def fetchall(self, sql, params=None): def transaction(self, savepoint=None, isolation=None, force_rollback: bool = False): from django.db import transaction + return transaction.atomic(self.alias, savepoint=savepoint) def check(self): @@ -107,14 +112,9 @@ def check(self): # except (ModuleNotFoundError, ImportError) as e: # raise e.__class__(f'{self.__class__} as database adaptor requires to install django') from e if self.config.is_mysql: - requires( - MySQLdb='mysqlclient' - ) + requires(MySQLdb="mysqlclient") elif self.config.is_postgresql: - requires( - psycopg='"psycopg[binary,pool]"', - psycopg2='psycopg2' - ) + requires(psycopg='"psycopg[binary,pool]"', psycopg2="psycopg2") class DjangoDatabase(Database): diff --git a/utilmeta/core/orm/backends/django/deletion.py b/utilmeta/core/orm/backends/django/deletion.py index 7f83a18..f3a7cac 100644 --- a/utilmeta/core/orm/backends/django/deletion.py +++ b/utilmeta/core/orm/backends/django/deletion.py @@ -4,8 +4,11 @@ from django.db.models.deletion import Collector, Counter from collections import defaultdict from itertools import chain -from django.db.models.deletion import get_candidate_relations_to_delete, \ - DO_NOTHING, ProtectedError +from django.db.models.deletion import ( + get_candidate_relations_to_delete, + DO_NOTHING, + ProtectedError, +) from django.db.models import QuerySet, sql, signals from django.db import models from django.core.exceptions import EmptyResultSet @@ -17,6 +20,7 @@ try: from django.db.models.deletion import RestrictedError except ImportError: + class RestrictedError(Exception): pass @@ -43,7 +47,9 @@ async def delete_single(cls, qs: QuerySet, db: DatabaseConnections.database_cls) return 0 @classmethod - async def update_batch(cls, model, pk_list, values, db: DatabaseConnections.database_cls): + async def update_batch( + cls, model, pk_list, values, db: DatabaseConnections.database_cls + ): query = sql.UpdateQuery(model) query.add_update_values(values) for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): @@ -51,9 +57,10 @@ async def update_batch(cls, model, pk_list, values, db: DatabaseConnections.data query.clear_where() else: from django.db.models.sql.where import WhereNode + query.where = WhereNode() query.add_filter( - "pk__in", pk_list[offset: offset + GET_ITERATOR_CHUNK_SIZE] + "pk__in", pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE] ) q, params = query.get_compiler(db.alias).as_sql() await db.execute(q, params) @@ -75,10 +82,11 @@ async def delete_batch(cls, model, pk_list, db: DatabaseConnections.database_cls query.clear_where() else: from django.db.models.sql.where import WhereNode + query.where = WhereNode() query.add_filter( f"{field.attname}__in", - pk_list[offset: offset + GET_ITERATOR_CHUNK_SIZE], + pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE], ) where = query.where table = query.get_meta().db_table @@ -115,7 +123,9 @@ async def async_delete(self): if self.can_fast_delete(instance): async with db.async_transaction(): # with transaction.mark_for_rollback_on_error(self.using): - count = await self.delete_batch(model, pk_list=[instance.pk], db=db) + count = await self.delete_batch( + model, pk_list=[instance.pk], db=db + ) setattr(instance, model._meta.pk.attname, None) return count, {model._meta.label: count} @@ -152,11 +162,12 @@ async def async_delete(self): if updates: combined_updates = reduce(or_, updates) from .queryset import AwaitableQuerySet + if not isinstance(combined_updates, AwaitableQuerySet): combined_updates = AwaitableQuerySet( model=combined_updates.model, query=combined_updates.query, - using=combined_updates.db + using=combined_updates.db, ) try: await combined_updates.aupdate(**{field.name: value}) @@ -166,15 +177,19 @@ async def async_delete(self): model = objs[0].__class__ # query = sql.UpdateQuery(model) await self.update_batch( - model, pk_list=[obj.pk for obj in objs], - values={field.name: value}, db=db + model, + pk_list=[obj.pk for obj in objs], + values={field.name: value}, + db=db, ) else: for model, instances_for_fieldvalues in self.field_updates.items(): for (field, value), instances in instances_for_fieldvalues.items(): await self.update_batch( - model, pk_list=[obj.pk for obj in instances], - values={field.name: value}, db=db + model, + pk_list=[obj.pk for obj in instances], + values={field.name: value}, + db=db, ) # reverse instance collections @@ -214,12 +229,16 @@ def related_objects(self, related_model, related_fields, objs): Get a QuerySet of the related model to objs via related fields. """ from django.db.models import query_utils + predicate = query_utils.Q.create( [(f"{related_field.name}__in", objs) for related_field in related_fields], connector=query_utils.Q.OR, ) from .queryset import AwaitableQuerySet - return AwaitableQuerySet(model=related_model).using(self.using).filter(predicate) + + return ( + AwaitableQuerySet(model=related_model).using(self.using).filter(predicate) + ) async def aadd(self, objs, source=None, nullable=False, reverse_dependency=False): """ @@ -230,6 +249,7 @@ async def aadd(self, objs, source=None, nullable=False, reverse_dependency=False Return a list of all objects that were not already collected. """ from .queryset import AwaitableQuerySet + new_objs = [] if isinstance(objs, AwaitableQuerySet): if not await objs.aexists(): @@ -297,6 +317,7 @@ async def acollect( return from .queryset import AwaitableQuerySet + if isinstance(objs, QuerySet): model = objs.model if not isinstance(objs, AwaitableQuerySet): @@ -373,7 +394,10 @@ async def acollect( ) ) sub_objs = sub_objs.only(*tuple(referenced_fields)) - if getattr(on_delete, "lazy_sub_objs", False) or await sub_objs.aexists(): + if ( + getattr(on_delete, "lazy_sub_objs", False) + or await sub_objs.aexists() + ): try: r = on_delete(self, field, sub_objs, self.using) if inspect.isawaitable(r): diff --git a/utilmeta/core/orm/backends/django/exceptions.py b/utilmeta/core/orm/backends/django/exceptions.py index 2c2c79d..3881669 100644 --- a/utilmeta/core/orm/backends/django/exceptions.py +++ b/utilmeta/core/orm/backends/django/exceptions.py @@ -1,3 +1,11 @@ from django.core.exceptions import * -from django.db.utils import DataError, DatabaseError, NotSupportedError, \ - IntegrityError, InterfaceError, InternalError, OperationalError, ProgrammingError +from django.db.utils import ( + DataError, + DatabaseError, + NotSupportedError, + IntegrityError, + InterfaceError, + InternalError, + OperationalError, + ProgrammingError, +) diff --git a/utilmeta/core/orm/backends/django/expressions.py b/utilmeta/core/orm/backends/django/expressions.py index 5aead16..16a8813 100644 --- a/utilmeta/core/orm/backends/django/expressions.py +++ b/utilmeta/core/orm/backends/django/expressions.py @@ -1,34 +1,59 @@ from django.db.models.aggregates import * from django.db.models.functions import * -from django.db.models.expressions import F, Q, Subquery, OuterRef, ResolvedOuterRef, Value, \ - ValueRange, Case, When, Col, Window, Ref, RawSQL, Func, BaseExpression, \ - RowRange, OrderBy, Exists, WindowFrame, Star, Expression, Combinable, CombinedExpression +from django.db.models.expressions import ( + F, + Q, + Subquery, + OuterRef, + ResolvedOuterRef, + Value, + ValueRange, + Case, + When, + Col, + Window, + Ref, + RawSQL, + Func, + BaseExpression, + RowRange, + OrderBy, + Exists, + WindowFrame, + Star, + Expression, + Combinable, + CombinedExpression, +) from django import VERSION + if VERSION[0] < 3: # compat from django.db.models.expressions import BaseExpression from django.db.models.functions.mixins import NumericOutputFieldMixin - def _get_output_field(self, _original=BaseExpression._resolve_output_field): # noqa + def _get_output_field(self, _original=BaseExpression._resolve_output_field): # noqa try: return _original(self) # noqa except AttributeError: from django.db.models.fields import FloatField + return FloatField() - BaseExpression._resolve_output_field = _get_output_field # patch + BaseExpression._resolve_output_field = _get_output_field # patch NumericOutputFieldMixin._resolve_output_field = _get_output_field class SubqueryCount(Subquery): template = "(SELECT count(*) FROM (%(subquery)s) _count)" from django.db import models + output_field = models.PositiveIntegerField() def __init__(self, queryset: models.QuerySet, output_field=None, **extra): if not queryset.query.select: - queryset = queryset.values('pk') + queryset = queryset.values("pk") super().__init__(queryset, output_field=output_field, **extra) diff --git a/utilmeta/core/orm/backends/django/field.py b/utilmeta/core/orm/backends/django/field.py index 4116d8b..2c6d173 100644 --- a/utilmeta/core/orm/backends/django/field.py +++ b/utilmeta/core/orm/backends/django/field.py @@ -17,33 +17,41 @@ def one_to(field): - return isinstance(field, (models.OneToOneField, models.ForeignKey, models.OneToOneRel)) + return isinstance( + field, (models.OneToOneField, models.ForeignKey, models.OneToOneRel) + ) def many_to(field): if isinstance(field, models.OneToOneRel): # OneToOneRel is subclass of ManyToOneRel return False - return isinstance(field, (models.ManyToManyField, models.ManyToManyRel, models.ManyToOneRel)) + return isinstance( + field, (models.ManyToManyField, models.ManyToManyRel, models.ManyToOneRel) + ) def to_many(field): - return isinstance(field, (models.ManyToManyField, models.ManyToManyRel, models.ForeignKey)) + return isinstance( + field, (models.ManyToManyField, models.ManyToManyRel, models.ForeignKey) + ) def to_one(field): - return isinstance(field, (models.OneToOneField, models.OneToOneRel, models.ManyToOneRel)) + return isinstance( + field, (models.OneToOneField, models.OneToOneRel, models.ManyToOneRel) + ) class DjangoModelFieldAdaptor(ModelFieldAdaptor): field: Union[models.Field, ForeignObjectRel, exp.BaseExpression, exp.Combinable] - model: 'DjangoModelAdaptor' + model: "DjangoModelAdaptor" def __init__(self, field, addon: str = None, model=None, lookup_name: str = None): if isinstance(field, DeferredAttribute): field = field.field if not lookup_name: - lookup_name = getattr(field, 'field_name', getattr(field, 'name', None)) + lookup_name = getattr(field, "field_name", getattr(field, "name", None)) # if isinstance(field, str): # from .model import DjangoModelAdaptor @@ -51,30 +59,35 @@ def __init__(self, field, addon: str = None, model=None, lookup_name: str = None # field = model.get_field(field) if not self.qualify(field): - raise TypeError(f'Invalid field: {field}') + raise TypeError(f"Invalid field: {field}") super().__init__(field, addon, model, lookup_name) self.validate_addon() @property def multi_relations(self): - return self.lookup_name and '__' in self.lookup_name + return self.lookup_name and "__" in self.lookup_name def validate_addon(self): if not self.addon: return if not isinstance(self.addon, str): - raise TypeError(f'Invalid addon: {repr(self.addon)}, must be str') + raise TypeError(f"Invalid addon: {repr(self.addon)}, must be str") if self.is_concrete: _t = self.field.get_internal_type() addons = constant.ADDON_FIELD_LOOKUPS.get(_t, []) if self.addon not in addons: from django.db.models import JSONField + if not isinstance(self.field, JSONField): - warnings.warn(f'Invalid addon: {repr(self.addon)} for field: {self.field},' - f' only {addons} are supported') + warnings.warn( + f"Invalid addon: {repr(self.addon)} for field: {self.field}," + f" only {addons} are supported" + ) else: - raise TypeError(f'Not concrete field: {self.field} cannot have addon: {repr(self.addon)}') + raise TypeError( + f"Not concrete field: {self.field} cannot have addon: {repr(self.addon)}" + ) @property def title(self) -> Optional[str]: @@ -85,28 +98,30 @@ def title(self) -> Optional[str]: @property def description(self) -> Optional[str]: - return str(self.field.help_text or '') or None + return str(self.field.help_text or "") or None @classmethod def qualify(cls, obj): - return isinstance(obj, (models.Field, ForeignObjectRel, exp.BaseExpression, exp.Combinable)) + return isinstance( + obj, (models.Field, ForeignObjectRel, exp.BaseExpression, exp.Combinable) + ) @property def field_model(self): if self.is_exp: return None - return getattr(self.field, 'model', None) + return getattr(self.field, "model", None) @property - def target_field(self) -> Optional['ModelFieldAdaptor']: - target_field = getattr(self.field, 'target_field', None) + def target_field(self) -> Optional["ModelFieldAdaptor"]: + target_field = getattr(self.field, "target_field", None) if target_field: return self.__class__(target_field, model=self.model) return None @property - def remote_field(self) -> Optional['ModelFieldAdaptor']: - remote_field = getattr(self.field, 'remote_field', None) + def remote_field(self) -> Optional["ModelFieldAdaptor"]: + remote_field = getattr(self.field, "remote_field", None) if remote_field and self.field.related_model: return self.__class__(remote_field, model=self.field.related_model) return None @@ -115,11 +130,12 @@ def remote_field(self) -> Optional['ModelFieldAdaptor']: def related_model(self): if self.is_exp: return None - rel = getattr(self.field, 'related_model') + rel = getattr(self.field, "related_model") if rel: - if rel == 'self': + if rel == "self": return self from .model import DjangoModelAdaptor + return DjangoModelAdaptor(rel) return None @@ -133,11 +149,14 @@ def through_model(self): rel = self.field if rel.through: from .model import DjangoModelAdaptor + return DjangoModelAdaptor(rel.through) return None @property - def through_fields(self) -> Tuple[Optional['ModelFieldAdaptor'], Optional['ModelFieldAdaptor']]: + def through_fields( + self, + ) -> Tuple[Optional["ModelFieldAdaptor"], Optional["ModelFieldAdaptor"]]: if not self.is_m2m: return None, None is_rel = False @@ -168,7 +187,7 @@ def through_fields(self) -> Tuple[Optional['ModelFieldAdaptor'], Optional['Model def is_nullable(self): if not self.is_concrete: return True - return getattr(self.field, 'null', False) + return getattr(self.field, "null", False) @property def is_optional(self): @@ -183,8 +202,8 @@ def is_writable(self): if self.field == self.model.meta.auto_field: return False param = self.params - auto_now_add = param.get('auto_now_add') - auto_created = param.get('auto_created') + auto_now_add = param.get("auto_now_add") + auto_created = param.get("auto_created") if auto_now_add or auto_created: return False return True @@ -206,9 +225,9 @@ def is_auto(self): if not self.is_concrete: return False param = self.params - auto_now_add = param.get('auto_now_add') - auto_now = param.get('auto_now') - auto_created = param.get('auto_created') + auto_now_add = param.get("auto_now_add") + auto_now = param.get("auto_now") + auto_created = param.get("auto_created") if auto_now_add or auto_now or auto_created: return True return self.field == self.model.meta.auto_field @@ -218,7 +237,7 @@ def is_auto_now(self): if not self.is_concrete: return False param = self.params - auto_now = param.get('auto_now') + auto_now = param.get("auto_now") return auto_now @classmethod @@ -261,6 +280,7 @@ def rule(self) -> Type[Rule]: if _type and self.is_nullable: from utype.parser.rule import LogicalType + _type = LogicalType.any_of(_type, type(None)) # return _type @@ -272,6 +292,7 @@ def rule(self) -> Type[Rule]: if _type != Any: if _type and self.is_nullable: from utype.parser.rule import LogicalType + _type = LogicalType.any_of(_type, type(None)) elif self.is_exp: @@ -302,45 +323,45 @@ def rule(self) -> Type[Rule]: params = self._get_params(field) kwargs = {} - if params.get('max_length'): - kwargs['max_length'] = params['max_length'] - if params.get('min_length'): - kwargs['min_length'] = params['min_length'] - if 'max_value' in params: - kwargs['le'] = params['max_value'] - if 'min_value' in params: - kwargs['ge'] = params['min_value'] + if params.get("max_length"): + kwargs["max_length"] = params["max_length"] + if params.get("min_length"): + kwargs["min_length"] = params["min_length"] + if "max_value" in params: + kwargs["le"] = params["max_value"] + if "min_value" in params: + kwargs["ge"] = params["min_value"] if isinstance(field, models.DecimalField): - kwargs['max_length'] = field.max_digits - kwargs['decimal_places'] = Lax(field.decimal_places) + kwargs["max_length"] = field.max_digits + kwargs["decimal_places"] = Lax(field.decimal_places) # for the reason that IntegerField is the base class of All integer fields # so the isinstance determine will be the last to include elif isinstance(field, models.IntegerField): if isinstance(field, models.PositiveSmallIntegerField): - kwargs['ge'] = 0 - kwargs['le'] = constant.SM + kwargs["ge"] = 0 + kwargs["le"] = constant.SM elif isinstance(field, models.AutoField): - kwargs['ge'] = 1 - kwargs['le'] = constant.MD + kwargs["ge"] = 1 + kwargs["le"] = constant.MD elif isinstance(field, models.BigAutoField): - kwargs['ge'] = 1 - kwargs['le'] = constant.LG + kwargs["ge"] = 1 + kwargs["le"] = constant.LG elif isinstance(field, models.BigIntegerField): - kwargs['ge'] = -constant.LG - kwargs['le'] = constant.LG + kwargs["ge"] = -constant.LG + kwargs["le"] = constant.LG elif isinstance(field, models.PositiveBigIntegerField): - kwargs['ge'] = 0 - kwargs['le'] = constant.LG + kwargs["ge"] = 0 + kwargs["le"] = constant.LG elif isinstance(field, models.PositiveIntegerField): - kwargs['ge'] = 0 - kwargs['le'] = constant.MD + kwargs["ge"] = 0 + kwargs["le"] = constant.MD elif isinstance(field, models.SmallIntegerField): - kwargs['ge'] = -constant.SM - kwargs['le'] = constant.SM + kwargs["ge"] = -constant.SM + kwargs["le"] = constant.SM else: - kwargs['ge'] = -constant.MD - kwargs['le'] = constant.MD + kwargs["ge"] = -constant.MD + kwargs["le"] = constant.MD if _type is None: # fallback to string field @@ -352,9 +373,9 @@ def rule(self) -> Type[Rule]: def name(self) -> Optional[str]: if self.is_exp: return None - if hasattr(self.field, 'name'): + if hasattr(self.field, "name"): return self.field.name - if hasattr(self.field, 'field_name'): + if hasattr(self.field, "field_name"): # toOneRel return self.field.field_name return None @@ -376,17 +397,19 @@ def check_query(self): if not qn: return try: - if '__' not in qn: - self.model.get_queryset(**{qn + '__isnull': False}) + if "__" not in qn: + self.model.get_queryset(**{qn + "__isnull": False}) else: try: self.model.get_queryset(**{qn: None}) except ValueError: - self.model.get_queryset(**{qn: ''}) + self.model.get_queryset(**{qn: ""}) except exceptions.FieldError as e: - raise exceptions.FieldError(f'Invalid query name: {repr(qn)} for {self.model.model}: {e}') + raise exceptions.FieldError( + f"Invalid query name: {repr(qn)} for {self.model.model}: {e}" + ) except ValueError as e: - print(f'failed to check query field: {repr(qn)} for {self.model.model}', e) + print(f"failed to check query field: {repr(qn)} for {self.model.model}", e) pass @property @@ -407,11 +430,14 @@ def to_field(self) -> Optional[str]: @property def relate_name(self) -> Optional[str]: if self.is_fk: - related_name = getattr(self.field, '_related_name', None) + related_name = getattr(self.field, "_related_name", None) if related_name: return related_name try: - return self.field.remote_field.name or self.field.remote_field.get_cache_name() + return ( + self.field.remote_field.name + or self.field.remote_field.get_cache_name() + ) except (AttributeError, NotImplementedError): return None return None @@ -446,7 +472,7 @@ def is_concrete(self): if self.is_m2m: # somehow ManyToManyField is considered "concrete" in django return False - return getattr(self.field, 'concrete', False) + return getattr(self.field, "concrete", False) @property def is_m2m(self): diff --git a/utilmeta/core/orm/backends/django/generator.py b/utilmeta/core/orm/backends/django/generator.py index bd58a25..d9da62d 100644 --- a/utilmeta/core/orm/backends/django/generator.py +++ b/utilmeta/core/orm/backends/django/generator.py @@ -24,17 +24,20 @@ def _get_unsliced_qs(self, base=None): else: if isinstance(base, models.QuerySet): if not issubclass(base.model, self.model.model): - raise TypeError(f'Invalid queryset: {base}') + raise TypeError(f"Invalid queryset: {base}") qs = base else: - raise TypeError(f'Invalid queryset: {base}') + raise TypeError(f"Invalid queryset: {base}") if self.annotates: qs = qs.annotate(**self.annotates) if self.q: qs = qs.filter(self.q) - if self.distinct and not qs.query.distinct and \ - not qs.query.combinator and \ - not qs.query.is_sliced: + if ( + self.distinct + and not qs.query.distinct + and not qs.query.combinator + and not qs.query.is_sliced + ): qs = qs.distinct() return qs @@ -65,12 +68,12 @@ def process_filter(self, field: ParserFilter, value): try: q = field.query(value) except Exception as e: - prepend = f'{self.__class__}: apply filter: [{repr(field.name)}].order failed with error: ' + prepend = f"{self.__class__}: apply filter: [{repr(field.name)}].order failed with error: " if not field.fail_silently: raise Error(e).throw(prepend=prepend) - warnings.warn(f'{prepend}{e}') + warnings.warn(f"{prepend}{e}") if not isinstance(q, exp.Q): - raise TypeError(f'Invalid query expression: {q}') + raise TypeError(f"Invalid query expression: {q}") else: q = Q(**{field.query_name: value}) @@ -82,15 +85,17 @@ def process_filter(self, field: ParserFilter, value): try: order = field.order(value) except Exception as e: - prepend = f'{self.__class__}: apply filter: [{repr(field.name)}].order failed with error: ' + prepend = f"{self.__class__}: apply filter: [{repr(field.name)}].order failed with error: " if not field.fail_silently: raise Error(e).throw(prepend=prepend) - warnings.warn(f'{prepend}{e}') + warnings.warn(f"{prepend}{e}") if not multi(order): order = [order] self.orders.extend(order) - def process_order(self, order: Order, field: ModelFieldAdaptor, name: str, flag: int = 1): + def process_order( + self, order: Order, field: ModelFieldAdaptor, name: str, flag: int = 1 + ): if field.is_exp: self._add_annotate(name, field.field) name = field.query_name or name @@ -107,15 +112,21 @@ def process_order(self, order: Order, field: ModelFieldAdaptor, name: str, flag: f = f(nulls_last=True) order_field = f else: - order_field = ('-' if desc else '') + name + order_field = ("-" if desc else "") + name self.orders.append(order_field) - def _add_annotate(self, key, expression: exp.BaseExpression, distinct_count: bool = True): + def _add_annotate( + self, key, expression: exp.BaseExpression, distinct_count: bool = True + ): if not isinstance(expression, (exp.BaseExpression, exp.Combinable)): - raise TypeError(f'Invalid expression: {expression}') + raise TypeError(f"Invalid expression: {expression}") if distinct_count and isinstance(expression, exp.Count): expression.distinct = True if isinstance(expression, exp.Sum): - expression = exp.Subquery(models.QuerySet(model=self.model.model).filter( - pk=exp.OuterRef('pk')).annotate(v=expression).values('v')) + expression = exp.Subquery( + models.QuerySet(model=self.model.model) + .filter(pk=exp.OuterRef("pk")) + .annotate(v=expression) + .values("v") + ) self.annotates.setdefault(key, expression) diff --git a/utilmeta/core/orm/backends/django/model.py b/utilmeta/core/orm/backends/django/model.py index f7c109a..ee6bd95 100644 --- a/utilmeta/core/orm/backends/django/model.py +++ b/utilmeta/core/orm/backends/django/model.py @@ -1,6 +1,7 @@ from utilmeta.utils import SEG, awaitable from ..base import ModelFieldAdaptor, ModelAdaptor from typing import Tuple, Optional, List, Callable, Type + # from .queryset import AwaitableQuerySet from django.db import models from django.db.models.base import ModelBase @@ -30,9 +31,9 @@ class DjangoModelAdaptor(ModelAdaptor): def ident(self): meta = self.meta if not meta: - return '' + return "" app_label = meta.app_label - tag = '.'.join((app_label, self.model.__name__)) + tag = ".".join((app_label, self.model.__name__)) return tag.lower() @property @@ -67,6 +68,7 @@ def save_raw(self, pk=None, **data): async def asave_raw(self, pk=None, **data): inst = self.init_instance(pk, **data) from .queryset import AwaitableQuerySet + return await AwaitableQuerySet(model=self.model)._insert_obj(inst, raw=True) def create(self, d=None, **data) -> model_cls: @@ -124,7 +126,7 @@ async def avalues(self, q=None, *fields, **filters) -> List[dict]: def get_queryset(self, q=None, **filters) -> queryset_cls: # for django it's like model.objects.all() if isinstance(q, list): - q = models.Q(pk__in=[getattr(obj, 'pk', obj) for obj in q]) + q = models.Q(pk__in=[getattr(obj, "pk", obj) for obj in q]) args = (q,) if q else () try: @@ -164,10 +166,10 @@ async def aget_instance(self, q=None, **filters) -> model_cls: def init_instance(self, pk=None, **data): if pk: - data.setdefault('pk', pk) + data.setdefault("pk", pk) obj = self.model(**data) - if getattr(obj, 'id', None) is None: - setattr(obj, 'id', obj.pk or pk) + if getattr(obj, "id", None) is None: + setattr(obj, "id", obj.pk or pk) return obj def check_subquery(self, qs): @@ -175,14 +177,14 @@ def check_subquery(self, qs): return False if len(qs.query.select) > 1: # django.core.exceptions.FieldError: Cannot resolve expression type, unknown output_field - raise ValueError(f'Multiple fields selected in related queryset: {qs}') + raise ValueError(f"Multiple fields selected in related queryset: {qs}") if qs.query.is_sliced: hi = qs.query.high_mark lo = qs.query.low_mark if hi is not None and lo is not None: if hi - lo == 1: return True - raise ValueError('subquery result must be limited to 1 result') + raise ValueError("subquery result must be limited to 1 result") def check_queryset(self, qs, check_model: bool = False): if not isinstance(qs, self.queryset_cls): @@ -195,12 +197,12 @@ def check_queryset(self, qs, check_model: bool = False): def get_model(self, qs: models.QuerySet): if not isinstance(qs, self.queryset_cls): - raise TypeError(f'Invalid queryset: {qs}') + raise TypeError(f"Invalid queryset: {qs}") return self.__class__(qs.model) @property def meta(self) -> Options: - return getattr(self.model, '_meta') + return getattr(self.model, "_meta") @property def abstract(self): @@ -215,7 +217,7 @@ def table_name(self): @property def default_db_alias(self) -> str: - return self.get_queryset().db or 'default' + return self.get_queryset().db or "default" def get_parents(self): return self.meta.parents @@ -223,30 +225,34 @@ def get_parents(self): def cross_models(self, field: str): if not isinstance(field, str): return False - return '.' in field or '__' in field - - def get_field(self, name: str, validator: Callable = None, - silently: bool = False, - allow_addon: bool = False) -> Optional[field_adaptor_cls]: + return "." in field or "__" in field + + def get_field( + self, + name: str, + validator: Callable = None, + silently: bool = False, + allow_addon: bool = False, + ) -> Optional[field_adaptor_cls]: """ Get name from a field references """ if not name: - raise ValueError(f'{self.model}: empty field') + raise ValueError(f"{self.model}: empty field") if not isinstance(name, str): # field ref / expression return self.field_adaptor_cls(name, model=self) - if name == 'pk': + if name == "pk": return self.field_adaptor_cls(self.meta.pk, model=self, lookup_name=name) model = self.model - lookups = name.replace('.', SEG).split(SEG) + lookups = name.replace(".", SEG).split(SEG) f = None addon = None for i, lk in enumerate(lookups): try: if not model: raise exc.FieldDoesNotExist - meta: Options = getattr(model, '_meta') + meta: Options = getattr(model, "_meta") f = meta.get_field(lk) if callable(validator): validator(f) @@ -257,20 +263,24 @@ def get_field(self, name: str, validator: Callable = None, break if silently: return None - raise exc.FieldDoesNotExist(f'Field: {repr(name)} lookup {repr(lk)}' - f' of model {model} not exists: {e}') - return self.field_adaptor_cls(f, addon=addon, model=self, lookup_name=SEG.join(lookups)) + raise exc.FieldDoesNotExist( + f"Field: {repr(name)} lookup {repr(lk)}" + f" of model {model} not exists: {e}" + ) + return self.field_adaptor_cls( + f, addon=addon, model=self, lookup_name=SEG.join(lookups) + ) def get_backward(self, field: str) -> str: raise NotImplementedError def get_reverse_lookup(self, lookup: str) -> Tuple[str, Optional[str]]: reverse_fields = [] - lookups = lookup.replace('.', SEG).split(SEG) + lookups = lookup.replace(".", SEG).split(SEG) _model = self # relate1__relate2__common1__common2 common_index = None - common_field = '' + common_field = "" for i, name in enumerate(lookups): field = _model.get_field(name) if field.remote_field: @@ -318,7 +328,9 @@ def get_fields(self, many=False, no_inherit=False) -> List[ModelFieldAdaptor]: def get_related_adaptor(self, field): return self.__class__(field.related_model) if field.related_model else None - def gen_lookup_keys(self, field: str, keys, strict: bool = True, excludes: List[str] = None) -> list: + def gen_lookup_keys( + self, field: str, keys, strict: bool = True, excludes: List[str] = None + ) -> list: raise NotImplementedError def gen_lookup_filter(self, field, q, excludes: List[str] = None): @@ -328,10 +340,12 @@ def include_many_relates(self, field: str): if not field: return False if isinstance(field, (exp.BaseExpression, exp.Combinable)): - return self.include_many_relates(self.field_adaptor_cls.get_exp_field(field)) + return self.include_many_relates( + self.field_adaptor_cls.get_exp_field(field) + ) if not isinstance(field, str): return False - lookups = field.replace('.', SEG).split(SEG) + lookups = field.replace(".", SEG).split(SEG) mod = self.model for lkp in lookups: try: @@ -357,16 +371,23 @@ def resolve_output_field(self, expr): return r_field if not r_field: return l_field - if operator in ('+', '*', '/', '^'): - if isinstance(l_field, PositiveIntegerRelDbTypeMixin) \ - and isinstance(r_field, PositiveIntegerRelDbTypeMixin): + if operator in ("+", "*", "/", "^"): + if isinstance( + l_field, PositiveIntegerRelDbTypeMixin + ) and isinstance(r_field, PositiveIntegerRelDbTypeMixin): return l_field - if operator in ('+', '-', '*', '/', '^'): - if isinstance(l_field, models.FloatField) or isinstance(r_field, models.FloatField): + if operator in ("+", "-", "*", "/", "^"): + if isinstance(l_field, models.FloatField) or isinstance( + r_field, models.FloatField + ): return models.FloatField() - if isinstance(l_field, models.DecimalField) or isinstance(r_field, models.DecimalField): + if isinstance(l_field, models.DecimalField) or isinstance( + r_field, models.DecimalField + ): return models.DecimalField() - if isinstance(l_field, models.IntegerField) and isinstance(r_field, models.IntegerField): + if isinstance(l_field, models.IntegerField) and isinstance( + r_field, models.IntegerField + ): return models.IntegerField() return l_field or r_field elif isinstance(expr, exp.BaseExpression): @@ -412,13 +433,13 @@ def check_query(self, q): try: self.get_queryset(q) except exc.FieldError as e: - raise exc.FieldError(f'Invalid query {q}: {e}') + raise exc.FieldError(f"Invalid query {q}: {e}") def check_order(self, f): try: self.get_queryset().order_by(f) except exc.FieldError as e: - raise exc.FieldError(f'Invalid order field {repr(f)}: {e}') + raise exc.FieldError(f"Invalid order field {repr(f)}: {e}") def is_sub_model(self, model): if isinstance(model, DjangoModelAdaptor): diff --git a/utilmeta/core/orm/backends/django/models.py b/utilmeta/core/orm/backends/django/models.py index ebca968..831daa5 100644 --- a/utilmeta/core/orm/backends/django/models.py +++ b/utilmeta/core/orm/backends/django/models.py @@ -9,9 +9,9 @@ class PasswordField(CharField): def guess_encoded(cls, pwd: str): if len(pwd) < 60: return False - if pwd.count('$') < 3: + if pwd.count("$") < 3: return False - if not pwd.endswith('='): + if not pwd.endswith("="): return False return True @@ -20,16 +20,27 @@ def get_prep_value(self, value): return None from django.contrib.auth.hashers import make_password from utilmeta.utils.functional import gen_key + if self.guess_encoded(value): # already encoded, maybe error update using save() but did not specify update_fields return value return make_password(value, gen_key(self.salt_length)) - def __init__(self, max_length: int, min_length: int = 1, salt_length=32, - regex: str = None, *args, **kwargs): - kwargs['max_length'] = 80 + salt_length - assert isinstance(max_length, int) and isinstance(min_length, int) and max_length >= min_length > 0, \ - f'Password field length config must satisfy max_length >= min_length > 0' + def __init__( + self, + max_length: int, + min_length: int = 1, + salt_length=32, + regex: str = None, + *args, + **kwargs, + ): + kwargs["max_length"] = 80 + salt_length + assert ( + isinstance(max_length, int) + and isinstance(min_length, int) + and max_length >= min_length > 0 + ), f"Password field length config must satisfy max_length >= min_length > 0" self.salt_length = salt_length self.regex = regex self._max_length = max_length @@ -61,14 +72,18 @@ def get_prep_value(self, value): if v in self.values: return v if self.none_default: - raise ValueError(f"ChoiceField contains value: {value} out of choices scope {self.values}, " - f"if you don't wan't exception here, set a default value") + raise ValueError( + f"ChoiceField contains value: {value} out of choices scope {self.values}, " + f"if you don't wan't exception here, set a default value" + ) val = self.reverse_choices_map.get(v) if val: return val if self.none_default: - raise ValueError(f"ChoiceField contains value: {value} out of choices scope {self.keys + self.values}, " - f"if you don't wan't exception here, set a default value") + raise ValueError( + f"ChoiceField contains value: {value} out of choices scope {self.keys + self.values}, " + f"if you don't wan't exception here, set a default value" + ) return self.default def from_db_value(self, value, expression=None, connection=None): @@ -85,7 +100,7 @@ def to_python(self, value): return self.reverse_choices_map.get(value, self.default) if value in self.values: return value - return self.choices_map.get(value, '') + return self.choices_map.get(value, "") def get_value(self, value): if value is None: @@ -94,9 +109,16 @@ def get_value(self, value): return value return self.choices_map.get(value, value) - def __init__(self, choices: Union[Type[Static], Type[Enum], dict, tuple, List[str]], - retrieve_key: bool = False, store_key: bool = True, max_length: int = None, - default: Optional[str] = NOT_PROVIDED, *args, **kwargs): + def __init__( + self, + choices: Union[Type[Static], Type[Enum], dict, tuple, List[str]], + retrieve_key: bool = False, + store_key: bool = True, + max_length: int = None, + default: Optional[str] = NOT_PROVIDED, + *args, + **kwargs, + ): from utilmeta.utils.functional import repeat, multi import inspect import collections @@ -107,13 +129,17 @@ def __init__(self, choices: Union[Type[Static], Type[Enum], dict, tuple, List[st keys = [] values = [] _choices = [] - if inspect.isclass(choices) and issubclass(choices, Static) or isinstance(choices, Static): + if ( + inspect.isclass(choices) + and issubclass(choices, Static) + or isinstance(choices, Static) + ): choices = choices.dict(reverse=True) if inspect.isclass(choices) and issubclass(choices, Enum): choices = dict(choices.__members__) if not choices: - raise ValueError(f'ChoiceField must specify choices, got {choices}') + raise ValueError(f"ChoiceField must specify choices, got {choices}") if isinstance(choices, collections.Iterator): choices = list(choices) @@ -133,8 +159,9 @@ def __init__(self, choices: Union[Type[Static], Type[Enum], dict, tuple, List[st if isinstance(choices, list): for i, t in enumerate(choices): if isinstance(t, tuple): - assert len(t) == 2, \ - ValueError('Choice field for list choices must be a 2-item tuple') + assert len(t) == 2, ValueError( + "Choice field for list choices must be a 2-item tuple" + ) _k = t[0] if isinstance(t, tuple) else str(i) _v = t[1] if isinstance(t, tuple) else str(t) @@ -162,12 +189,16 @@ def __init__(self, choices: Union[Type[Static], Type[Enum], dict, tuple, List[st _choices.append(items) if not _choices: - raise ValueError(f'ChoiceField choices must be list/tuple/dict got {choices}') + raise ValueError( + f"ChoiceField choices must be list/tuple/dict got {choices}" + ) if not store_key: retrieve_key = False elif repeat(keys + values): - raise ValueError(f"ChoiceField choices's keys {keys} and values {values} should't repeat") + raise ValueError( + f"ChoiceField choices's keys {keys} and values {values} should't repeat" + ) self.keys = tuple(keys) self.values = tuple(values) @@ -183,22 +214,26 @@ def __init__(self, choices: Union[Type[Static], Type[Enum], dict, tuple, List[st self.default = default if default is None: - kwargs['null'] = True + kwargs["null"] = True elif not self.none_default: if self.store_key: if str(default) not in self.keys: try: self.default = self.reverse_choices_map[str(default)] except KeyError: - raise ValueError(f'ChoiceField default value: {default} ' - f'out of scope: {self.keys} and {self.values}') + raise ValueError( + f"ChoiceField default value: {default} " + f"out of scope: {self.keys} and {self.values}" + ) else: if str(default) not in self.values: try: self.default = self.choices_map[str(default)] except KeyError: - raise ValueError(f'ChoiceField default value: {default} ' - f'out of scope: {self.keys} and {self.values}') + raise ValueError( + f"ChoiceField default value: {default} " + f"out of scope: {self.keys} and {self.values}" + ) # self.default = str(default) \ # if str(default) in self.keys else self.reverse_choices_map[default] @@ -209,13 +244,15 @@ def __init__(self, choices: Union[Type[Static], Type[Enum], dict, tuple, List[st _max_length = max([len(c) for c in self.choices_map.values()]) if max_length and max_length < _max_length: - raise ValueError(f'ChoiceField max_length: {max_length} ' - f'if less than the longest choice length: {_max_length}') - - kwargs['max_length'] = max_length or _max_length - kwargs['default'] = self.default - kwargs['choices'] = tuple(_choices) - self._choices = _choices # be list type + raise ValueError( + f"ChoiceField max_length: {max_length} " + f"if less than the longest choice length: {_max_length}" + ) + + kwargs["max_length"] = max_length or _max_length + kwargs["default"] = self.default + kwargs["choices"] = tuple(_choices) + self._choices = _choices # be list type super().__init__(*args, **kwargs) @property @@ -224,9 +261,9 @@ def none_default(self): def deconstruct(self): name, path, args, kwargs = super().deconstruct() - kwargs['retrieve_key'] = self.retrieve_key - kwargs['store_key'] = self.store_key - kwargs['choices'] = self._choices + kwargs["retrieve_key"] = self.retrieve_key + kwargs["store_key"] = self.store_key + kwargs["choices"] = self._choices return name, path, args, kwargs @@ -236,7 +273,9 @@ class AwaitableModel(models.Model): class Meta: abstract = True - async def asave(self, force_insert=False, force_update=False, using=None, update_fields=None): + async def asave( + self, force_insert=False, force_update=False, using=None, update_fields=None + ): qs = AwaitableQuerySet(model=self.__class__).using(using) if force_insert: await qs._insert_obj(self) @@ -248,7 +287,7 @@ async def asave(self, force_insert=False, force_update=False, using=None, update for field in fields: name = field if not isinstance(field, str): - if hasattr(field, 'attname'): + if hasattr(field, "attname"): name = field.attname if not isinstance(name, str): continue @@ -261,7 +300,9 @@ async def asave(self, force_insert=False, force_update=False, using=None, update async def adelete(self): if not self.pk: return - return await AwaitableQuerySet(model=self.__class__).filter(pk=self.pk).adelete() + return ( + await AwaitableQuerySet(model=self.__class__).filter(pk=self.pk).adelete() + ) async def ACASCADE(collector, field, sub_objs, using): @@ -273,6 +314,7 @@ async def ACASCADE(collector, field, sub_objs, using): fail_on_restricted=False, ) from django.db import connections + if field.null and not connections[using].features.can_defer_constraint_checks: collector.add_field_update(field, None, sub_objs) @@ -289,7 +331,7 @@ class AbstractSession(AwaitableModel): created_time = models.DateTimeField(auto_now_add=True) last_activity = models.DateTimeField(default=None, null=True) expiry_time = models.DateTimeField(default=None, null=True) - deleted_time = models.DateTimeField(default=None, null=True) # already expired + deleted_time = models.DateTimeField(default=None, null=True) # already expired class Meta: abstract = True diff --git a/utilmeta/core/orm/backends/django/query.py b/utilmeta/core/orm/backends/django/query.py index ed7e1ff..d7e2c94 100644 --- a/utilmeta/core/orm/backends/django/query.py +++ b/utilmeta/core/orm/backends/django/query.py @@ -76,6 +76,7 @@ async def async_pre_sql_setup(self): self.query.clear_where() else: from django.db.models.sql.where import WhereNode + self.query.where = WhereNode() if self.query.related_updates or must_pre_select: @@ -84,6 +85,7 @@ async def async_pre_sql_setup(self): # selecting from the updating table (e.g. MySQL). idents = [] import collections + related_ids = collections.defaultdict(list) compiler = query.get_compiler(self.using) q, params = compiler.as_sql() @@ -133,11 +135,10 @@ async def async_execute_sql(self, case_update: bool = False): if q: await db.fetchone(q, params) from django.db import connections + for query in self.query.get_related_updates(): compiler = self.__class__( - query, - connection=connections[self.using], - using=self.using + query, connection=connections[self.using], using=self.using ) await compiler.async_execute_sql(case_update) @@ -148,7 +149,7 @@ def _parse_update_param(cls, param): elif isinstance(param, timedelta): return param.total_seconds() elif isinstance(param, (list, tuple, set)): - return '{%s}' % ','.join([str(p) for p in param]) + return "{%s}" % ",".join([str(p) for p in param]) return param @@ -158,11 +159,14 @@ class AwaitableQuery(sql.Query): # adapt backend if django.VERSION < (4, 2): if django.VERSION < (3, 2): + def exists(self, using, limit=True): q = self.clone() if not q.distinct: if q.group_by is True: - q.add_fields((f.attname for f in self.model._meta.concrete_fields), False) + q.add_fields( + (f.attname for f in self.model._meta.concrete_fields), False + ) # Disable GROUP BY aliases to avoid orphaning references to the # SELECT clause which is about to be cleared. q.set_group_by(allow_aliases=False) @@ -170,13 +174,15 @@ def exists(self, using, limit=True): q.clear_ordering(True) q.set_limits(high=1) compiler = q.get_compiler(using=using) - compiler.query.add_extra({'a': 1}, None, None, None, None, None) - compiler.query.set_extra_mask(['a']) + compiler.query.add_extra({"a": 1}, None, None, None, None, None) + compiler.query.set_extra_mask(["a"]) return compiler.query + else: + def exists(self, using, limit=True): q = super().exists(using, limit=limit) - q.add_annotation(Value("1"), "a") # use str instead of int + q.add_annotation(Value("1"), "a") # use str instead of int return q def get_aggregation_query(self, added_aggregate_names, using=None): @@ -209,6 +215,7 @@ def get_aggregation_query(self, added_aggregate_names, using=None): or self.combinator ): from django.db.models.sql.subqueries import AggregateQuery + inner_query = self.clone() inner_query.subquery = True if django.VERSION < (3, 2): @@ -236,7 +243,9 @@ def get_aggregation_query(self, added_aggregate_names, using=None): ) if inner_query.default_cols and has_existing_aggregate_annotations: inner_query.group_by = ( - self.model._meta.pk.get_col(inner_query.get_initial_alias()), + self.model._meta.pk.get_col( + inner_query.get_initial_alias() + ), ) inner_query.default_cols = False @@ -248,7 +257,9 @@ def get_aggregation_query(self, added_aggregate_names, using=None): for alias, expression in list(inner_query.annotation_select.items()): annotation_select_mask = inner_query.annotation_select_mask if expression.is_summary: - expression, col_cnt = inner_query.rewrite_cols(expression, col_cnt) + expression, col_cnt = inner_query.rewrite_cols( + expression, col_cnt + ) outer_query.annotations[alias] = expression.relabeled_clone( relabels ) @@ -272,10 +283,7 @@ def get_aggregation_query(self, added_aggregate_names, using=None): try: outer_query.add_subquery(inner_query, using) except EmptyResultSet: - return { - alias: None - for alias in outer_query.annotation_select - } + return {alias: None for alias in outer_query.annotation_select} else: outer_query = self self.select = () @@ -287,10 +295,12 @@ def get_aggregation_query(self, added_aggregate_names, using=None): outer_query.select_for_update = False outer_query.select_related = False return outer_query + else: + def exists(self, limit=True): q = super().exists(limit=limit) - q.add_annotation(Value("1"), "a") # use str instead of int + q.add_annotation(Value("1"), "a") # use str instead of int return q def get_aggregation_query(self, aggregate_exprs, using=None): @@ -357,7 +367,9 @@ def get_aggregation_query(self, aggregate_exprs, using=None): # used. if inner_query.default_cols and has_existing_aggregation: inner_query.group_by = ( - self.model._meta.pk.get_col(inner_query.get_initial_alias()), + self.model._meta.pk.get_col( + inner_query.get_initial_alias() + ), ) inner_query.default_cols = False if not qualify: diff --git a/utilmeta/core/orm/backends/django/queryset.py b/utilmeta/core/orm/backends/django/queryset.py index 535cfe6..8337556 100644 --- a/utilmeta/core/orm/backends/django/queryset.py +++ b/utilmeta/core/orm/backends/django/queryset.py @@ -1,8 +1,12 @@ import inspect from django.db.models import QuerySet, Manager, Model, sql, AutoField from django.db.models.options import Options -from django.db.models.query import ValuesListIterable, NamedValuesListIterable, \ - FlatValuesListIterable, ModelIterable +from django.db.models.query import ( + ValuesListIterable, + NamedValuesListIterable, + FlatValuesListIterable, + ModelIterable, +) from django.core import exceptions from django.utils.functional import partition from utilmeta.utils import awaitable @@ -17,6 +21,7 @@ try: from django.db.models.utils import resolve_callables except ImportError: + def resolve_callables(mapping): for k, v in mapping.items(): yield k, v() if callable(v) else v @@ -43,7 +48,9 @@ class AwaitableQuerySet(QuerySet): collector_cls = AwaitableCollector def __init__(self, model, query=None, using=None, hints=None): - super().__init__(model, query=query or self.query_cls(model), using=using, hints=hints) + super().__init__( + model, query=query or self.query_cls(model), using=using, hints=hints + ) def __aiter__(self): async def generator(): @@ -61,6 +68,7 @@ async def generator(): # yield namedtuple(self.model.__name__, field_names=list(item))(*item.values()) else: yield item + return generator() def as_sql(self) -> Tuple[str, tuple]: @@ -98,11 +106,14 @@ def _convert_raw_values(self, values, query): res[name] = value result.append(res) - if issubclass(self._iterable_class, (ValuesListIterable, FlatValuesListIterable)): + if issubclass( + self._iterable_class, (ValuesListIterable, FlatValuesListIterable) + ): list_result = [] if self._iterable_class == NamedValuesListIterable: from collections import namedtuple - t = namedtuple('Row', names) + + t = namedtuple("Row", names) for item in result: list_result.append(t(**item)) else: @@ -131,12 +142,12 @@ def fill_model_instance(self, values: dict): if val is Ellipsis: continue obj_values[field.column] = val - pk = values.get('id', values.get('pk')) + pk = values.get("id", values.get("pk")) if pk is not None: - obj_values.setdefault('pk', pk) + obj_values.setdefault("pk", pk) obj = self.model(**obj_values) - if getattr(obj, 'id', None) is None: - setattr(obj, 'id', obj.pk) + if getattr(obj, "id", None) is None: + setattr(obj, "id", obj.pk) return obj def instance(self, *args, **kwargs) -> Optional[Model]: @@ -153,10 +164,12 @@ async def instance(self, *args, **kwargs) -> Optional[Model]: return self.fill_model_instance(values) async def afirst(self): - return await (self if self.ordered else self.order_by('pk'))[:1].instance() + return await (self if self.ordered else self.order_by("pk"))[:1].instance() async def alast(self): - return await (self.reverse() if self.ordered else self.order_by('pk'))[:1].instance() + return await (self.reverse() if self.ordered else self.order_by("pk"))[ + :1 + ].instance() def result(self, one: bool = False): result = list(self) @@ -196,7 +209,7 @@ async def result(self, one: bool = False): @property def meta(self) -> Options: - return getattr(self.model, '_meta') + return getattr(self.model, "_meta") async def acreate(self, **kwargs): obj: Model = self.model(**kwargs) @@ -211,13 +224,13 @@ async def acreate(self, **kwargs): async def _insert_obj_parents(self, obj: Model, cls=None): """Save all the parents of cls using values from self.""" cls = cls or obj.__class__ - if getattr(cls, '_meta').proxy: - cls = getattr(cls, '_meta').concrete_model + if getattr(cls, "_meta").proxy: + cls = getattr(cls, "_meta").concrete_model meta = cls._meta for parent, field in meta.parents.items(): # Make sure the link fields are synced between parent and self. - parent_meta: Options = getattr(parent, '_meta') + parent_meta: Options = getattr(parent, "_meta") if ( field and getattr(obj, parent_meta.pk.attname) is None @@ -243,8 +256,8 @@ async def _insert_obj_parents(self, obj: Model, cls=None): async def _insert_obj(self, obj: Model, cls=None, raw: bool = False): cls = cls or obj.__class__ - if getattr(cls, '_meta').proxy: - cls = getattr(cls, '_meta').concrete_model + if getattr(cls, "_meta").proxy: + cls = getattr(cls, "_meta").concrete_model meta: Options = cls._meta pk_val = getattr(obj, meta.pk.attname) @@ -263,10 +276,7 @@ async def _insert_obj(self, obj: Model, cls=None, raw: bool = False): fields = [f for f in fields if f is not meta.auto_field] results = await self._async_insert( - [obj], - fields=fields, cls=cls, - returning_fields=returning_fields, - raw=raw + [obj], fields=fields, cls=cls, returning_fields=returning_fields, raw=raw ) if results: obj_value = results[0] @@ -325,15 +335,15 @@ async def _async_insert( on_conflict=None, update_fields=None, unique_fields=None, - ignore_conflicts=False # compat django 3 + ignore_conflicts=False, # compat django 3 ): self._for_write = True if using is None: using = self.db cls = cls or self.model - if getattr(cls, '_meta').proxy: - cls = getattr(cls, '_meta').concrete_model + if getattr(cls, "_meta").proxy: + cls = getattr(cls, "_meta").concrete_model if django.VERSION > (4, 1): query = sql.InsertQuery( @@ -365,9 +375,7 @@ async def _async_insert( if can_return: rows = await db.fetchall(q, params) else: - rows = [ - {self.meta.pk.column: await db.execute(q, params)} - ] + rows = [{self.meta.pk.column: await db.execute(q, params)}] for val in rows: val: dict tuple_values = [] @@ -394,7 +402,7 @@ async def abulk_create( update_conflicts=False, update_fields=None, unique_fields=None, - no_transaction: bool = False + no_transaction: bool = False, ): """ Internal django implementation of bulk_create is too complicate to split into async code @@ -410,6 +418,7 @@ async def abulk_create( if has_parent: tasks = [] import asyncio + for obj in objs: tasks.append(asyncio.create_task(self._insert_obj(obj))) try: @@ -444,7 +453,9 @@ async def abulk_create( self._prepare_for_bulk_create(objs) db = self.connections_cls.get(self.db) - async with (DummyContent() if no_transaction else db.async_transaction(savepoint=False)): + async with ( + DummyContent() if no_transaction else db.async_transaction(savepoint=False) + ): objs_with_pk, objs_without_pk = partition(lambda o: o.pk is None, objs) if objs_with_pk: returned_columns = await self._async_batched_insert( @@ -496,11 +507,11 @@ async def acount(self) -> int: # query.clear_select_fields() # query.clear_ordering() if django.VERSION < (4, 2): - query.add_annotation(Count('*'), alias='__count', is_summary=True) - r = await query.aget_aggregation(self.db, ['__count']) or {} + query.add_annotation(Count("*"), alias="__count", is_summary=True) + r = await query.aget_aggregation(self.db, ["__count"]) or {} else: r = await query.aget_aggregation(self.db, {"__count": Count("*")}) or {} - number = r.get('__count') or r.get('count') or r.get('COUNT(*)') + number = r.get("__count") or r.get("count") or r.get("COUNT(*)") if number is None and r: number = list(r.values())[0] # weird, don't know why now @@ -620,6 +631,7 @@ async def aget_or_create(self, defaults=None, **kwargs): @property def conn(self): from django.db import connections + return connections[self.db] async def aget(self, *args, **kwargs): @@ -676,7 +688,9 @@ async def aupdate_or_create(self, defaults=None, **kwargs): await self.__class__(self.model).filter(pk=obj.pk).aupdate(**params) return obj, False - async def abulk_update(self, objs, fields, batch_size=None, no_transaction: bool = False): + async def abulk_update( + self, objs, fields, batch_size=None, no_transaction: bool = False + ): if batch_size is not None and batch_size < 0: raise ValueError("Batch size must be a positive integer.") if not fields: @@ -702,7 +716,7 @@ async def abulk_update(self, objs, fields, batch_size=None, no_transaction: bool max_batch_size = connection.ops.bulk_batch_size(["pk", "pk"] + fields, objs) batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size requires_casting = connection.features.requires_casted_case_in_updates - batches = (objs[i: i + batch_size] for i in range(0, len(objs), batch_size)) + batches = (objs[i : i + batch_size] for i in range(0, len(objs), batch_size)) updates = [] for batch_objs in batches: update_kwargs = {} @@ -721,7 +735,9 @@ async def abulk_update(self, objs, fields, batch_size=None, no_transaction: bool # rows_updated = 0 queryset = self.using(self.db) db = self.connections_cls.get(self.db) - async with (DummyContent() if no_transaction else db.async_transaction(savepoint=False)): + async with ( + DummyContent() if no_transaction else db.async_transaction(savepoint=False) + ): for pks, update_kwargs in updates: await queryset.filter(pk__in=pks).aupdate(**update_kwargs) # return rows_updated @@ -759,7 +775,7 @@ async def adelete(self): if collector.can_fast_delete(del_query): await collector.acollect(del_query) else: - pks = await self.values_list('pk', flat=True).result() + pks = await self.values_list("pk", flat=True).result() if not pks: return True, 0 await collector.acollect([self.model(pk=pk) for pk in pks]) @@ -770,4 +786,5 @@ async def adelete(self): return deleted, _rows_count -class AwaitableManager(Manager.from_queryset(AwaitableQuerySet)): pass +class AwaitableManager(Manager.from_queryset(AwaitableQuerySet)): + pass diff --git a/utilmeta/core/orm/backends/peewee/example.py b/utilmeta/core/orm/backends/peewee/example.py index 2be9705..6a3222d 100644 --- a/utilmeta/core/orm/backends/peewee/example.py +++ b/utilmeta/core/orm/backends/peewee/example.py @@ -1,7 +1,7 @@ from peewee import * from utilmeta.core import module -database = SqliteDatabase('my_db') +database = SqliteDatabase("my_db") # model definitions -- the standard "pattern" is to define a base model class # that specifies which database to use. then, any subclasses will automatically @@ -12,6 +12,7 @@ class BaseModel(Model): class Meta: database = database + # the user model specifies its fields (or columns) declaratively, like django @@ -25,10 +26,12 @@ class User(BaseModel): a = User.username.contains -query = module.Query({ - 'username_contains': User.username.contains(module.P), - 'data>': User.join_date > module.P, - 'data>=': User.join_date >= module.P, - 'data<': User.join_date < module.P, - 'email_isnull': User.email.is_null(module.P) -}) +query = module.Query( + { + "username_contains": User.username.contains(module.P), + "data>": User.join_date > module.P, + "data>=": User.join_date >= module.P, + "data<": User.join_date < module.P, + "email_isnull": User.email.is_null(module.P), + } +) diff --git a/utilmeta/core/orm/backends/peewee/peewee.py b/utilmeta/core/orm/backends/peewee/peewee.py index 413561e..9d951a9 100644 --- a/utilmeta/core/orm/backends/peewee/peewee.py +++ b/utilmeta/core/orm/backends/peewee/peewee.py @@ -1,6 +1,11 @@ from peewee import Model, ModelSelect, Expression, fn from typing import Type -from ..base import ModelAdaptor, QuerysetAdaptor, QueryExpressionAdaptor, ModelFieldAdaptor +from ..base import ( + ModelAdaptor, + QuerysetAdaptor, + QueryExpressionAdaptor, + ModelFieldAdaptor, +) class PeeweeModelFieldAdaptor(ModelFieldAdaptor): @@ -44,4 +49,3 @@ def qualify(cls, impl): def get_queryset(self): return self.model.select() - diff --git a/utilmeta/core/orm/compiler.py b/utilmeta/core/orm/compiler.py index a999e2a..e9be83c 100644 --- a/utilmeta/core/orm/compiler.py +++ b/utilmeta/core/orm/compiler.py @@ -11,7 +11,12 @@ class TransactionWrapper: - def __init__(self, model: 'ModelAdaptor', transaction: Union[str, bool] = False, errors_map: dict = None): + def __init__( + self, + model: "ModelAdaptor", + transaction: Union[str, bool] = False, + errors_map: dict = None, + ): # self.enabled = bool(transaction) db_alias = None if isinstance(transaction, str): @@ -24,6 +29,7 @@ def __init__(self, model: 'ModelAdaptor', transaction: Union[str, bool] = False, self.db_alias = db_alias from .plugins.atomic import AtomicPlugin + self.atomic = AtomicPlugin(db_alias) if db_alias else None self.errors_map = errors_map or {} @@ -82,12 +88,16 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): class BaseQueryCompiler: - def __init__(self, parser: SchemaClassParser, queryset, context: QueryContext = None): + def __init__( + self, parser: SchemaClassParser, queryset, context: QueryContext = None + ): self.parser = parser self.model = parser.model self.queryset = queryset self.context = context or QueryContext() - self.recursive_map: Dict[Any, Dict[Any, dict]] = self.context.recursion_map or {} + self.recursive_map: Dict[Any, Dict[Any, dict]] = ( + self.context.recursion_map or {} + ) self.pk_list = [] self.pk_map = {} # self.recursive_pk_list = [] @@ -104,9 +114,12 @@ def get_integrity_error(self, e: Exception) -> Exception: return self.context.integrity_error_cls(e) return e - def get_related_context(self, field: ParserQueryField, - force_expressions: dict = None, - force_raise_error: bool = False): + def get_related_context( + self, + field: ParserQueryField, + force_expressions: dict = None, + force_raise_error: bool = False, + ): includes = excludes = None if self.context.includes: inter = set(self.context.includes).intersection(field.all_aliases) @@ -117,7 +130,7 @@ def get_related_context(self, field: ParserQueryField, return QueryContext( using=self.context.using, # single=field.related_single, - single=False, # not make it single, related context is always about multiple + single=False, # not make it single, related context is always about multiple includes=includes, excludes=excludes, recursive_map=self.recursive_map, @@ -160,7 +173,9 @@ def process_fields(self): continue if not field.readable: continue - if not self.context.in_scope(field.all_aliases, dependants=field.dependants): + if not self.context.in_scope( + field.all_aliases, dependants=field.dependants + ): continue self.process_query_field(field) @@ -171,7 +186,9 @@ def get_values(self) -> List[dict]: async def get_values(self) -> List[dict]: raise NotImplementedError - def process_data(self, data: dict, with_relations: bool = None) -> Tuple[dict, dict, dict]: + def process_data( + self, data: dict, with_relations: bool = None + ) -> Tuple[dict, dict, dict]: if not isinstance(data, dict): return {}, {}, {} if not isinstance(data, self.parser.obj): @@ -226,28 +243,32 @@ def commit_data(self, data): async def commit_data(self, data): raise NotImplementedError - def save_data(self, - data, - must_create: bool = False, - must_update: bool = False, - ignore_bulk_errors: bool = False, - ignore_relation_errors: bool = False, - with_relations: bool = None, - transaction: bool = False, - ): + def save_data( + self, + data, + must_create: bool = False, + must_update: bool = False, + ignore_bulk_errors: bool = False, + ignore_relation_errors: bool = False, + with_relations: bool = None, + transaction: bool = False, + ): raise NotImplementedError @awaitable(save_data) - async def save_data(self, - data, - must_create: bool = False, - must_update: bool = False, - ignore_bulk_errors: bool = False, - ignore_relation_errors: bool = False, - with_relations: bool = None, - transaction: bool = False, - ): + async def save_data( + self, + data, + must_create: bool = False, + must_update: bool = False, + ignore_bulk_errors: bool = False, + ignore_relation_errors: bool = False, + with_relations: bool = None, + transaction: bool = False, + ): raise NotImplementedError - def get_integrity_errors(self, asynchronous: bool = False) -> Tuple[Type[Exception], ...]: + def get_integrity_errors( + self, asynchronous: bool = False + ) -> Tuple[Type[Exception], ...]: return () diff --git a/utilmeta/core/orm/context.py b/utilmeta/core/orm/context.py index c6b1ad4..36c2b18 100644 --- a/utilmeta/core/orm/context.py +++ b/utilmeta/core/orm/context.py @@ -21,7 +21,9 @@ def in_scope(self, aliases: List[str], dependants: List[str] = None): if not aliases: return False if self.includes: - return bool(set(aliases).union(dependants or []).intersection(self.includes)) + return bool( + set(aliases).union(dependants or []).intersection(self.includes) + ) if self.excludes: return not set(aliases).intersection(self.excludes) return True diff --git a/utilmeta/core/orm/databases/base.py b/utilmeta/core/orm/databases/base.py index 2f9722f..0f103b7 100644 --- a/utilmeta/core/orm/databases/base.py +++ b/utilmeta/core/orm/databases/base.py @@ -8,12 +8,12 @@ class BaseDatabaseAdaptor: asynchronous = False DEFAULT_ENGINES = {} - def __init__(self, config: 'Database', alias: str = None): + def __init__(self, config: "Database", alias: str = None): self.config = config self.alias = alias def get_engine(self): - if '.' in self.config.engine: + if "." in self.config.engine: return self.config.engine if self.config.engine.lower() in self.DEFAULT_ENGINES: return self.DEFAULT_ENGINES[self.config.engine.lower()] diff --git a/utilmeta/core/orm/databases/config.py b/utilmeta/core/orm/databases/config.py index 29309be..c6d8c80 100644 --- a/utilmeta/core/orm/databases/config.py +++ b/utilmeta/core/orm/databases/config.py @@ -13,11 +13,9 @@ class Database(Config): This is just a declaration interface for database the real implementation is database adaptor """ - DEFAULT_HOST = '127.0.0.1' - DEFAULT_PORTS = { - 'postgres': 5432, - 'mysql': 3306 - } + + DEFAULT_HOST = "127.0.0.1" + DEFAULT_PORTS = {"postgres": 5432, "mysql": 3306} sync_adaptor_cls = None async_adaptor_cls = EncodeDatabasesAsyncAdaptor @@ -27,33 +25,35 @@ class Database(Config): # --- name: str - engine: str = 'sqlite' - user: str = '' - password: str = '' - host: str = '' + engine: str = "sqlite" + user: str = "" + password: str = "" + host: str = "" port: Optional[int] = None time_zone: Optional[str] = None ssl: Any = None max_size: Optional[int] = None min_size: Optional[int] = None max_age: Optional[int] = 0 - replica_of: Optional['Database'] = None + replica_of: Optional["Database"] = None options: Optional[dict] = None - def __init__(self, - name: str, - engine: str = 'sqlite', - user: str = '', - password: str = '', - host: str = '', - port: Optional[int] = None, - time_zone: Optional[str] = None, - ssl: Any = None, - max_size: Optional[int] = None, # connection pool - min_size: Optional[int] = None, # connection pool - max_age: Optional[int] = 0, # connection max age - replica_of: Optional['Database'] = None, - options: Optional[dict] = None): + def __init__( + self, + name: str, + engine: str = "sqlite", + user: str = "", + password: str = "", + host: str = "", + port: Optional[int] = None, + time_zone: Optional[str] = None, + ssl: Any = None, + max_size: Optional[int] = None, # connection pool + min_size: Optional[int] = None, # connection pool + max_age: Optional[int] = 0, # connection max age + replica_of: Optional["Database"] = None, + options: Optional[dict] = None, + ): super().__init__(locals()) self.host = self.host or self.DEFAULT_HOST if not self.port: @@ -68,7 +68,7 @@ def __init__(self, def params(self): options = dict(self.options or {}) if self.ssl: - options.update(ssl=self.ssl) # True or other ssl context + options.update(ssl=self.ssl) # True or other ssl context if self.max_size: options.update(max_size=self.max_size) if self.min_size: @@ -85,23 +85,23 @@ def local(self): def location(self): if self.is_sqlite: return self.name - return f'{self.host}:{self.port}' + return f"{self.host}:{self.port}" @property def is_sqlite(self): - return 'sqlite' in self.engine + return "sqlite" in self.engine @property def is_postgresql(self): - return 'postgres' in self.engine + return "postgres" in self.engine @property def is_mysql(self): - return 'mysql' in self.engine + return "mysql" in self.engine @property def is_oracle(self): - return 'oracle' in self.engine + return "oracle" in self.engine @property def alias(self): @@ -120,13 +120,13 @@ def database_name(self): @property def type(self): if self.is_sqlite: - return 'sqlite' + return "sqlite" elif self.is_postgresql: - return 'postgresql' + return "postgresql" elif self.is_mysql: - return 'mysql' + return "mysql" elif self.is_oracle: - return 'oracle' + return "oracle" return self.engine @property @@ -139,30 +139,31 @@ def dsn(self): # https://stackoverflow.com/a/19262231/14026109 # Also, as Windows doesn't have the concept of root # and instead uses drives, you have to specify absolute path with 3 slashes - return '/' + self.name + return "/" + self.name else: user = self.user if self.password: from urllib.parse import quote + # for special chars like @ will disrupt DNS - user += f':{quote(self.password)}' + user += f":{quote(self.password)}" netloc = self.host if self.port: - netloc += f':{self.port}' - return f'{user}@{netloc}/{self.name}' + netloc += f":{self.port}" + return f"{user}@{netloc}/{self.name}" @property def protected_dsn(self): if self.is_sqlite: - return '/' + self.name + return "/" + self.name else: user = self.user if self.password: - user += f':******' + user += f":******" netloc = self.host if self.port: - netloc += f':{self.port}' - return f'{user}@{netloc}/{self.name}' + netloc += f":{self.port}" + return f"{user}@{netloc}/{self.name}" def apply(self, alias: str, asynchronous: bool = None, project_dir: str = None): if asynchronous: @@ -174,10 +175,11 @@ def apply(self, alias: str, asynchronous: bool = None, project_dir: str = None): else: # default from ..backends.django.database import DjangoDatabaseAdaptor + self.adaptor = DjangoDatabaseAdaptor(self, alias) if not self.adaptor: - raise exceptions.NotConfigured('Database adaptor not implemented') + raise exceptions.NotConfigured("Database adaptor not implemented") if self.is_sqlite and project_dir: if not os.path.isabs(self.name): self.name = str(os.path.join(project_dir, self.name)) @@ -226,11 +228,19 @@ def fetchall(self, sql, params=None) -> List[dict]: async def fetchall(self, sql, params=None) -> List[dict]: return await self.get_adaptor(True).fetchall(sql, params) - def transaction(self, savepoint=None, isolation=None, force_rollback: bool = False) -> ContextManager: - return self.get_adaptor(False).transaction(savepoint, isolation, force_rollback=force_rollback) + def transaction( + self, savepoint=None, isolation=None, force_rollback: bool = False + ) -> ContextManager: + return self.get_adaptor(False).transaction( + savepoint, isolation, force_rollback=force_rollback + ) - def async_transaction(self, savepoint=None, isolation=None, force_rollback: bool = False) -> AsyncContextManager: - return self.get_adaptor(True).transaction(savepoint, isolation, force_rollback=force_rollback) + def async_transaction( + self, savepoint=None, isolation=None, force_rollback: bool = False + ) -> AsyncContextManager: + return self.get_adaptor(True).transaction( + savepoint, isolation, force_rollback=force_rollback + ) class DatabaseConnections(Config): @@ -251,12 +261,14 @@ def add_database(self, service: UtilMeta, alias: str, database: Database): if not database.async_adaptor_cls: if service.adaptor and service.adaptor.async_db_adaptor_cls: database.async_adaptor_cls = service.adaptor.async_db_adaptor_cls - database.apply(alias, asynchronous=service.asynchronous, project_dir=service.project_dir) + database.apply( + alias, asynchronous=service.asynchronous, project_dir=service.project_dir + ) if alias not in self.databases: self.databases.setdefault(alias, database) @classmethod - def get(cls, alias: str = 'default') -> Database: + def get(cls, alias: str = "default") -> Database: config = cls.config() if not config: raise exceptions.NotConfigured(cls) diff --git a/utilmeta/core/orm/databases/encode.py b/utilmeta/core/orm/databases/encode.py index e5bdca7..485342a 100644 --- a/utilmeta/core/orm/databases/encode.py +++ b/utilmeta/core/orm/databases/encode.py @@ -58,70 +58,76 @@ async def __aexit__(self, exc_type, exc_value, traceback): class EncodeDatabasesAsyncAdaptor(BaseDatabaseAdaptor): asynchronous = True - POSTGRESQL = 'postgresql+asyncpg' - POSTGRESQL_AIOPG = 'postgresql+aiopg' - MYSQL = 'mysql+aiomysql' - MYSQL_ASYNCMY = 'mysql+asyncmy' - SQLITE = 'sqlite+aiosqlite' + POSTGRESQL = "postgresql+asyncpg" + POSTGRESQL_AIOPG = "postgresql+aiopg" + MYSQL = "mysql+aiomysql" + MYSQL_ASYNCMY = "mysql+asyncmy" + SQLITE = "sqlite+aiosqlite" DEFAULT_ENGINES = { - 'sqlite': SQLITE, - 'sqlite3': SQLITE, - 'mysql': MYSQL, - 'postgresql': POSTGRESQL, - 'postgres': POSTGRESQL + "sqlite": SQLITE, + "sqlite3": SQLITE, + "mysql": MYSQL, + "postgresql": POSTGRESQL, + "postgres": POSTGRESQL, } DEFAULT_ASYNC_ENGINES = { - 'sqlite': 'sqlite+aiosqlite', - 'mysql': 'mysql+aiomysql', - 'postgres': 'postgresql+asyncpg' + "sqlite": "sqlite+aiosqlite", + "mysql": "mysql+aiomysql", + "postgres": "postgresql+asyncpg", } - def __init__(self, config: 'Database', alias: str = None): + def __init__(self, config: "Database", alias: str = None): super().__init__(config, alias=alias) self.async_engine = None self.db_backend = None self.engine = None - if '+' in self.config.engine: - self.db_backend, self.async_engine = self.config.engine.split('+') + if "+" in self.config.engine: + self.db_backend, self.async_engine = self.config.engine.split("+") self.engine = self.config.engine else: for name, engine in self.DEFAULT_ASYNC_ENGINES.items(): if name in self.config.engine.lower(): self.engine = engine - self.db_backend, self.async_engine = self.engine.split('+') + self.db_backend, self.async_engine = self.engine.split("+") break if not self.engine: - raise ValueError(f'{self.__class__.__name__}: engine invalid or not implemented: ' - f'{repr(self.config.engine)}') + raise ValueError( + f"{self.__class__.__name__}: engine invalid or not implemented: " + f"{repr(self.config.engine)}" + ) - self._db = None # process local + self._db = None # process local self._processed = False # import threading # self.local = threading.local() # thread local # self._var_db = contextvars.ContextVar('db') # coroutine local def get_integrity_errors(self): - if self.db_backend in ('postgres', 'postgresql'): + if self.db_backend in ("postgres", "postgresql"): errors = [] try: from asyncpg.exceptions import IntegrityConstraintViolationError + errors.append(IntegrityConstraintViolationError) except (ImportError, ModuleNotFoundError): pass try: from psycopg2 import IntegrityError + errors.append(IntegrityError) except (ImportError, ModuleNotFoundError): pass return tuple(errors) - elif self.db_backend in ('sqlite', 'sqlite3'): + elif self.db_backend in ("sqlite", "sqlite3"): from sqlite3 import IntegrityError + return (IntegrityError,) - elif self.db_backend == 'mysql': + elif self.db_backend == "mysql": errors = [] try: from pymysql.err import IntegrityError + errors.append(IntegrityError) except (ImportError, ModuleNotFoundError): pass @@ -134,22 +140,23 @@ def get_db(self): # return getattr(self.local, 'db', None) # return self._var_db.get(None) from databases import Database + engine = self.engine if not engine: - raise ValueError(f'Invalid engine: {engine}') + raise ValueError(f"Invalid engine: {engine}") # sqlite:// # postgresql://[user[:password]@][netloc][:port][/dbname][?param1=value1&...] params = dict(self.config.params) factory = self.connection_factory if factory: params.update(factory=factory) - database = Database(f'{engine}://{self.config.dsn}', **params) + database = Database(f"{engine}://{self.config.dsn}", **params) self._db = database return database @property def connection_factory(self): - if self.db_backend in ('sqlite', 'sqlite3'): + if self.db_backend in ("sqlite", "sqlite3"): import sqlite3 from aiosqlite import Connection @@ -157,7 +164,7 @@ class SQLiteConnection(sqlite3.Connection): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # if not self.in_transaction: - self.execute('PRAGMA foreign_keys = ON;') + self.execute("PRAGMA foreign_keys = ON;") # self.execute('PRAGMA legacy_alter_table = OFF;') # print('SQLITE EXECUTE') @@ -181,9 +188,11 @@ async def connect(self): try: await db.connect() except Exception as e: - raise e.__class__(f'Database: encode/databases connect to database: ' - f'{self.config.name}({self.config.alias}) with dns:' - f' {repr(self.config.protected_dsn)} failed: {e}') from e + raise e.__class__( + f"Database: encode/databases connect to database: " + f"{self.config.name}({self.config.alias}) with dns:" + f" {repr(self.config.protected_dsn)} failed: {e}" + ) from e # if not self._processed: # await self.process_db(db) return db @@ -217,32 +226,34 @@ def _parse_sql_params(cls, sql: str, params=None): return sql, {key: str(val) for key, val in params.items()} elif isinstance(params, (list, tuple)): # regex = re.compile('%s::[a-zA-Z0-9()]+\[\]') - sql = re.compile(r'%s::[a-zA-Z0-9()]+\[\]').sub('%s', sql) # match array (only for postgres) - replaces = tuple(f':param{i}' for i in range(0, len(params))) + sql = re.compile(r"%s::[a-zA-Z0-9()]+\[\]").sub( + "%s", sql + ) # match array (only for postgres) + replaces = tuple(f":param{i}" for i in range(0, len(params))) sql = sql % replaces - params = {f'param{i}': params[i] for i in range(0, len(params))} + params = {f"param{i}": params[i] for i in range(0, len(params))} # print('parsed:', sql, params) return sql, params else: - raise ValueError(f'Invalid params: {params}') + raise ValueError(f"Invalid params: {params}") async def execute(self, sql, params=None): - db = await self.connect() # lazy connect + db = await self.connect() # lazy connect sql, params = self._parse_sql_params(sql, params) return await db.execute(sql, params) async def execute_many(self, sql, params: list): - db = await self.connect() # lazy connect + db = await self.connect() # lazy connect return await db.execute_many(sql, params) async def fetchone(self, sql, params=None): - db = await self.connect() # lazy connect + db = await self.connect() # lazy connect sql, params = self._parse_sql_params(sql, params) r = await db.fetch_one(sql, params) return dict(r._mapping) if r else None async def fetchall(self, sql, params=None): - db = await self.connect() # lazy connect + db = await self.connect() # lazy connect # db = self.get_db() sql, params = self._parse_sql_params(sql, params) values = await db.fetch_all(sql, params) @@ -250,18 +261,15 @@ async def fetchall(self, sql, params=None): def transaction(self, savepoint=None, isolation=None, force_rollback: bool = False): db = self.get_db() - return _Transaction(db.connection, force_rollback=force_rollback, isolation=isolation) + return _Transaction( + db.connection, force_rollback=force_rollback, isolation=isolation + ) def check(self): if self.config.is_mysql: - requires( - MySQLdb='mysqlclient' - ) + requires(MySQLdb="mysqlclient") elif self.config.is_postgresql: - requires( - psycopg='"psycopg[binary,pool]"', - psycopg2='psycopg2' - ) + requires(psycopg='"psycopg[binary,pool]"', psycopg2="psycopg2") if self.async_engine: requires(self.async_engine) # try: diff --git a/utilmeta/core/orm/encoder.py b/utilmeta/core/orm/encoder.py index 22fbc34..c2912ce 100644 --- a/utilmeta/core/orm/encoder.py +++ b/utilmeta/core/orm/encoder.py @@ -1,7 +1,7 @@ from utype import register_encoder try: - from psycopg2._json import Json # noqa + from psycopg2._json import Json # noqa @register_encoder(Json) def from_iterable(encoder, data): diff --git a/utilmeta/core/orm/exceptions.py b/utilmeta/core/orm/exceptions.py index b88125b..1cfcee0 100644 --- a/utilmeta/core/orm/exceptions.py +++ b/utilmeta/core/orm/exceptions.py @@ -1,19 +1,18 @@ - class ModelRequired(ValueError): def __init__(self, msg: str = None): - super().__init__(msg or 'orm.Error: model is required for query execution') + super().__init__(msg or "orm.Error: model is required for query execution") class MissingPrimaryKey(ValueError): def __init__(self, msg: str = None, model=None): self.model = model - super().__init__(msg or 'orm.Error: pk is missing for update') + super().__init__(msg or "orm.Error: pk is missing for update") class UpdateFailed(ValueError): def __init__(self, msg: str = None, model=None): self.model = model - super().__init__(msg or 'orm.Error: must_update=True: failed to update') + super().__init__(msg or "orm.Error: must_update=True: failed to update") class InvalidRelationalUpdate(ValueError): @@ -27,7 +26,7 @@ class EmptyQueryset(ValueError): def __init__(self, msg: str = None, model=None): self.model = model - super().__init__(msg or 'orm.Error: result is empty') + super().__init__(msg or "orm.Error: result is empty") class InvalidMode(TypeError): diff --git a/utilmeta/core/orm/fields/field.py b/utilmeta/core/orm/fields/field.py index 2fd78d2..7f49c12 100644 --- a/utilmeta/core/orm/fields/field.py +++ b/utilmeta/core/orm/fields/field.py @@ -15,11 +15,7 @@ class ParserQueryField(ParserField): - def __init__( - self, - model: 'ModelAdaptor' = None, - **kwargs - ): + def __init__(self, model: "ModelAdaptor" = None, **kwargs): super().__init__(**kwargs) self._kwargs = kwargs from ..backends.base import ModelAdaptor, ModelFieldAdaptor @@ -27,14 +23,20 @@ def __init__( self.model = model self.model_field: Optional[ModelFieldAdaptor] = None self.related_model: Optional[ModelAdaptor] = None - self.related_schema: Optional[Type['Schema']] = None + self.related_schema: Optional[Type["Schema"]] = None self.related_single = None self.relation_update_enabled = False - self.isolated = self.field.isolated if isinstance(self.field, QueryField) else False - self.fail_silently = self.field.fail_silently if isinstance(self.field, QueryField) else False + self.isolated = ( + self.field.isolated if isinstance(self.field, QueryField) else False + ) + self.fail_silently = ( + self.field.fail_silently if isinstance(self.field, QueryField) else False + ) self.many_included = False self.subquery = None - self.queryset = self.field.queryset if isinstance(self.field, QueryField) else None + self.queryset = ( + self.field.queryset if isinstance(self.field, QueryField) else None + ) self.reverse_lookup = None self.primary_key = False self.func = None @@ -42,7 +44,7 @@ def __init__( self.type_override = False self.original_type = None - def reconstruct(self, model: 'ModelAdaptor'): + def reconstruct(self, model: "ModelAdaptor"): return self.__class__(model, **self._kwargs) def get_query_schema(self): @@ -54,7 +56,10 @@ def get_query_schema(self): if isinstance(self.type, type) and issubclass(self.type, Rule): # try to find List[schema] - if isinstance(self.type.__origin__, LogicalType) and self.type.__origin__.combinator: + if ( + isinstance(self.type.__origin__, LogicalType) + and self.type.__origin__.combinator + ): self.related_single = True for arg in self.type.__origin__.args: @@ -65,7 +70,11 @@ def get_query_schema(self): schema = arg break else: - if self.type.__origin__ and issubclass(self.type.__origin__, list) and self.type.__args__: + if ( + self.type.__origin__ + and issubclass(self.type.__origin__, list) + and self.type.__args__ + ): self.related_single = False # also for List[str] / List[int] # we only accept list, not tuple/set @@ -91,13 +100,16 @@ def get_query_schema(self): if self.model_field and self.related_model: # if parser.model and self.model_field: # check model if not queryset - if self.related_model.is_sub_model(parser.model) \ - or parser.model.is_sub_model(self.related_model): + if self.related_model.is_sub_model( + parser.model + ) or parser.model.is_sub_model(self.related_model): schema = schema or parser.obj else: - raise TypeError(f'orm.Field({repr(self.name)}): ' - f'Invalid related model: {parser.model.model},' - f' sub model of {self.related_model.model} expected') + raise TypeError( + f"orm.Field({repr(self.name)}): " + f"Invalid related model: {parser.model.model}," + f" sub model of {self.related_model.model} expected" + ) else: schema = schema or parser.obj # 1. func field @@ -116,7 +128,9 @@ def get_query_schema(self): # raise TypeError(f'orm.Field({repr(self.name)})) no model ' # f'specified for related schema: {parser.obj}') - class schema(parser.obj, Schema[self.related_model]): pass + class schema(parser.obj, Schema[self.related_model]): + pass + else: if not issubclass(schema, Schema): # common schema, not related schema @@ -126,8 +140,9 @@ class schema(parser.obj, Schema[self.related_model]): pass self.isolated = True @classmethod - def process_annotate_meta(cls, m, model: 'ModelAdaptor' = None, **kwargs): + def process_annotate_meta(cls, m, model: "ModelAdaptor" = None, **kwargs): from ..backends.base import ModelAdaptor + if isinstance(model, ModelAdaptor): if model.field_adaptor_cls.qualify(m): return QueryField(m) @@ -136,8 +151,11 @@ def process_annotate_meta(cls, m, model: 'ModelAdaptor' = None, **kwargs): return super().process_annotate_meta(m, **kwargs) @classmethod - def get_field(cls, annotation: Any, default, model: 'ModelAdaptor' = None, **kwargs): + def get_field( + cls, annotation: Any, default, model: "ModelAdaptor" = None, **kwargs + ): from ..backends.base import ModelAdaptor + if isinstance(model, ModelAdaptor): if model.field_adaptor_cls.qualify(default): return QueryField(default) @@ -150,41 +168,38 @@ def setup(self, options: utype.Options): self.original_type = self.type from ..backends.base import ModelAdaptor + if not isinstance(self.model, ModelAdaptor): return if class_func(self.field_name): from utype.parser.func import FunctionParser + func = FunctionParser.apply_for(self.field_name) # fixme: ugly approach, getting the awaitable async function - async_func = getattr(func.obj, '_asyncfunc', None) - sync_func = getattr(func.obj, '_syncfunc', None) + async_func = getattr(func.obj, "_asyncfunc", None) + sync_func = getattr(func.obj, "_syncfunc", None) if async_func and sync_func: from utilmeta.utils import awaitable + if isinstance(self.field_name, classmethod): sync_func = classmethod(sync_func) async_func = classmethod(async_func) sync_wrapper = FunctionParser.apply_for(sync_func).wrap( - ignore_methods=True, - parse_params=True, - parse_result=True + ignore_methods=True, parse_params=True, parse_result=True ) async_wrapper = FunctionParser.apply_for(async_func).wrap( - ignore_methods=True, - parse_params=True, - parse_result=True + ignore_methods=True, parse_params=True, parse_result=True ) self.func = awaitable(sync_wrapper)(async_wrapper) else: self.func = func.wrap( - ignore_methods=True, - parse_params=True, - parse_result=True + ignore_methods=True, parse_params=True, parse_result=True ) self.func_multi = bool(func.pos_var) if not self.mode: - self.mode = 'r' + self.mode = "r" self.get_query_schema() return @@ -192,39 +207,50 @@ def setup(self, options: utype.Options): if self.model.check_subquery(self.field_name): self.subquery = self.field_name if not self.mode: - self.mode = 'r' + self.mode = "r" self.related_model = self.model.get_model(self.subquery) if not self.related_model: - raise ValueError(f'No model detected in queryset: {self.subquery}') + raise ValueError(f"No model detected in queryset: {self.subquery}") if self.queryset is not None: - raise ValueError(f'specify subquery field and queryset at the same time is not supported') + raise ValueError( + f"specify subquery field and queryset at the same time is not supported" + ) self.get_query_schema() if self.related_single is False: - warnings.warn(f'{self.model} schema field: {repr(self.name)} is a multi-relation with a subquery, ' - f'you need to make sure that only 1 row of the query is returned, ' - f'otherwise use query function instead') + warnings.warn( + f"{self.model} schema field: {repr(self.name)} is a multi-relation with a subquery, " + f"you need to make sure that only 1 row of the query is returned, " + f"otherwise use query function instead" + ) self.isolated = True # force isolated for queryset query (even without schema) return - self.model_field = self.model.get_field(self.field_name, allow_addon=True, silently=True) - self.related_model = self.model_field.related_model if self.model_field else None + self.model_field = self.model.get_field( + self.field_name, allow_addon=True, silently=True + ) + self.related_model = ( + self.model_field.related_model if self.model_field else None + ) # fix: get related model before get query schema self.get_query_schema() if self.model_field: - self.primary_key = self.model_field and self.model_field.is_pk and \ - self.model.is_sub_model(self.model_field.field_model) + self.primary_key = ( + self.model_field + and self.model_field.is_pk + and self.model.is_sub_model(self.model_field.field_model) + ) # use is sub model, because pk might be its base model if self.model_field.is_auto: if self.model_field.is_auto_now: if not self.no_input: - self.no_input = 'aw' + self.no_input = "aw" if self.default_factory is None: self.default_factory = time_now # handle auto_now differently @@ -232,25 +258,27 @@ def setup(self, options: utype.Options): if not self.mode: # accept 'w' to identify object if self.primary_key or self.model_field.is_writable: - mode = {'r', 'w'} + mode = {"r", "w"} if isinstance(self.no_input, str): mode.update(self.no_input) if isinstance(self.no_output, str): mode.update(self.no_output) # eg. id: int = orm.Field(no_input='a') # should have mode: 'raw' instead of 'rw - self.mode = ''.join(sorted(list(mode))) + self.mode = "".join(sorted(list(mode))) if self.required is True: - self.required = 'r' + self.required = "r" if not self.no_input: - self.no_input = 'a' + self.no_input = "a" - if not self.model_field.is_writable or self.model.cross_models(self.field_name): + if not self.model_field.is_writable or self.model.cross_models( + self.field_name + ): # read only if not self.mode and not self.primary_key: - self.mode = 'r' + self.mode = "r" # do not set primary key field to mode='r' # otherwise pk will not be settable in other mode @@ -275,24 +303,32 @@ def setup(self, options: utype.Options): if self.queryset is not None: if not self.related_model: - raise ValueError(f'Invalid queryset for field: {repr(self.model_field.name)}, ' - f'no related model') - if not self.related_model.check_queryset(self.queryset, check_model=True): - raise ValueError(f'Invalid queryset for field: {repr(self.model_field.name)}, ' - f'must be a queryset of model {self.related_model.model}') + raise ValueError( + f"Invalid queryset for field: {repr(self.model_field.name)}, " + f"no related model" + ) + if not self.related_model.check_queryset( + self.queryset, check_model=True + ): + raise ValueError( + f"Invalid queryset for field: {repr(self.model_field.name)}, " + f"must be a queryset of model {self.related_model.model}" + ) self.reverse_lookup, c = self.model.get_reverse_lookup(self.field_name) if c or not self.reverse_lookup: - raise ValueError(f'Invalid queryset for field: {repr(self.model_field.name)}, ' - f'invalid reverse lookup: {self.reverse_lookup}, {c}') + raise ValueError( + f"Invalid queryset for field: {repr(self.model_field.name)}, " + f"invalid reverse lookup: {self.reverse_lookup}, {c}" + ) if self.related_schema: # even for fk schema # is not writable by default # if self.related_model or self.many_included: if not self.mode: - self.mode = 'r' + self.mode = "r" - elif 'a' in self.mode or 'w' in self.mode: + elif "a" in self.mode or "w" in self.mode: # UPDATE ON RELATIONAL if options.mode and set(options.mode).issubset(self.mode): self.setup_relational_update(options) @@ -302,11 +338,16 @@ def setup(self, options: utype.Options): # 1. for a common field (say, JSONField) with related schema, we does not say mode to 'r' # 2. for serializing array field (pk_values) using related schema, isolated should be True else: - if self.mode and ('a' in self.mode or 'w' in self.mode): + if self.mode and ("a" in self.mode or "w" in self.mode): # update many fields # tags: [1, 4, 5] - if not self.model.cross_models(self.field_name) and not self.model_field.is_concrete: - if self.model_field.is_m2m or (self.model_field.is_o2 and self.model_field.is_2o): + if ( + not self.model.cross_models(self.field_name) + and not self.model_field.is_concrete + ): + if self.model_field.is_m2m or ( + self.model_field.is_o2 and self.model_field.is_2o + ): # 1. OneToOneRel # 2. ManyToManyField / ManyToManyRel self.relation_update_enabled = True @@ -318,26 +359,34 @@ def setup(self, options: utype.Options): self.type = rule.merge_type(self.type) # merge declared type and model field type except utype.exc.ConfigError as e: - warnings.warn(f'orm.Schema[{self.model.model}] got model field: [{repr(self.name)}] ' - f'with rule: {rule} ' - f'conflicted to the declared type: {self.type}, using the declared type,' - f'error: {e}') + warnings.warn( + f"orm.Schema[{self.model.model}] got model field: [{repr(self.name)}] " + f"with rule: {rule} " + f"conflicted to the declared type: {self.type}, using the declared type," + f"error: {e}" + ) # fixme: do not merge for ForwardRef else: if isinstance(self.field, QueryField) and self.field.field: - raise ValueError(f'orm.Field({repr(self.field.field)}) not exists in model: {self.model}') + raise ValueError( + f"orm.Field({repr(self.field.field)}) not exists in model: {self.model}" + ) # will not be queried (input of 'r' mode) if not self.no_input: - self.no_input = 'r' + self.no_input = "r" if not self.no_output: # no output for write / create - self.no_output = 'aw' + self.no_output = "aw" def override_required(self, options: utype.Options): if not self.type_override: - if self.model_field and self.related_schema and not self.model.cross_models(self.field_name): - if 'a' in self.mode or 'w' in self.mode: + if ( + self.model_field + and self.related_schema + and not self.model.cross_models(self.field_name) + ): + if "a" in self.mode or "w" in self.mode: # UPDATE ON RELATIONAL if options.mode and set(options.mode).issubset(self.mode): return True @@ -350,12 +399,11 @@ def setup_relational_update(self, options: utype.Options): # CROSS MODEL FIELDS CANNOT USED IN UPDATE self.no_output = self.no_output or options.mode return - remote_field_name = self.model_field.remote_field.column_name # +_id + remote_field_name = self.model_field.remote_field.column_name # +_id from utilmeta.core.orm import Schema self.related_schema = self.related_schema._get_relational_update_cls( - field=remote_field_name, - mode=options.mode + field=remote_field_name, mode=options.mode ) # can be cached @@ -364,26 +412,35 @@ def setup_relational_update(self, options: utype.Options): origin = None rule_args = [] rule_constraints = {} - if isinstance(self.type.__origin__, LogicalType) and self.type.__origin__.combinator: + if ( + isinstance(self.type.__origin__, LogicalType) + and self.type.__origin__.combinator + ): args = [] for arg in self.type.__origin__.args: if isinstance(arg, type) and issubclass(arg, Schema): - args.append(arg._get_relational_update_cls( - field=remote_field_name, - mode=options.mode - )) + args.append( + arg._get_relational_update_cls( + field=remote_field_name, mode=options.mode + ) + ) else: args.append(arg) origin = LogicalType.combine(self.type.__origin__.combinator, *args) rule_args = self.type.__args__ or [] else: - if self.type.__origin__ and issubclass(self.type.__origin__, list) and self.type.__args__: + if ( + self.type.__origin__ + and issubclass(self.type.__origin__, list) + and self.type.__args__ + ): arg = self.type.__args__[0] if isinstance(arg, type) and issubclass(arg, Schema): - rule_args.append(arg._get_relational_update_cls( - field=remote_field_name, - mode=options.mode - )) + rule_args.append( + arg._get_relational_update_cls( + field=remote_field_name, mode=options.mode + ) + ) else: rule_args.append(arg) origin = self.type.__origin__ @@ -398,18 +455,18 @@ def setup_relational_update(self, options: utype.Options): else: if isinstance(self.type, type) and issubclass(self.type, Schema): self.type = self.type._get_relational_update_cls( - field=remote_field_name, - mode=options.mode + field=remote_field_name, mode=options.mode ) else: if isinstance(self.type, LogicalType) and self.type.combinator: args = [] for arg in self.type.args: if isinstance(arg, type) and issubclass(arg, Schema): - args.append(arg._get_relational_update_cls( - field=remote_field_name, - mode=options.mode - )) + args.append( + arg._get_relational_update_cls( + field=remote_field_name, mode=options.mode + ) + ) else: args.append(arg) self.type = LogicalType.combine(self.type.combinator, *args) @@ -426,7 +483,7 @@ def readable(self): return True if not self.model_field: return False - return not self.always_no_input(utype.Options(mode='r')) + return not self.always_no_input(utype.Options(mode="r")) @property def writable(self): @@ -473,9 +530,11 @@ def is_sub_relation(self): # content.article is a sub relation if self.related_model: if issubclass(self.related_model.model, self.model.model): - if not self.model_field.multi_relations and \ - self.model_field.remote_field and \ - self.model_field.remote_field.is_pk: + if ( + not self.model_field.multi_relations + and self.model_field.remote_field + and self.model_field.remote_field.is_pk + ): return True return False @@ -492,19 +551,22 @@ def schema_annotations(self): class QueryField(Field): parser_field_cls = ParserQueryField - def __init__(self, field=None, *, - queryset=None, - fail_silently: bool = None, - auth: dict = None, - # filter=None, - # order_by: Union[str, List[str], Callable] = None, - # limit: Union[int, Callable] = None, - # distinct: bool = None, - isolated: bool = None, - **kwargs - # if module enabled result control (page / rows / limit / offset) and such params is provided - # this config is automatically turn to True to prevent result control the entire queryset - ): + def __init__( + self, + field=None, + *, + queryset=None, + fail_silently: bool = None, + auth: dict = None, + # filter=None, + # order_by: Union[str, List[str], Callable] = None, + # limit: Union[int, Callable] = None, + # distinct: bool = None, + isolated: bool = None, + **kwargs + # if module enabled result control (page / rows / limit / offset) and such params is provided + # this config is automatically turn to True to prevent result control the entire queryset + ): super().__init__(**kwargs) self.field = field diff --git a/utilmeta/core/orm/fields/filter.py b/utilmeta/core/orm/fields/filter.py index 17ed235..b4d6e3c 100644 --- a/utilmeta/core/orm/fields/filter.py +++ b/utilmeta/core/orm/fields/filter.py @@ -1,5 +1,6 @@ import inspect from utype import Field + # from utilmeta.conf import Preference from utype.parser.field import ParserField from utype.types import * @@ -9,18 +10,21 @@ class Filter(Field): - def __init__(self, - field=None, - # allow at most 1 operator in 1 Filter to provide clarity - *, - query=None, # expression to convert a input string to a Q object, - order: Union[str, list, Callable] = None, # use order only if this filter is provided - # like order_by [1, 4, 2] - # lambda val: Case(*[When(**{field: v, 'then': pos}) for pos, v in enumerate(val)]) - fail_silently: bool = False, - required: bool = False, - **kwargs - ): + def __init__( + self, + field=None, + # allow at most 1 operator in 1 Filter to provide clarity + *, + query=None, # expression to convert a input string to a Q object, + order: Union[ + str, list, Callable + ] = None, # use order only if this filter is provided + # like order_by [1, 4, 2] + # lambda val: Case(*[When(**{field: v, 'then': pos}) for pos, v in enumerate(val)]) + fail_silently: bool = False, + required: bool = False, + **kwargs, + ): self.field = field self.query = query self.order = order @@ -35,21 +39,18 @@ def __init__(self, @property def schema_annotations(self): return { - 'class': 'filter', + "class": "filter", } class ParserFilter(ParserField): - field: 'Filter' + field: "Filter" field_cls = Filter - def __init__( - self, - model: 'ModelAdaptor' = None, - **kwargs - ): + def __init__(self, model: "ModelAdaptor" = None, **kwargs): super().__init__(**kwargs) from ..backends.base import ModelAdaptor, ModelFieldAdaptor + self.model: Optional[ModelAdaptor] = None self.model_field: Optional[ModelFieldAdaptor] = None # self.query: Optional[Callable] = None @@ -60,13 +61,17 @@ def __init__( if isinstance(self.field, Filter): if self.field_name: - self.model_field = model.get_field(self.field_name, allow_addon=True, silently=True) + self.model_field = model.get_field( + self.field_name, allow_addon=True, silently=True + ) if self.model_field: self.validate_field() else: if not self.filter.query: - raise ValueError(f'Filter({repr(self.field_name)}) ' - f'not resolved to field in model: {model.model}') + raise ValueError( + f"Filter({repr(self.field_name)}) " + f"not resolved to field in model: {model.model}" + ) if not inspect.isfunction(self.query): self.model.check_query(self.query) diff --git a/utilmeta/core/orm/fields/order.py b/utilmeta/core/orm/fields/order.py index 76c39d7..8c99361 100644 --- a/utilmeta/core/orm/fields/order.py +++ b/utilmeta/core/orm/fields/order.py @@ -2,6 +2,7 @@ from utype import Field from utype.parser.field import ParserField from utype.types import * + # from utilmeta.util.error import Error if TYPE_CHECKING: @@ -13,23 +14,28 @@ class Random: class Order: - def __init__(self, field=None, *, - asc: bool = True, - desc: bool = True, - document: str = None, - distinct: bool = False, - nulls_first: bool = False, # in asc() / desc() - nulls_last: bool = False, - notnull: bool = False, - ): + def __init__( + self, + field=None, + *, + asc: bool = True, + desc: bool = True, + document: str = None, + distinct: bool = False, + nulls_first: bool = False, # in asc() / desc() + nulls_last: bool = False, + notnull: bool = False, + ): if not asc and not desc: - raise ValueError(f'Order({repr(field)}) must specify asc or desc') + raise ValueError(f"Order({repr(field)}) must specify asc or desc") if notnull: if nulls_first or nulls_last: - raise ValueError(f'Order({repr(field)}) that set ' - f'notnull=True cannot config nulls_first or nulls_last') + raise ValueError( + f"Order({repr(field)}) that set " + f"notnull=True cannot config nulls_first or nulls_last" + ) self.asc = asc self.desc = desc self.distinct = distinct @@ -41,18 +47,15 @@ def __init__(self, field=None, *, class ParserOrderBy(ParserField): - field: 'OrderBy' + field: "OrderBy" - def __init__( - self, - model: 'ModelAdaptor' = None, - **kwargs - ): + def __init__(self, model: "ModelAdaptor" = None, **kwargs): super().__init__(**kwargs) from ..backends.base import ModelAdaptor, ModelFieldAdaptor + self.model: Optional[ModelAdaptor] = None self.orders: Dict[str, Tuple[Order, ModelFieldAdaptor, int]] = {} - self.desc_prefix: str = '-' + self.desc_prefix: str = "-" if isinstance(model, ModelAdaptor) and isinstance(self.field, OrderBy): self.model = model @@ -64,14 +67,16 @@ def __init__( field = model.get_field(field_name, allow_addon=True) name = key if isinstance(key, str) else field.query_name if not name: - raise ValueError(f'Order field: {key} must have a valid name') + raise ValueError(f"Order field: {key} must have a valid name") if field.is_exp: model.check_expressions(field.field) else: model.check_order(field.query_name) if model.include_many_relates(field_name): - warnings.warn(f'Order for {model} field <{field_name}> contains multiple value, ' - f'make sure that is what your expected') + warnings.warn( + f"Order for {model} field <{field_name}> contains multiple value, " + f"make sure that is what your expected" + ) if order.asc: orders.setdefault(name, (order, field, 1)) @@ -83,7 +88,7 @@ def __init__( list(orders), item_type=str, # name=f'{self.model.ident}.{self.name}.enum', - unique=True + unique=True, ) # def parse_value(self, value, context): @@ -114,7 +119,7 @@ def schema_annotations(self): asc=order.asc, desc=order.desc, nulls_first=order.nulls_first, - nulls_last=order.nulls_last + nulls_last=order.nulls_last, ) data.update(orders=orders) return data @@ -123,37 +128,35 @@ def schema_annotations(self): class OrderBy(Field): parser_field_cls = ParserOrderBy - def __init__(self, orders: Union[list, Dict[Any, Order]], - *, - # orders can be a list of model fields, or a dict of order configuration - # key: str = None, - # max_length: int = None, - desc_prefix: str = '-', - ignore_invalids: bool = True, - ignore_conflicts: bool = True, # like if asc and desc is provided at the same time - required: bool = False, - description: str = None, - single: bool = False, - **kwargs, - ): + def __init__( + self, + orders: Union[list, Dict[Any, Order]], + *, + # orders can be a list of model fields, or a dict of order configuration + # key: str = None, + # max_length: int = None, + desc_prefix: str = "-", + ignore_invalids: bool = True, + ignore_conflicts: bool = True, # like if asc and desc is provided at the same time + required: bool = False, + description: str = None, + single: bool = False, + **kwargs, + ): if isinstance(orders, list): orders = {o: Order() for o in orders} order_docs = [] for key, order in orders.items(): if order.document: - order_docs.append(f'{key}: {order.document}') - order_doc = '\n'.join(order_docs) + order_docs.append(f"{key}: {order.document}") + order_doc = "\n".join(order_docs) if description: - description += '\n' + order_doc + description += "\n" + order_doc else: description = order_doc - super().__init__( - **kwargs, - description=description, - required=required - ) + super().__init__(**kwargs, description=description, required=required) self.orders = orders self.desc_prefix = desc_prefix @@ -164,5 +167,5 @@ def __init__(self, orders: Union[list, Dict[Any, Order]], @property def schema_annotations(self): return { - 'class': 'order_by', + "class": "order_by", } diff --git a/utilmeta/core/orm/fields/pagination.py b/utilmeta/core/orm/fields/pagination.py index cd5ca3d..9d1c192 100644 --- a/utilmeta/core/orm/fields/pagination.py +++ b/utilmeta/core/orm/fields/pagination.py @@ -10,7 +10,7 @@ def __init__(self, ge: int = 1, required: bool = False, **kwargs): @property def schema_annotations(self): - return {'class': 'page'} + return {"class": "page"} class Offset(Field): @@ -22,7 +22,7 @@ def __init__(self, ge: int = 0, required: bool = False, default=0, **kwargs): @property def schema_annotations(self): - return {'class': 'offset'} + return {"class": "offset"} class Limit(Field): @@ -34,4 +34,4 @@ def __init__(self, ge: int = 0, required: bool = False, **kwargs): @property def schema_annotations(self): - return {'class': 'limit'} + return {"class": "limit"} diff --git a/utilmeta/core/orm/fields/scope.py b/utilmeta/core/orm/fields/scope.py index f55a9a1..afc19ac 100644 --- a/utilmeta/core/orm/fields/scope.py +++ b/utilmeta/core/orm/fields/scope.py @@ -11,14 +11,15 @@ class Scope(Field): # TEMPLATE_ALIASES = ['template', 'includes', 'scope', 'fields'] # EXCLUDES_ALIASES = ['excludes', 'skip'] - def __init__(self, - excluded: bool = False, - max_depth: int = None, - ignore_invalids: bool = True, - allow_recursive: bool = True, - required: bool = False, - **kwargs - ): + def __init__( + self, + excluded: bool = False, + max_depth: int = None, + ignore_invalids: bool = True, + allow_recursive: bool = True, + required: bool = False, + **kwargs + ): super().__init__(**kwargs, required=required) self.max_depth = max_depth self.ignore_invalids = ignore_invalids @@ -37,10 +38,7 @@ def get_scope_value(cls, value): @property def schema_annotations(self): - return { - 'class': 'scope', - 'excluded': self.excluded - } + return {"class": "scope", "excluded": self.excluded} @property def default_type(self): diff --git a/utilmeta/core/orm/generator.py b/utilmeta/core/orm/generator.py index 0c60b93..c0921bd 100644 --- a/utilmeta/core/orm/generator.py +++ b/utilmeta/core/orm/generator.py @@ -14,13 +14,14 @@ class BaseQuerysetGenerator: - def __init__(self, - parser: QueryClassParser, - values: dict, - # --- config params - distinct: bool = None, - **kwargs - ): + def __init__( + self, + parser: QueryClassParser, + values: dict, + # --- config params + distinct: bool = None, + **kwargs + ): self.parser = parser self.model = parser.model self.values = values @@ -58,10 +59,7 @@ async def acount(self, base=None) -> int: raise NotImplementedError def get_context(self, **kwargs): - kwargs.update( - includes=self.includes, - excludes=self.excludes - ) + kwargs.update(includes=self.includes, excludes=self.excludes) return QueryContext(**kwargs) @property @@ -84,7 +82,9 @@ def slice(self) -> slice: def process_filter(self, field: ParserFilter, value): raise NotImplementedError - def process_order(self, order: Order, field: 'ModelFieldAdaptor', name: str, flag: int = 1): + def process_order( + self, order: Order, field: "ModelFieldAdaptor", name: str, flag: int = 1 + ): raise NotImplementedError def process_value(self, field, value): @@ -97,10 +97,7 @@ def process_value(self, field, value): if o in field.orders: order, f, flag = field.orders[o] self.process_order( - order, - field=f, - flag=flag, - name=str(o).lstrip(field.desc_prefix) + order, field=f, flag=flag, name=str(o).lstrip(field.desc_prefix) ) elif isinstance(field.field, Page): self.page = value diff --git a/utilmeta/core/orm/parser.py b/utilmeta/core/orm/parser.py index b3fc5c4..2731649 100644 --- a/utilmeta/core/orm/parser.py +++ b/utilmeta/core/orm/parser.py @@ -3,6 +3,7 @@ from .fields.filter import ParserFilter from . import exceptions from typing import TYPE_CHECKING + if TYPE_CHECKING: from .compiler import BaseQueryCompiler from .generator import BaseQuerysetGenerator @@ -12,8 +13,9 @@ class QueryClassParser(ClassParser): parser_field_cls = ParserFilter def __init__(self, obj, *args, **kwargs): - model = getattr(obj, '__model__', None) + model = getattr(obj, "__model__", None) from .backends.base import ModelAdaptor + self.model = ModelAdaptor.dispatch(model) if model else None super().__init__(obj, *args, **kwargs) @@ -21,7 +23,7 @@ def __init__(self, obj, *args, **kwargs): def kwargs(self): return dict(model=self.model) - def get_generator(self, values: dict, **kwargs) -> 'BaseQuerysetGenerator': + def get_generator(self, values: dict, **kwargs) -> "BaseQuerysetGenerator": return self.model.generator_cls(self, values, **kwargs) @property @@ -36,15 +38,18 @@ class SchemaClassParser(ClassParser): parser_field_cls = ParserQueryField def __init__(self, obj, *args, **kwargs): - model = getattr(obj, '__model__', None) + model = getattr(obj, "__model__", None) from .backends.base import ModelAdaptor + self.model = ModelAdaptor.dispatch(model) if model else None super().__init__(obj, *args, **kwargs) - serialize_options = getattr(obj, '__serialize_options__', None) - self.output_options = self.options_cls.generate_from( - serialize_options - ) if serialize_options else None + serialize_options = getattr(obj, "__serialize_options__", None) + self.output_options = ( + self.options_cls.generate_from(serialize_options) + if serialize_options + else None + ) pk_names = set() if self.model: @@ -61,8 +66,10 @@ def __init__(self, obj, *args, **kwargs): field = field.reconstruct(self.model) field.setup(self.options) except Exception as e: - raise e.__class__(f'{self.name}(orm.Schema): setup field [{repr(name)}] ' - f'for model: {self.model} failed with error: {e}') from e + raise e.__class__( + f"{self.name}(orm.Schema): setup field [{repr(name)}] " + f"for model: {self.model} failed with error: {e}" + ) from e self.fields[name] = field # if pk_names: @@ -78,14 +85,16 @@ def __init__(self, obj, *args, **kwargs): def kwargs(self): return dict(model=self.model) - def get_compiler(self, queryset, context=None) -> 'BaseQueryCompiler': + def get_compiler(self, queryset, context=None) -> "BaseQueryCompiler": if not self.model: - raise exceptions.ModelRequired(f'{self.name}: model is required for query execution') + raise exceptions.ModelRequired( + f"{self.name}: model is required for query execution" + ) return self.model.compiler_cls(self, queryset, context=context) def get_instance(self, data: dict): # pk = self.get_pk(data) - inst = dict(pk=getattr(data, 'pk', None)) + inst = dict(pk=getattr(data, "pk", None)) for key, val in data.items(): field = self.get_field(key) if isinstance(field, ParserQueryField): diff --git a/utilmeta/core/orm/plugins/atomic.py b/utilmeta/core/orm/plugins/atomic.py index 39ed41b..8f2e391 100644 --- a/utilmeta/core/orm/plugins/atomic.py +++ b/utilmeta/core/orm/plugins/atomic.py @@ -6,8 +6,14 @@ class AtomicPlugin(PluginBase): - def __init__(self, alias: str = 'default', savepoint: bool = True, durable: bool = False, - isolation=None, force_rollback: bool = False): + def __init__( + self, + alias: str = "default", + savepoint: bool = True, + durable: bool = False, + isolation=None, + force_rollback: bool = False, + ): super().__init__(locals()) self.alias = alias self.savepoint = savepoint @@ -23,7 +29,7 @@ def __enter__(self): self.transaction = self.db.transaction( savepoint=self.savepoint, isolation=self.isolation, - force_rollback=self.force_rollback + force_rollback=self.force_rollback, ) return self.transaction.__enter__() @@ -48,7 +54,7 @@ async def __aenter__(self): self.async_transaction = self.db.async_transaction( savepoint=self.savepoint, isolation=self.isolation, - force_rollback=self.force_rollback + force_rollback=self.force_rollback, ) return await self.async_transaction.__aenter__() @@ -62,20 +68,21 @@ def __call__(self, f, *_, **__): transaction = self.db.async_transaction( savepoint=self.savepoint, isolation=self.isolation, - force_rollback=self.force_rollback + force_rollback=self.force_rollback, ) @functools.wraps(f) async def wrapper(*args, **kwargs): async with transaction: return await f(*args, **kwargs) + return wrapper elif inspect.isfunction(f): transaction = self.db.transaction( savepoint=self.savepoint, isolation=self.isolation, - force_rollback=self.force_rollback + force_rollback=self.force_rollback, ) @functools.wraps(f) @@ -89,9 +96,11 @@ def wrapper(*args, **kwargs): pass else: if service.asynchronous: + @functools.wraps(f) def threaded_wrapper(*args, **kwargs): return service.pool.get_result(wrapper, *args, **kwargs) + return threaded_wrapper return wrapper diff --git a/utilmeta/core/orm/schema.py b/utilmeta/core/orm/schema.py index daa5821..58fb685 100644 --- a/utilmeta/core/orm/schema.py +++ b/utilmeta/core/orm/schema.py @@ -9,17 +9,17 @@ from .fields.field import ParserQueryField -T = TypeVar('T') +T = TypeVar("T") __caches__: dict = {} class Schema(utype.Schema): __serialize_options__ = utype.Options( - mode='r', + mode="r", addition=True, ignore_required=True, - ignore_constraints=True # skip constraints validation when querying from db + ignore_constraints=True # skip constraints validation when querying from db # no_default=True, # no default, but default can be calculated when attr is called ) @@ -34,10 +34,7 @@ def __class_getitem__(cls: T, item) -> T: k = (cls, item) if k in __caches__: return __caches__[k] - attrs = { - '__qualname__': cls.__qualname__, - '__module__': cls.__module__ - } + attrs = {"__qualname__": cls.__qualname__, "__module__": cls.__module__} options = None annotations = {} @@ -82,19 +79,19 @@ def pk(self, val): if name in self: _set = True if not _set: - self.__dict__['pk'] = val + self.__dict__["pk"] = val def get_instance(self, fresh: bool = True): if fresh: if self.pk is None: - raise exceptions.MissingPrimaryKey('pk is missing for query instance') + raise exceptions.MissingPrimaryKey("pk is missing for query instance") return self.__parser__.model.get_instance(pk=self.pk) return self.__parser__.get_instance(self) async def aget_instance(self, fresh: bool = True): if fresh: if self.pk is None: - raise exceptions.MissingPrimaryKey('pk is missing for query instance') + raise exceptions.MissingPrimaryKey("pk is missing for query instance") return await self.__parser__.model.aget_instance(pk=self.pk) return self.__parser__.get_instance(self) @@ -133,10 +130,10 @@ def _get_relational_update_cls(cls, field: str, mode: str): if k in __caches__: return __caches__[k] - suffix = f'_RELATIONAL_UPDATE_{field}' + suffix = f"_RELATIONAL_UPDATE_{field}" attrs = { - '__qualname__': cls.__qualname__ + suffix, - '__module__': cls.__module__, + "__qualname__": cls.__qualname__ + suffix, + "__module__": cls.__module__, } if isinstance(mode, str): attrs.update(__options__=utype.Options(mode=mode)) @@ -145,21 +142,32 @@ def _get_relational_update_cls(cls, field: str, mode: str): model_field = cls.__parser__.model.get_field(field) if not model_field: - raise ValueError(f'Invalid relation remote_field: {repr(field)}, not exists') + raise ValueError( + f"Invalid relation remote_field: {repr(field)}, not exists" + ) if not model_field.is_fk: - raise ValueError(f'Invalid relation remote_field: {repr(field)}, must be ForeignKey') + raise ValueError( + f"Invalid relation remote_field: {repr(field)}, must be ForeignKey" + ) relational_fields = [] for name, parser_field in cls.__parser__.fields.items(): parser_field: SchemaClassParser.parser_field_cls - if parser_field.model_field and parser_field.model_field.name == model_field.name: - if name == 'pk': + if ( + parser_field.model_field + and parser_field.model_field.name == model_field.name + ): + if name == "pk": continue - attrs[parser_field.attname] = parser_field.field_cls(no_input=mode, mode=mode) + attrs[parser_field.attname] = parser_field.field_cls( + no_input=mode, mode=mode + ) if not parser_field.no_output: relational_fields.append(parser_field.attname) if not relational_fields: - attrs[field] = cls.__parser_cls__.parser_field_cls.field_cls(no_input=mode, mode=mode, no_output=False) + attrs[field] = cls.__parser_cls__.parser_field_cls.field_cls( + no_input=mode, mode=mode, no_output=False + ) relational_fields = [field] attrs.update( @@ -194,7 +202,7 @@ def init(cls: Type[T], queryset, context=None) -> T: cls: Type[Schema] values = cls._get_compiler(queryset, context=context, single=True).get_values() if not values: - raise exceptions.EmptyQueryset(f'Empty queryset') + raise exceptions.EmptyQueryset(f"Empty queryset") return cls.__from__(values[0], cls.__serialize_options__) @classmethod @@ -203,39 +211,46 @@ async def ainit(cls: Type[T], queryset, context=None) -> T: # initialize this schema with the given queryset (first element) # raise error if queryset is empty cls: Type[Schema] - values = await cls._get_compiler(queryset, context=context, single=True).get_values() + values = await cls._get_compiler( + queryset, context=context, single=True + ).get_values() if not values: - raise exceptions.EmptyQueryset(f'Empty queryset') + raise exceptions.EmptyQueryset(f"Empty queryset") return cls.__from__(values[0], cls.__serialize_options__) - def commit(self, queryset: T) -> T: # -> queryset + def commit(self, queryset: T) -> T: # -> queryset # commit the data in the schema to the queryset (update) # id is ignored here compiler = self._get_compiler(queryset) return compiler.commit_data(self) # @awaitable(commit) - async def acommit(self, queryset: T) -> T: # -> queryset + async def acommit(self, queryset: T) -> T: # -> queryset # commit the data in the schema to the queryset (update) # id is ignored here compiler = self._get_compiler(queryset) return await compiler.commit_data(self) - def save(self: T, - must_create: bool = None, - must_update: bool = None, - with_relations: bool = None, - ignore_relation_errors: Union[bool, Type[Exception], List[Type[Exception]]] = False, - transaction: Union[bool, str] = False, - ) -> T: # -> queryset + def save( + self: T, + must_create: bool = None, + must_update: bool = None, + with_relations: bool = None, + ignore_relation_errors: Union[ + bool, Type[Exception], List[Type[Exception]] + ] = False, + transaction: Union[bool, str] = False, + ) -> T: # -> queryset # no id: create # id: create -(integrityError)-> update if must_update and must_create: - raise ValueError(f'{__class__.__name__}.save(): must_create and must_update cannot both be True') + raise ValueError( + f"{__class__.__name__}.save(): must_create and must_update cannot both be True" + ) if must_create is None: if must_update: must_create = False - elif self.__options__.mode == 'a' and not self.pk: + elif self.__options__.mode == "a" and not self.pk: must_create = True # if with_relations is None: # with_relations = True @@ -250,26 +265,31 @@ def save(self: T, must_update=must_update, with_relations=with_relations, ignore_relation_errors=ignore_relation_errors, - transaction=transaction + transaction=transaction, ) return self # @awaitable(save) - async def asave(self: T, - must_create: bool = None, - must_update: bool = None, - with_relations: bool = None, - ignore_relation_errors: Union[bool, Type[Exception], List[Type[Exception]]] = False, - transaction: Union[bool, str] = False, - ) -> T: # -> queryset + async def asave( + self: T, + must_create: bool = None, + must_update: bool = None, + with_relations: bool = None, + ignore_relation_errors: Union[ + bool, Type[Exception], List[Type[Exception]] + ] = False, + transaction: Union[bool, str] = False, + ) -> T: # -> queryset # no id: create # id: create -(integrityError)-> update if must_update and must_create: - raise ValueError(f'{__class__.__name__}.asave(): must_create and must_update cannot both be True') + raise ValueError( + f"{__class__.__name__}.asave(): must_create and must_update cannot both be True" + ) if must_create is None: if must_update: must_create = False - elif self.__options__.mode == 'a' and not self.pk: + elif self.__options__.mode == "a" and not self.pk: must_create = True # if with_relations is None: # with_relations = True @@ -284,26 +304,29 @@ async def asave(self: T, must_update=must_update, with_relations=with_relations, ignore_relation_errors=ignore_relation_errors, - transaction=transaction + transaction=transaction, ) return self @classmethod - def bulk_save(cls: Type[T], - data: List[T], - must_create: bool = False, - must_update: bool = False, - with_relations: bool = None, - ignore_errors: Union[bool, Type[Exception], List[Type[Exception]]] = False, - ignore_relation_errors: Union[bool, Type[Exception], List[Type[Exception]]] = False, - transaction: Union[bool, str] = False, - ) -> List[T]: + def bulk_save( + cls: Type[T], + data: List[T], + must_create: bool = False, + must_update: bool = False, + with_relations: bool = None, + ignore_errors: Union[bool, Type[Exception], List[Type[Exception]]] = False, + ignore_relation_errors: Union[ + bool, Type[Exception], List[Type[Exception]] + ] = False, + transaction: Union[bool, str] = False, + ) -> List[T]: # the queryset is contained in the data, # data with id will be updated (try, and create after not exists) # data without id will be created compiler = cls._get_compiler(None) if not isinstance(data, list): - raise TypeError(f'Invalid data: {data}, must be list') + raise TypeError(f"Invalid data: {data}, must be list") # 1. transform data list to schema instance list values = [val if isinstance(val, cls) else cls.__from__(val) for val in data] # 2. bulk create @@ -315,30 +338,33 @@ def bulk_save(cls: Type[T], with_relations=with_relations, ignore_bulk_errors=ignore_errors, ignore_relation_errors=ignore_relation_errors, - transaction=transaction + transaction=transaction, ), - values + values, ): if pk: val.pk = pk return values @classmethod - async def abulk_save(cls: Type[T], - data: List[T], - must_create: bool = False, - must_update: bool = False, - with_relations: bool = None, - ignore_errors: Union[bool, Type[Exception], List[Type[Exception]]] = False, - ignore_relation_errors: Union[bool, Type[Exception], List[Type[Exception]]] = False, - transaction: Union[bool, str] = False, - ) -> List[T]: + async def abulk_save( + cls: Type[T], + data: List[T], + must_create: bool = False, + must_update: bool = False, + with_relations: bool = None, + ignore_errors: Union[bool, Type[Exception], List[Type[Exception]]] = False, + ignore_relation_errors: Union[ + bool, Type[Exception], List[Type[Exception]] + ] = False, + transaction: Union[bool, str] = False, + ) -> List[T]: # the queryset is contained in the data, # data with id will be updated (try, and create after not exists) # data without id will be created compiler = cls._get_compiler(None) if not isinstance(data, list): - raise TypeError(f'Invalid data: {data}, must be list') + raise TypeError(f"Invalid data: {data}, must be list") # 1. transform data list to schema instance list values = [val if isinstance(val, cls) else cls.__from__(val) for val in data] for pk, val in zip( @@ -349,9 +375,9 @@ async def abulk_save(cls: Type[T], with_relations=with_relations, ignore_bulk_errors=ignore_errors, ignore_relation_errors=ignore_relation_errors, - transaction=transaction + transaction=transaction, ), - values + values, ): if pk: val.pk = pk @@ -372,6 +398,7 @@ class Query(utype.Schema): def __class_getitem__(cls, item): class _class(cls): __model__ = item + return _class def get_generator(self): diff --git a/utilmeta/core/request/backends/base.py b/utilmeta/core/request/backends/base.py index ba65c25..a65f350 100644 --- a/utilmeta/core/request/backends/base.py +++ b/utilmeta/core/request/backends/base.py @@ -1,7 +1,15 @@ from urllib.parse import urlsplit, urlunsplit from typing import Optional -from utilmeta.utils import MetaMethod, CommonMethod, Header, get_request_ip, \ - RequestType, cached_property, time_now, parse_query_string +from utilmeta.utils import ( + MetaMethod, + CommonMethod, + Header, + get_request_ip, + RequestType, + cached_property, + time_now, + parse_query_string, +) from utilmeta.utils import exceptions as exc from utilmeta.utils import LOCAL_IP from utilmeta.core.file import File @@ -44,7 +52,7 @@ def route(self): @route.setter def route(self, route): - self._route = str(route or '').strip('/') + self._route = str(route or "").strip("/") def __contains__(self, item): return item in self._context @@ -72,7 +80,7 @@ def clear_context(self): self._context.clear() @classmethod - def reconstruct(cls, adaptor: 'RequestAdaptor'): + def reconstruct(cls, adaptor: "RequestAdaptor"): if isinstance(adaptor, cls): return adaptor.request raise NotImplementedError @@ -117,7 +125,7 @@ def url(self): # full url def encoded_path(self): parsed = urlsplit(self.url) if parsed.query: - return parsed.path + '?' + parsed.query + return parsed.path + "?" + parsed.query return parsed.path @property @@ -130,11 +138,11 @@ def hostname(self): @property def origin(self): - origin_header = self.headers.get('origin') + origin_header = self.headers.get("origin") if origin_header: return origin_header s = urlsplit(self.url) - return urlunsplit((s.scheme, s.netloc, '', '', '')) + return urlunsplit((s.scheme, s.netloc, "", "", "")) @property def scheme(self): @@ -162,8 +170,8 @@ def content_type(self) -> Optional[str]: if not ct: return ct = str(ct) - if ';' in ct: - return ct.split(';')[0].strip() + if ";" in ct: + return ct.split(";")[0].strip() return ct @property @@ -203,17 +211,19 @@ def form_type(self): @property def text_type(self): content_type = self.content_type - return content_type.startswith('text') + return content_type.startswith("text") def get_json(self): if not self.content_length: # Empty content return None import json + return json.loads(self.body, cls=self.json_decoder_cls) def get_xml(self): from xml.etree.ElementTree import XMLParser + parser = XMLParser() parser.feed(self.body) return parser.close() @@ -260,7 +270,9 @@ def content(self): except NotImplementedError: raise except Exception as e: - raise exc.UnprocessableEntity(f'process request body failed with error: {e}') + raise exc.UnprocessableEntity( + f"process request body failed with error: {e}" + ) @property def body(self) -> Optional[bytes]: @@ -284,7 +296,9 @@ async def async_load(self): except NotImplementedError: raise except Exception as e: - raise exc.UnprocessableEntity(f'process request body failed with error: {e}') from e + raise exc.UnprocessableEntity( + f"process request body failed with error: {e}" + ) from e async def async_read(self): raise NotImplementedError diff --git a/utilmeta/core/request/backends/django.py b/utilmeta/core/request/backends/django.py index c84eebd..8e85053 100644 --- a/utilmeta/core/request/backends/django.py +++ b/utilmeta/core/request/backends/django.py @@ -1,9 +1,13 @@ -import io - from .base import RequestAdaptor from django.http.request import HttpRequest from django.middleware.csrf import CsrfViewMiddleware, get_token -from utilmeta.utils import parse_query_dict, cached_property, Header, LOCAL_IP, multi, exceptions, url_join +from utilmeta.utils import ( + parse_query_dict, + Header, + LOCAL_IP, + multi, + url_join, +) from ipaddress import ip_address from utilmeta.core.file.backends.django import DjangoFileAdaptor from utilmeta.core.file.base import File @@ -11,10 +15,12 @@ def get_request_ip(meta: dict): - ips = [*meta.get(Header.FORWARDED_FOR, '').replace(' ', '').split(','), - meta.get(Header.REMOTE_ADDR)] - if '' in ips: - ips.remove('') + ips = [ + *meta.get(Header.FORWARDED_FOR, "").replace(" ", "").split(","), + meta.get(Header.REMOTE_ADDR), + ] + if "" in ips: + ips.remove("") if LOCAL_IP in ips: ips.remove(LOCAL_IP) for ip in ips: @@ -33,7 +39,9 @@ def gen_csrf_token(self): return get_token(self.request) def check_csrf_token(self) -> bool: - err_resp = CsrfViewMiddleware(lambda *_: None).process_view(self.request, None, None, None) + err_resp = CsrfViewMiddleware(lambda *_: None).process_view( + self.request, None, None, None + ) return err_resp is None @property @@ -49,12 +57,13 @@ def address(self): return self._address def get_url(self): - if hasattr(self.request, 'get_raw_uri'): + if hasattr(self.request, "get_raw_uri"): return self.request.get_raw_uri() try: return self.request.build_absolute_uri() except KeyError: from utilmeta import service + return url_join(service.origin, self.path) @property @@ -64,19 +73,19 @@ def path(self): @classmethod def load_form_data(cls, request): m = request.method - load_call = getattr(request, '_load_post_and_files') - if m in ('PUT', 'PATCH'): - if hasattr(request, '_post'): - delattr(request, '_post') - delattr(request, '_files') + load_call = getattr(request, "_load_post_and_files") + if m in ("PUT", "PATCH"): + if hasattr(request, "_post"): + delattr(request, "_post") + delattr(request, "_files") try: - request.method = 'POST' + request.method = "POST" load_call() request.method = m except AttributeError: - request.META['REQUEST_METHOD'] = 'POST' + request.META["REQUEST_METHOD"] = "POST" load_call() - request.META['REQUEST_METHOD'] = m + request.META["REQUEST_METHOD"] = m def get_form(self): self.load_form_data(self.request) @@ -85,7 +94,9 @@ def get_form(self): for key in self.request.FILES: files = self.request.FILES.getlist(key) if multi(files): - parsed_files[key] = [File(self.file_adaptor_cls(file)) for file in files] + parsed_files[key] = [ + File(self.file_adaptor_cls(file)) for file in files + ] else: parsed_files[key] = File(self.file_adaptor_cls(files)) data.update(parsed_files) diff --git a/utilmeta/core/request/backends/sanic.py b/utilmeta/core/request/backends/sanic.py index 20ac007..7d4a65a 100644 --- a/utilmeta/core/request/backends/sanic.py +++ b/utilmeta/core/request/backends/sanic.py @@ -5,7 +5,7 @@ from utilmeta.core.file.backends.sanic import SanicFileAdaptor from sanic.request.form import File as SanicFile from utilmeta.core.file.base import File -from utilmeta.utils import exceptions, multi, get_request_ip +from utilmeta.utils import multi, get_request_ip from utype import unprovided @@ -27,7 +27,7 @@ def address(self): return ipaddress.ip_address(ip) except (AttributeError, ValueError): pass - return get_request_ip(dict(self.headers)) or ipaddress.ip_address('127.0.0.1') + return get_request_ip(dict(self.headers)) or ipaddress.ip_address("127.0.0.1") @property def cookies(self): diff --git a/utilmeta/core/request/backends/starlette.py b/utilmeta/core/request/backends/starlette.py index 09cb8cb..6166407 100644 --- a/utilmeta/core/request/backends/starlette.py +++ b/utilmeta/core/request/backends/starlette.py @@ -14,12 +14,13 @@ class StarletteRequestAdaptor(RequestAdaptor): This adaptor can adapt starlette project and all frameworks based on it such as [FastAPI] """ + request: Request file_adaptor_cls = StarletteFileAdaptor backend = starlette @classmethod - def reconstruct(cls, adaptor: 'RequestAdaptor'): + def reconstruct(cls, adaptor: "RequestAdaptor"): pass def gen_csrf_token(self): @@ -41,7 +42,9 @@ def request_method(self) -> str: @property def url(self) -> str: - return str(self.request.url) # request.url is a URL structure, str will get the inner _url + return str( + self.request.url + ) # request.url is a URL structure, str will get the inner _url @property def cookies(self): @@ -51,7 +54,7 @@ def cookies(self): def query_params(self): query = {} for key, value in self.request.query_params.multi_items(): - query.setdefault(key.rstrip('[]'), []).append(value) + query.setdefault(key.rstrip("[]"), []).append(value) return {k: val[0] if len(val) == 1 else val for k, val in query.items()} @property @@ -70,7 +73,7 @@ def scheme(self): def encoded_path(self): path, query = self.path, self.query_string if query: - return path + '?' + query + return path + "?" + query return path @property @@ -124,9 +127,7 @@ async def steam(): yield b"" return - form = await MultiPartParser( - self.headers, steam() - ).parse() + form = await MultiPartParser(self.headers, steam()).parse() else: form = await self.request.form() return self.process_form(form) diff --git a/utilmeta/core/request/backends/tornado.py b/utilmeta/core/request/backends/tornado.py index 76f2177..d480f68 100644 --- a/utilmeta/core/request/backends/tornado.py +++ b/utilmeta/core/request/backends/tornado.py @@ -1,9 +1,9 @@ from tornado.httpserver import HTTPRequest as ServerRequest + # from tornado.httpclient import HTTPRequest as ClientRequest from ..base import RequestAdaptor from utilmeta.core.file.backends.tornado import TornadoFileAdaptor from utilmeta.core.file.base import File -from utilmeta.utils import exceptions as exc import tornado @@ -13,7 +13,7 @@ class TornadoServerRequestAdaptor(RequestAdaptor): backend = tornado @classmethod - def reconstruct(cls, adaptor: 'RequestAdaptor') -> ServerRequest: + def reconstruct(cls, adaptor: "RequestAdaptor") -> ServerRequest: if isinstance(adaptor, cls): return adaptor.request raise ServerRequest( diff --git a/utilmeta/core/request/backends/werkzeug.py b/utilmeta/core/request/backends/werkzeug.py index eca7aad..c1c9f23 100644 --- a/utilmeta/core/request/backends/werkzeug.py +++ b/utilmeta/core/request/backends/werkzeug.py @@ -2,7 +2,7 @@ from werkzeug.wrappers import Request from utilmeta.core.file.backends.werkzeug import WerkzeugFileAdaptor from utilmeta.core.file.base import File -from utilmeta.utils import Headers, exceptions, HAS_BODY_METHODS +from utilmeta.utils import Headers, HAS_BODY_METHODS import werkzeug @@ -50,7 +50,7 @@ def query_params(self): @property def body(self): - return self._body or b'' + return self._body or b"" @property def headers(self): diff --git a/utilmeta/core/request/base.py b/utilmeta/core/request/base.py index 5c1cbd5..95d7671 100644 --- a/utilmeta/core/request/base.py +++ b/utilmeta/core/request/base.py @@ -9,7 +9,7 @@ __all__ = [ - 'Request', + "Request", ] @@ -29,23 +29,22 @@ def apply_for(cls, req): return req return cls(req) - def __init__(self, - request=None, *, - method: str = None, - url: str = None, - query: dict = None, - data=None, - headers: Union[Mapping, Dict[str, str]] = None, - backend=None, - ): + def __init__( + self, + request=None, + *, + method: str = None, + url: str = None, + query: dict = None, + data=None, + headers: Union[Mapping, Dict[str, str]] = None, + backend=None, + ): if not request: from .client import ClientRequest + request = ClientRequest( - method=method, - url=url, - query=query, - data=data, - headers=headers + method=method, url=url, query=query, data=data, headers=headers ) self.adaptor = RequestAdaptor.dispatch(request) @@ -75,7 +74,7 @@ def path(self) -> str: @property def traffic(self): - traffic = self.adaptor.get_context('traffic') + traffic = self.adaptor.get_context("traffic") if traffic: return traffic value = 12 # HTTP/1.1 200 OK \r\n @@ -117,12 +116,12 @@ def headers(self) -> dict: @property def authorization(self) -> Tuple[Optional[str], Optional[str]]: - auth: str = self.headers.get('authorization') + auth: str = self.headers.get("authorization") if not auth: return None, None - if ' ' in auth: + if " " in auth: lst = auth.split() - return lst[0], ' '.join(lst[1:]) + return lst[0], " ".join(lst[1:]) return None, auth @property diff --git a/utilmeta/core/request/client.py b/utilmeta/core/request/client.py index 14ef678..6c712dd 100644 --- a/utilmeta/core/request/client.py +++ b/utilmeta/core/request/client.py @@ -4,24 +4,35 @@ from .backends.base import RequestAdaptor from http.cookies import SimpleCookie from urllib.parse import urlsplit, urlencode, urlunsplit -from utilmeta.utils import Headers, pop, file_like, guess_mime_type, \ - RequestType, encode_multipart_form, json_dumps, multi,\ - parse_query_string, parse_query_dict +from utilmeta.utils import ( + Headers, + pop, + file_like, + guess_mime_type, + RequestType, + encode_multipart_form, + json_dumps, + multi, + parse_query_string, + parse_query_dict, +) from collections.abc import Mapping from utilmeta.core.file import File class ClientRequest: - def __init__(self, - method: str, - url: str, - query: dict = None, - data=None, - headers: Dict[str, str] = None): + def __init__( + self, + method: str, + url: str, + query: dict = None, + data=None, + headers: Dict[str, str] = None, + ): self._method = method self.headers = Headers(headers or {}) - cookie = SimpleCookie(self.headers.get('cookie', {})) + cookie = SimpleCookie(self.headers.get("cookie", {})) self.cookies = {k: v.value for k, v in cookie.items()} self._data = data @@ -30,7 +41,7 @@ def __init__(self, self._form: Optional[dict] = None self._json: Union[dict, list, None] = None - self._url = url or '' + self._url = url or "" self._query = parse_query_dict(query or {}) self.build_url() @@ -81,26 +92,28 @@ def build_url(self): url_parsed = urlsplit(self._url) url_query = parse_query_string(url_parsed.query) if url_parsed.query else {} url_query.update(self._query) - self._url = urlunsplit(( - url_parsed.scheme, - url_parsed.netloc, - url_parsed.path, - urlencode(url_query), - url_parsed.fragment - )) + self._url = urlunsplit( + ( + url_parsed.scheme, + url_parsed.netloc, + url_parsed.path, + urlencode(url_query), + url_parsed.fragment, + ) + ) @property def route(self): - return urlsplit(self.url).path.strip('/') + return urlsplit(self.url).path.strip("/") @property def content_type(self) -> Optional[str]: - return self.headers.get('content-type') + return self.headers.get("content-type") @content_type.setter def content_type(self, t): if t: - self.headers['content-type'] = t + self.headers["content-type"] = t @property def contains_files(self): @@ -140,7 +153,7 @@ def build_body(self): if self.data is None: return - pop(self.headers, 'content-length') + pop(self.headers, "content-length") # there are difference between JSON.stringify in js and json.dumps in python # while JSON.stringify leave not spaces and json.dumps leave space between a comma and next key # difference like {"a":1,"b":2} and {"a": 1, "b": 2} @@ -156,7 +169,11 @@ def build_body(self): if self.content_type: if self.content_type.startswith(RequestType.JSON): - self._json = json.loads(self.data) if isinstance(self.data, bytes) else json.load(self.data) + self._json = ( + json.loads(self.data) + if isinstance(self.data, bytes) + else json.load(self.data) + ) elif self.content_type.startswith(RequestType.FORM_URLENCODED): qs = self.data @@ -223,7 +240,7 @@ def build_body(self): self._json = list(self.data) elif file_like(self.data): - name = getattr(self.data, 'name', None) + name = getattr(self.data, "name", None) content_type = None if name: content_type, encoding = guess_mime_type(name) @@ -253,7 +270,8 @@ def qualify(cls, obj): @property def address(self): from ipaddress import ip_address - return ip_address('127.0.0.1') + + return ip_address("127.0.0.1") # @property # def content_type(self) -> Optional[str]: @@ -266,7 +284,7 @@ def address(self): @property def content_length(self) -> int: - length = self.headers.get('content-length') + length = self.headers.get("content-length") if length is not None: return int(length or 0) if self.body: @@ -305,7 +323,7 @@ def path(self): @property def scheme(self): - return urlsplit(self.request.url).scheme or 'http' + return urlsplit(self.request.url).scheme or "http" @property def headers(self): diff --git a/utilmeta/core/request/properties.py b/utilmeta/core/request/properties.py index 6401563..335d49d 100644 --- a/utilmeta/core/request/properties.py +++ b/utilmeta/core/request/properties.py @@ -12,30 +12,30 @@ __all__ = [ - 'URL', - 'Host', - 'PathParam', - 'FilePathParam', - 'SlugPathParam', - 'BodyParam', - 'Body', - 'Form', - 'Json', - 'EncodingField', - 'Query', - 'QueryParam', - 'Headers', - 'HeaderParam', - 'Cookies', - 'CookieParam', - 'UserAgent', - 'Address', - 'Time' + "URL", + "Host", + "PathParam", + "FilePathParam", + "SlugPathParam", + "BodyParam", + "Body", + "Form", + "Json", + "EncodingField", + "Query", + "QueryParam", + "Headers", + "HeaderParam", + "Cookies", + "CookieParam", + "UserAgent", + "Address", + "Time", ] class Path(Property): - __ident__ = 'path' + __ident__ = "path" @classmethod def getter(cls, request: Request, *keys: str): @@ -43,7 +43,7 @@ def getter(cls, request: Request, *keys: str): class URL(Property): - __ident__ = 'url' + __ident__ = "url" @classmethod def getter(cls, request: Request, field: ParserField = None): @@ -55,13 +55,14 @@ class Host(Property): def getter(cls, request: Request, field: ParserField = None): return request.host - def __init__(self, - allow_list: list = None, - block_list: list = None, - local_only: bool = False, - private_only: bool = False, - public_only: bool = False, - ): + def __init__( + self, + allow_list: list = None, + block_list: list = None, + local_only: bool = False, + private_only: bool = False, + public_only: bool = False, + ): super().__init__() self.allow_list = allow_list self.block_list = block_list @@ -71,15 +72,15 @@ def __init__(self, class Body(Property): - PLAIN = 'text/plain' - JSON = 'application/json' - FORM_URLENCODED = 'application/x-www-form-urlencoded' - FORM_DATA = 'multipart/form-data' - XML = 'application/xml' - OCTET_STREAM = 'application/octet-stream' - - __ident__ = 'body' - __name_prefix__ = 'request' + PLAIN = "text/plain" + JSON = "application/json" + FORM_URLENCODED = "application/x-www-form-urlencoded" + FORM_DATA = "multipart/form-data" + XML = "application/xml" + OCTET_STREAM = "application/octet-stream" + + __ident__ = "body" + __name_prefix__ = "request" __no_default__ = True content_type = None @@ -106,24 +107,35 @@ async def getter(self, request: Request, field: ParserField = None): def validate_content_type(self, request: Request): if self.content_type and request.content_type != self.content_type: - raise exc.UnprocessableEntity(f'invalid content type: {request.content_type}') + raise exc.UnprocessableEntity( + f"invalid content type: {request.content_type}" + ) def validate_max_length(self, request: Request): - if self.max_length and request.content_length and request.content_length > self.max_length: + if ( + self.max_length + and request.content_length + and request.content_length > self.max_length + ): raise exc.RequestEntityTooLarge - def __init__(self, content_type: str = None, *, - description: str = None, - example: Any = None, - # options=None, - # default=unprovided, - max_length: int = None, **kwargs): + def __init__( + self, + content_type: str = None, + *, + description: str = None, + example: Any = None, + # options=None, + # default=unprovided, + max_length: int = None, + **kwargs, + ): init_kwargs = dict( max_length=max_length, # options=options, description=description, example=example, - **kwargs + **kwargs, ) super().__init__(**init_kwargs) # if content_type: @@ -139,16 +151,11 @@ def __init__(self, content_type: str = None, *, class Json(Body): - content_type = 'application/json' + content_type = "application/json" - def __init__(self, *, - description: str = None, - example: Any = None, **kwargs): + def __init__(self, *, description: str = None, example: Any = None, **kwargs): super().__init__( - self.content_type, - description=description, - example=example, - **kwargs + self.content_type, description=description, example=example, **kwargs ) # def getter(self, request: Request, field: ParserField = None): @@ -164,21 +171,19 @@ def __init__(self, *, class Form(Body): - content_type = 'multipart/form-data' + content_type = "multipart/form-data" - def __init__(self, *, - description: str = None, - example: Any = None, **kwargs): + def __init__(self, *, description: str = None, example: Any = None, **kwargs): super().__init__( - self.content_type, - description=description, - example=example, - **kwargs + self.content_type, description=description, example=example, **kwargs ) def validate_content_type(self, request: Request): - if self.content_type and request.content_type not in (Body.FORM_URLENCODED, Body.FORM_DATA): - raise exc.UnprocessableEntity('invalid content type') + if self.content_type and request.content_type not in ( + Body.FORM_URLENCODED, + Body.FORM_DATA, + ): + raise exc.UnprocessableEntity("invalid content type") # def getter(self, request: Request, field: ParserField = None): # if not request.adaptor.form_type: @@ -196,18 +201,23 @@ class ObjectProperty(Property): def init(self, field: ParserField): t = field.type if not t or not isinstance(t, type) or isinstance(None, t): - raise TypeError(f'{self.__class__}: {repr(field.name)} should specify a valid object type') + raise TypeError( + f"{self.__class__}: {repr(field.name)} should specify a valid object type" + ) if not issubclass(t, Mapping): from utype.parser.cls import ClassParser - parser = getattr(t, '__parser__', None) + + parser = getattr(t, "__parser__", None) if not isinstance(parser, ClassParser): - raise TypeError(f'{self.__class__}: {repr(field.name)} should specify a valid object type, got {t}') + raise TypeError( + f"{self.__class__}: {repr(field.name)} should specify a valid object type, got {t}" + ) return super().init(field) class Query(ObjectProperty): - __ident__ = 'query' - __name_prefix__ = 'request' + __ident__ = "query" + __name_prefix__ = "request" __type__ = dict __no_default__ = True @@ -253,8 +263,8 @@ def getter(cls, request: Request, field: ParserField = None): class Headers(ObjectProperty): - __ident__ = 'header' # according to OpenAPI, not "headers" - __name_prefix__ = 'request' + __ident__ = "header" # according to OpenAPI, not "headers" + __name_prefix__ = "request" __type__ = dict __no_default__ = True @@ -282,8 +292,8 @@ def getter(cls, request: Request, field: ParserField = None): class Cookies(ObjectProperty): - __ident__ = 'cookie' # according to OpenAPI, not "cookies" - __name_prefix__ = 'request' + __ident__ = "cookie" # according to OpenAPI, not "cookies" + __name_prefix__ = "request" @classmethod def getter(cls, request: Request, field: ParserField = None): @@ -291,7 +301,7 @@ def getter(cls, request: Request, field: ParserField = None): class RequestParam(Property): - __name_prefix__ = 'request' + __name_prefix__ = "request" def get_value(self, data: Mapping, field: ParserField): if isinstance(data, Mapping): @@ -304,14 +314,14 @@ def get_value(self, data: Mapping, field: ParserField): def getter(self, request: Request, field: ParserField = None): if not field: - raise ValueError(f'field required') + raise ValueError(f"field required") data = self.get_mapping(request) return self.get_value(data, field) @awaitable(getter) async def getter(self, request: Request, field: ParserField = None): if not field: - raise ValueError(f'field required') + raise ValueError(f"field required") data = self.get_mapping(request) if inspect.isawaitable(data): data = await data @@ -326,26 +336,26 @@ def get_mapping(cls, request: Request) -> Optional[Mapping]: async def get_mapping(cls, request: Request) -> Optional[Mapping]: raise NotImplementedError - def __init__(self, - alias: str = None, - default=unprovided, - required: bool = None, - style: str = None, - **kwargs): + def __init__( + self, + alias: str = None, + default=unprovided, + required: bool = None, + style: str = None, + **kwargs, + ): if required: default = unprovided super().__init__( alias=alias or self.alias_generator, default=default, required=required, - **kwargs + **kwargs, ) self.style = style # refer to https://swagger.io/specification/ try: - self._update_spec( - style=style - ) + self._update_spec(style=style) except AttributeError: pass @@ -383,20 +393,28 @@ def get_mapping(cls, request: Request) -> Optional[Mapping]: __in__ = Path __no_default__ = True - regex = '[^/]+' # default regex, can be override - - def __init__(self, regex: str = None, *, min_length: str = None, max_length: str = None, - required: bool = True, default=unprovided, **kwargs): + regex = "[^/]+" # default regex, can be override + + def __init__( + self, + regex: str = None, + *, + min_length: str = None, + max_length: str = None, + required: bool = True, + default=unprovided, + **kwargs, + ): if not regex: if min_length: if max_length: - regex = '(.{%s,%s})' % (min_length, max_length) + regex = "(.{%s,%s})" % (min_length, max_length) elif min_length == 1: - regex = '(.+)' + regex = "(.+)" else: - regex = '(.{%s,})' % min_length + regex = "(.{%s,})" % min_length elif max_length: - regex = '(.{0,%s})' % max_length + regex = "(.{0,%s})" % max_length if regex: self.regex = regex self.min_length = min_length @@ -407,12 +425,7 @@ def __init__(self, regex: str = None, *, min_length: str = None, max_length: str self.required = required - super().__init__( - regex=self.regex, - required=required, - default=default, - **kwargs - ) + super().__init__(regex=self.regex, required=required, default=default, **kwargs) class SlugPathParam(PathParam): @@ -420,7 +433,7 @@ class SlugPathParam(PathParam): class FilePathParam(PathParam): - regex = r'(.*)' + regex = r"(.*)" class QueryParam(RequestParam): @@ -444,7 +457,7 @@ def get_mapping(cls, request: Request): elif request.adaptor.form_type: mp = request.adaptor.get_form() else: - raise exc.UnprocessableEntity(f'invalid content type, must be json or form') + raise exc.UnprocessableEntity(f"invalid content type, must be json or form") data.set(mp) return mp @@ -457,39 +470,38 @@ async def get_mapping(cls, request: Request): if request.adaptor.json_type or request.adaptor.form_type: mp = await request.adaptor.async_load() else: - raise exc.UnprocessableEntity(f'invalid content type, must be json or form') + raise exc.UnprocessableEntity(f"invalid content type, must be json or form") data.set(mp) return mp class EncodingField(Field): - def __init__(self, - content_type: str = None, - # support for mixed encoding body (multipart/mixed) - # will integrate into encoding in requestBody - *, - description: str = None, - example: Any = None, - # options=None, - max_length: int = None, - headers: dict = None, - **kwargs): + def __init__( + self, + content_type: str = None, + # support for mixed encoding body (multipart/mixed) + # will integrate into encoding in requestBody + *, + description: str = None, + example: Any = None, + # options=None, + max_length: int = None, + headers: dict = None, + **kwargs, + ): super().__init__( max_length=max_length, # options=options, description=description, example=example, - **kwargs + **kwargs, ) self.content_type = content_type # self.options = options self.max_length = max_length # Body Too Long self.headers = headers try: - self._update_spec( - content_type=content_type, - headers=headers - ) + self._update_spec(content_type=content_type, headers=headers) except AttributeError: pass @@ -503,7 +515,8 @@ def get_mapping(cls, request: Request): @classmethod def alias_generator(cls, key: str): - return key.replace('_', '-') + return key.replace("_", "-") + # # class Authorization(HeaderParam): @@ -550,30 +563,37 @@ def get_mapping(cls, request: Request): class UserAgent(Property): __in__ = Headers - __key__ = 'user-agent' + __key__ = "user-agent" @classmethod def getter(cls, request: Request, field: ParserField = None): - return request.headers.get('User-Agent', unprovided) - - def __init__(self, - regex: str = None, - os_regex: str = None, - device_regex: str = None, - browser_regex: str = None, - bot: bool = None, - pc: bool = None, - mobile: bool = None, - tablet: bool = None): + return request.headers.get("User-Agent", unprovided) + + def __init__( + self, + regex: str = None, + os_regex: str = None, + device_regex: str = None, + browser_regex: str = None, + bot: bool = None, + pc: bool = None, + mobile: bool = None, + tablet: bool = None, + ): super().__init__(regex=regex) # None: no restriction whether or not agent is match # True: request agent must be ... # False: request agent must not be ... if bot is False: assert not pc and not mobile and not tablet - assert [pc, mobile, tablet].count(True) <= 1, f'Request Agent cannot specify multiple platform' - assert {bot, pc, mobile, tablet} != {None}, f'Request Agent must specify some rules' + assert [pc, mobile, tablet].count( + True + ) <= 1, f"Request Agent cannot specify multiple platform" + assert {bot, pc, mobile, tablet} != { + None + }, f"Request Agent must specify some rules" import re + self.regex = re.compile(regex) if regex else None self.os_regex = re.compile(os_regex) if os_regex else None self.device_regex = re.compile(device_regex) if device_regex else None @@ -585,43 +605,45 @@ def __init__(self, def runtime_validate(self, user_agent): try: - from user_agents.parsers import UserAgent # noqa + from user_agents.parsers import UserAgent # noqa except ModuleNotFoundError: - raise ModuleNotFoundError('UserAgent validation requires to install [user_agents] package') + raise ModuleNotFoundError( + "UserAgent validation requires to install [user_agents] package" + ) user_agent: UserAgent if self.regex: if not self.regex.search(user_agent.ua_string): - raise exc.PermissionDenied('Request Agent is denied') + raise exc.PermissionDenied("Request Agent is denied") if self.os_regex: if not self.os_regex.search(user_agent.os): - raise exc.PermissionDenied('Request Agent is denied') + raise exc.PermissionDenied("Request Agent is denied") if self.device_regex: if not self.device_regex.search(user_agent.device): - raise exc.PermissionDenied('Request Agent is denied') + raise exc.PermissionDenied("Request Agent is denied") if self.browser_regex: if not self.browser_regex.search(user_agent.browser): - raise exc.PermissionDenied('Request Agent is denied') + raise exc.PermissionDenied("Request Agent is denied") if self.bot is not None: if self.bot ^ user_agent.is_bot: - raise exc.PermissionDenied('Request Agent is denied') + raise exc.PermissionDenied("Request Agent is denied") if self.pc is not None: if self.pc ^ user_agent.is_pc: - raise exc.PermissionDenied('Request Agent is denied') + raise exc.PermissionDenied("Request Agent is denied") if self.mobile is not None: if self.mobile ^ user_agent.is_mobile: - raise exc.PermissionDenied('Request Agent is denied') + raise exc.PermissionDenied("Request Agent is denied") if self.tablet is not None: if self.tablet ^ user_agent.is_tablet: - raise exc.PermissionDenied('Request Agent is denied') + raise exc.PermissionDenied("Request Agent is denied") class Address(Property): @@ -631,14 +653,16 @@ class Address(Property): def getter(cls, request: Request, field: ParserField = None): return request.ip_address - def __init__(self, - block_list: List[Union[IPv4Network, IPv6Network, str]] = None, - allow_list: List[Union[IPv4Network, IPv6Network, str]] = None, - ipv4_only: bool = None, - ipv6_only: bool = None, - local_only: bool = None, # for micro-service integration - private_only: bool = None, - public_only: bool = None): + def __init__( + self, + block_list: List[Union[IPv4Network, IPv6Network, str]] = None, + allow_list: List[Union[IPv4Network, IPv6Network, str]] = None, + ipv4_only: bool = None, + ipv6_only: bool = None, + local_only: bool = None, # for micro-service integration + private_only: bool = None, + public_only: bool = None, + ): super().__init__() self.block_list = block_list self.allow_list = allow_list @@ -656,11 +680,12 @@ class Time(Property): def getter(cls, request: Request, field: ParserField = None): return request.time - def __init__(self, - not_before: datetime = None, # open time - not_after: datetime = None, # close time - time_zone: str = None, - ): + def __init__( + self, + not_before: datetime = None, # open time + not_after: datetime = None, # close time + time_zone: str = None, + ): super().__init__(required=False) self.not_before = not_before self.not_after = not_after diff --git a/utilmeta/core/request/var.py b/utilmeta/core/request/var.py index e214795..e11b225 100644 --- a/utilmeta/core/request/var.py +++ b/utilmeta/core/request/var.py @@ -13,11 +13,17 @@ class RequestContextVar(Property): - def __init__(self, key: str, cached: bool = False, static: bool = False, - default=None, factory: Callable = None): + def __init__( + self, + key: str, + cached: bool = False, + static: bool = False, + default=None, + factory: Callable = None, + ): super().__init__( default_factory=default if callable(default) else None, - default=default if not callable(default) else unprovided + default=default if not callable(default) else unprovided, ) self.key = key self.default = default @@ -25,7 +31,7 @@ def __init__(self, key: str, cached: bool = False, static: bool = False, self.cached = cached self.static = static - def setup(self, request: 'Request'): + def setup(self, request: "Request"): class c: @staticmethod def contains(): @@ -50,10 +56,10 @@ def delete(): return c - def contains(self, request: 'Request'): + def contains(self, request: "Request"): return request.adaptor.in_context(self.key) - def getter(self, request: 'Request', field=None, default=unprovided): + def getter(self, request: "Request", field=None, default=unprovided): r = default if self.contains(request): r = request.adaptor.get_context(self.key) @@ -70,7 +76,7 @@ def getter(self, request: 'Request', field=None, default=unprovided): return r @awaitable(getter) - async def getter(self, request: 'Request', field=None, default=unprovided): + async def getter(self, request: "Request", field=None, default=unprovided): r = default if self.contains(request): r = request.adaptor.get_context(self.key) @@ -92,12 +98,12 @@ async def getter(self, request: 'Request', field=None, default=unprovided): self.setter(request, r) return r - def setter(self, request: 'Request', value, field=None): + def setter(self, request: "Request", value, field=None): if self.static and self.contains(request): return request.adaptor.update_context(**{self.key: value}) - def deleter(self, request: 'Request', field=None): + def deleter(self, request: "Request", field=None): if self.static and self.contains(request): return request.adaptor.delete_context(self.key) @@ -106,24 +112,28 @@ def register_factory(self, func, force: bool = False): if self.factory: if self.factory != func: if force: - raise ValueError(f'factory conflicted: {func}, {self.factory}') + raise ValueError(f"factory conflicted: {func}, {self.factory}") else: - warnings.warn(f'factory conflicted: {func}, {self.factory}') + warnings.warn(f"factory conflicted: {func}, {self.factory}") return self.factory = func # cached context var -user = RequestContextVar('_user', cached=True) -user_id = RequestContextVar('_user_id', cached=True) -scopes = RequestContextVar('_scopes', cached=True) -data = RequestContextVar('_data', cached=True) # parsed str/dict data +user = RequestContextVar("_user", cached=True) +user_id = RequestContextVar("_user_id", cached=True) +scopes = RequestContextVar("_scopes", cached=True) +data = RequestContextVar("_data", cached=True) # parsed str/dict data # variable context var -time = RequestContextVar('_time', factory=lambda request: request.adaptor.time, static=True) -path_params = RequestContextVar('_path_params', default=dict) -allow_methods = RequestContextVar('_allow_methods', default=list) -allow_headers = RequestContextVar('_allow_headers', default=list) -unmatched_route = RequestContextVar('_unmatched_route', factory=lambda request: request.adaptor.route) -operation_names = RequestContextVar('_operation_names', default=list) +time = RequestContextVar( + "_time", factory=lambda request: request.adaptor.time, static=True +) +path_params = RequestContextVar("_path_params", default=dict) +allow_methods = RequestContextVar("_allow_methods", default=list) +allow_headers = RequestContextVar("_allow_headers", default=list) +unmatched_route = RequestContextVar( + "_unmatched_route", factory=lambda request: request.adaptor.route +) +operation_names = RequestContextVar("_operation_names", default=list) # all the passing-by route's name, to combine the endpoint operationId -endpoint_ref = RequestContextVar('_endpoint_ref', default=None) +endpoint_ref = RequestContextVar("_endpoint_ref", default=None) diff --git a/utilmeta/core/response/backends/aiohttp.py b/utilmeta/core/response/backends/aiohttp.py index 756e706..2cab3a3 100644 --- a/utilmeta/core/response/backends/aiohttp.py +++ b/utilmeta/core/response/backends/aiohttp.py @@ -1,5 +1,6 @@ from aiohttp.client_reqrep import ClientResponse from aiohttp.web_response import Response as ServerResponse + # from utilmeta.utils import async_to_sync from .base import ResponseAdaptor @@ -25,7 +26,7 @@ def headers(self): @property def body(self) -> bytes: - return getattr(self.response, '_body', None) + return getattr(self.response, "_body", None) async def async_read(self) -> bytes: return await self.response.read() @@ -35,7 +36,7 @@ async def async_load(self): return await self.response.text() elif self.json_type: return await self.response.json() - self.__dict__['body'] = await self.async_read() + self.__dict__["body"] = await self.async_read() return self.get_content() @property diff --git a/utilmeta/core/response/backends/base.py b/utilmeta/core/response/backends/base.py index ab8edc3..751336c 100644 --- a/utilmeta/core/response/backends/base.py +++ b/utilmeta/core/response/backends/base.py @@ -38,7 +38,7 @@ def headers(self): @property def cookies(self): - return SimpleCookie(self.headers.get('set-cookie')) + return SimpleCookie(self.headers.get("set-cookie")) @property def body(self) -> bytes: @@ -50,9 +50,9 @@ def charset(self) -> Optional[str]: if not ct: return None ct = str(ct) - for value in ct.split(';'): - if value.strip().startswith('charset='): - return value.split('=')[1].strip() + for value in ct.split(";"): + if value.strip().startswith("charset="): + return value.split("=")[1].strip() return None @utils.cached_property @@ -61,8 +61,8 @@ def content_type(self) -> Optional[str]: if not ct: return ct = str(ct) - if ';' in ct: - return ct.split(';')[0].strip() + if ";" in ct: + return ct.split(";")[0].strip() return ct @property @@ -96,7 +96,7 @@ def text_type(self): content_type = self.content_type if not content_type: return False - return content_type.startswith('text') + return content_type.startswith("text") # @property # def file_type(self): @@ -135,16 +135,18 @@ def get_content(self): def get_file(self): from io import BytesIO from utilmeta.core.file import File + return File(BytesIO(self.body)) # from utilmeta.utils.media import File # return File(file=BytesIO(self.body)) def get_text(self) -> str: - return self.body.decode(encoding=self.charset or 'utf-8', errors='replace') + return self.body.decode(encoding=self.charset or "utf-8", errors="replace") def get_json(self) -> Union[dict, list, None]: text = self.get_text() import json + try: return json.loads(text, cls=self.json_decoder_cls) except json.decoder.JSONDecodeError: @@ -152,12 +154,13 @@ def get_json(self) -> Union[dict, list, None]: def get_xml(self): from xml.etree.ElementTree import XMLParser + parser = XMLParser() parser.feed(self.body) return parser.close() async def async_load(self): - self.__dict__['body'] = await self.async_read() + self.__dict__["body"] = await self.async_read() return self.get_content() async def async_read(self): diff --git a/utilmeta/core/response/backends/django.py b/utilmeta/core/response/backends/django.py index b3f5dd7..bd07d5a 100644 --- a/utilmeta/core/response/backends/django.py +++ b/utilmeta/core/response/backends/django.py @@ -1,4 +1,9 @@ -from django.http.response import StreamingHttpResponse, HttpResponse, HttpResponseBase, FileResponse +from django.http.response import ( + StreamingHttpResponse, + HttpResponse, + HttpResponseBase, + FileResponse, +) from typing import Union, TYPE_CHECKING from .base import ResponseAdaptor import django @@ -17,11 +22,12 @@ def qualify(cls, obj): return isinstance(obj, HttpResponseBase) @classmethod - def reconstruct(cls, resp: Union['ResponseAdaptor', 'Response']): + def reconstruct(cls, resp: Union["ResponseAdaptor", "Response"]): if isinstance(resp, (HttpResponse, StreamingHttpResponse)): return resp from utilmeta.core.response import Response + if isinstance(resp, ResponseAdaptor): resp = Response(response=resp) elif not isinstance(resp, Response): diff --git a/utilmeta/core/response/backends/httpx.py b/utilmeta/core/response/backends/httpx.py index d939de4..82ff467 100644 --- a/utilmeta/core/response/backends/httpx.py +++ b/utilmeta/core/response/backends/httpx.py @@ -35,12 +35,12 @@ async def async_load(self): return self.response.text elif self.json_type: return self.response.json() - self.__dict__['body'] = self.response.content + self.__dict__["body"] = self.response.content return self.get_content() @property def cookies(self): - set_cookie = self.headers.get_list('set-cookie') + set_cookie = self.headers.get_list("set-cookie") cookies = SimpleCookie() if set_cookie: for cookie in set_cookie: diff --git a/utilmeta/core/response/backends/sanic.py b/utilmeta/core/response/backends/sanic.py index d0e01cb..49e4d1f 100644 --- a/utilmeta/core/response/backends/sanic.py +++ b/utilmeta/core/response/backends/sanic.py @@ -23,7 +23,7 @@ def qualify(cls, obj): return isinstance(obj, HTTPResponse) @classmethod - def reconstruct(cls, resp: Union['ResponseAdaptor', 'Response']): + def reconstruct(cls, resp: Union["ResponseAdaptor", "Response"]): if isinstance(resp, HTTPResponse): return resp @@ -38,7 +38,7 @@ def reconstruct(cls, resp: Union['ResponseAdaptor', 'Response']): resp.prepare_body(), status=resp.status, headers=Header(resp.prepare_headers()), - content_type=resp.content_type + content_type=resp.content_type, ) return response diff --git a/utilmeta/core/response/backends/starlette.py b/utilmeta/core/response/backends/starlette.py index ebcd978..ab88974 100644 --- a/utilmeta/core/response/backends/starlette.py +++ b/utilmeta/core/response/backends/starlette.py @@ -16,7 +16,7 @@ def qualify(cls, obj): return isinstance(obj, HttpResponse) @classmethod - def reconstruct(cls, resp: Union['ResponseAdaptor', 'Response']): + def reconstruct(cls, resp: Union["ResponseAdaptor", "Response"]): if isinstance(resp, HttpResponse): return resp @@ -27,10 +27,7 @@ def reconstruct(cls, resp: Union['ResponseAdaptor', 'Response']): elif not isinstance(resp, Response): resp = Response(resp) - kwargs = dict( - status_code=resp.status, - media_type=resp.content_type - ) + kwargs = dict(status_code=resp.status, media_type=resp.content_type) # file will not be closed if using this # file = resp.file # if file: @@ -60,13 +57,14 @@ def headers(self): @property def body(self): # StreamResponse does not have body attribute - return getattr(self.response, 'body', b'') + return getattr(self.response, "body", b"") @property def cookies(self): from http.cookies import SimpleCookie + cookies = SimpleCookie() - for cookie in self.response.headers.getlist('set-cookie'): + for cookie in self.response.headers.getlist("set-cookie"): cookies.load(cookie) return cookies diff --git a/utilmeta/core/response/backends/urllib.py b/utilmeta/core/response/backends/urllib.py index 8bd757d..1e72e07 100644 --- a/utilmeta/core/response/backends/urllib.py +++ b/utilmeta/core/response/backends/urllib.py @@ -34,8 +34,9 @@ def body(self): @property def cookies(self): from http.cookies import SimpleCookie + cookies = SimpleCookie() - for cookie in self.response.headers.get_all('Set-Cookie') or []: + for cookie in self.response.headers.get_all("Set-Cookie") or []: # use get_all, cause Set-Cookie can be multiple cookies.load(cookie) return cookies diff --git a/utilmeta/core/response/backends/werkzeug.py b/utilmeta/core/response/backends/werkzeug.py index f57afe1..96efd68 100644 --- a/utilmeta/core/response/backends/werkzeug.py +++ b/utilmeta/core/response/backends/werkzeug.py @@ -7,7 +7,7 @@ class WerkzeugResponseAdaptor(ResponseAdaptor): response: WerkzeugResponse @classmethod - def reconstruct(cls, resp: Union['ResponseAdaptor', 'WerkzeugResponse']): + def reconstruct(cls, resp: Union["ResponseAdaptor", "WerkzeugResponse"]): if isinstance(resp, WerkzeugResponse): return resp @@ -36,7 +36,7 @@ def status(self): @property def reason(self): - return '' + return "" @property def headers(self): diff --git a/utilmeta/core/response/base.py b/utilmeta/core/response/base.py index 48c4baa..41580e5 100644 --- a/utilmeta/core/response/base.py +++ b/utilmeta/core/response/base.py @@ -7,9 +7,18 @@ from utilmeta.core.request import Request from utype.types import * -from utilmeta.utils import Header, \ - get_generator_result, get_doc, is_hop_by_hop, http_time, file_like, \ - STATUS_WITHOUT_BODY, time_now, multi, guess_mime_type +from utilmeta.utils import ( + Header, + get_generator_result, + get_doc, + is_hop_by_hop, + http_time, + file_like, + STATUS_WITHOUT_BODY, + time_now, + multi, + guess_mime_type, +) from utilmeta.utils import exceptions as exc from utilmeta.utils import Headers from utilmeta.conf import Preference @@ -22,21 +31,22 @@ import re from ..file.base import File from ..file.backends.base import FileAdaptor + # from utype.parser.rule import LogicalType class ResponseClassParser(ClassParser): - NAMES = ('result', 'headers') + NAMES = ("result", "headers") @classmethod def validate_field_name(cls, name: str): return name in cls.NAMES -PLAIN = 'text/plain' -JSON = 'application/json' -XML = 'text/xml' -OCTET_STREAM = 'application/octet-stream' +PLAIN = "text/plain" +JSON = "application/json" +XML = "text/xml" +OCTET_STREAM = "application/octet-stream" class Response: @@ -67,7 +77,7 @@ class Response: reason: str = None charset: str = None content_type: Optional[str] = None - headers: Headers # can be any inherited map, or assign to a HeadersSchema + headers: Headers # can be any inherited map, or assign to a HeadersSchema cookies: SimpleCookie name: str = None description: str = None @@ -106,7 +116,7 @@ class OperationAResponse(Response): name = item else: name = get_obj_name(item) - response_name = f'{cls.__name__}_{name}' + response_name = f"{cls.__name__}_{name}" class _response(cls): if isinstance(item, int): @@ -115,14 +125,18 @@ class _response(cls): result: item _response.__name__ = response_name - _response.__qualname__ = '.'.join(cls.__qualname__.split('.')[:-1] + [response_name]) + _response.__qualname__ = ".".join( + cls.__qualname__.split(".")[:-1] + [response_name] + ) return _response def __init_subclass__(cls, **kwargs): cls.__parser__ = cls.__parser_cls__.apply_for(cls) cls.description = cls.description or get_doc(cls) - cls.wrapped = bool(cls.result_key or cls.count_key or cls.message_key or cls.state_key) + cls.wrapped = bool( + cls.result_key or cls.count_key or cls.message_key or cls.state_key + ) if not cls.content_type and cls.wrapped: cls.content_type = JSON @@ -130,39 +144,38 @@ def __init_subclass__(cls, **kwargs): keys = [cls.result_key, cls.message_key, cls.count_key, cls.state_key] wrap_keys = [k for k in keys if k is not None] if len(set(wrap_keys)) < len(wrap_keys): - raise ValueError(f'{cls.__name__}: conflict response keys: {wrap_keys}') - - def __init__(self, - result=None, - *, - state=None, - message=None, # can be str or error or dict/list of messages - count: int = None, - reason: str = None, - status: int = None, - extra: dict = None, - - content: Union[bytes, dict, list, str] = None, - content_type: str = None, - charset: str = None, - headers=None, - cookies=None, - - # store the original context - request: Request = None, - response=None, - error: Union[Error, Exception] = None, - file=None, - attachment=None, - # metadata - mocked: bool = False, - cached: bool = False, - timeout: bool = False, - aborted: bool = False, - # when timeout set to True, raw_response is None - stack: list = None, - strict: bool = None - ): + raise ValueError(f"{cls.__name__}: conflict response keys: {wrap_keys}") + + def __init__( + self, + result=None, + *, + state=None, + message=None, # can be str or error or dict/list of messages + count: int = None, + reason: str = None, + status: int = None, + extra: dict = None, + content: Union[bytes, dict, list, str] = None, + content_type: str = None, + charset: str = None, + headers=None, + cookies=None, + # store the original context + request: Request = None, + response=None, + error: Union[Error, Exception] = None, + file=None, + attachment=None, + # metadata + mocked: bool = False, + cached: bool = False, + timeout: bool = False, + aborted: bool = False, + # when timeout set to True, raw_response is None + stack: list = None, + strict: bool = None, + ): self.adaptor = None @@ -285,16 +298,14 @@ def close(self, fail_silently=True): except Exception as e: if not fail_silently: raise - warnings.warn(f'close response: {Self} failed with error: {e}') + warnings.warn(f"close response: {Self} failed with error: {e}") def parse_content(self): if self.result is not None: return if isinstance(self._content, dict) and self.wrapped: if self.result_key: - self.init_result( - self._content.get(self.result_key, self._content) - ) + self.init_result(self._content.get(self.result_key, self._content)) if self.message_key: self.message = self._content.get(self.message_key) if self.state_key: @@ -311,7 +322,10 @@ def match(self): return False if self.__class__.state and self.__class__.state != self.state: return False - if self.__class__.content_type and self.__class__.content_type != self.content_type: + if ( + self.__class__.content_type + and self.__class__.content_type != self.content_type + ): return False return True @@ -321,26 +335,28 @@ def is_cls(cls, r): @classmethod def response_like(cls, resp): - status = getattr(resp, 'status', getattr(resp, 'status_code', None)) + status = getattr(resp, "status", getattr(resp, "status_code", None)) if status and isinstance(status, int): return True return False @property def schema_parser(self) -> Optional[ClassParser]: - return getattr(self, '__parser__', None) + return getattr(self, "__parser__", None) def init_headers(self, headers): if self.strict and self.schema_parser: - field = self.schema_parser.fields.get('headers') + field = self.schema_parser.fields.get("headers") if field: # resolve before parse self.schema_parser.resolve_forward_refs() - headers = field.parse_value(headers or {}, context=self.schema_parser.options.make_context()) + headers = field.parse_value( + headers or {}, context=self.schema_parser.options.make_context() + ) self.headers = Headers(headers or {}) def init_result(self, result): - if hasattr(result, '__next__'): + if hasattr(result, "__next__"): # convert generator yield result into list # result = list(result) result = get_generator_result(result) @@ -350,11 +366,13 @@ def init_result(self, result): result = self.result if self.strict and self.schema_parser: - field = self.schema_parser.fields.get('result') + field = self.schema_parser.fields.get("result") if field: # resolve before parse self.schema_parser.resolve_forward_refs() - result = field.parse_value(result, context=self.schema_parser.options.make_context()) + result = field.parse_value( + result, context=self.schema_parser.options.make_context() + ) if not self.adaptor and self.response_like(result): try: @@ -381,10 +399,11 @@ def init_file(self, file): return from utilmeta.core.file.backends.base import FileAdaptor from pathlib import Path + if isinstance(file, (str, Path)): self._filepath = str(file) self._filename = os.path.basename(str(file)) - self._file = FileAdaptor.dispatch(open(self._filepath, 'rb')) + self._file = FileAdaptor.dispatch(open(self._filepath, "rb")) return if file_like(file): self._file = FileAdaptor.dispatch(file) @@ -408,7 +427,7 @@ def init_error(self, error: Union[Error, Exception]): self.result = error.result if self.headers is None: self.headers = error.headers - if not self.message: # empty string '' + if not self.message: # empty string '' self.message = str(error.exception) if not self.is_aborted: error.log(console=True) @@ -421,7 +440,7 @@ def build_data(self): if self.result_key: data[self.result_key] = self.result if self.message_key: - data[self.message_key] = self.message or '' + data[self.message_key] = self.message or "" if self.state_key: data[self.state_key] = self.state if self.count_key: @@ -449,7 +468,7 @@ def _make_bytes(self, value): return value if isinstance(value, (memoryview, bytearray)): return bytes(value) - charset = self.charset or 'utf-8' + charset = self.charset or "utf-8" if isinstance(value, str): return value.encode(charset) # Handle non-string types. @@ -470,40 +489,45 @@ def build_content(self): if self._filename: # if there is file path and no content-disposition is set # we set it - content_disposition = self.headers.get('content-disposition') + content_disposition = self.headers.get("content-disposition") if not content_disposition: # set from urllib.parse import quote from pathlib import Path - disp = 'attachment' if self._as_attachment else 'inline' - self.set_header('content-disposition', - f'{disp}; filename="{quote(self._filename)}"') + + disp = "attachment" if self._as_attachment else "inline" + self.set_header( + "content-disposition", + f'{disp}; filename="{quote(self._filename)}"', + ) else: data = self.build_data() - if hasattr(data, '__iter__'): + if hasattr(data, "__iter__"): if multi(data) and not isinstance(data, list): data = list(data) - elif not isinstance(data, (bytes, memoryview, str, list, dict, set, tuple)): + elif not isinstance( + data, (bytes, memoryview, str, list, dict, set, tuple) + ): # must convert to list iterable # this data is guarantee that not file_like data = b"".join(self._make_bytes(chunk) for chunk in data) if hasattr(data, "close"): try: data.close() - except Exception: # noqa + except Exception: # noqa pass # self._data = data - if data is None or data == '': - data = b'' + if data is None or data == "": + data = b"" self._content = data self.build_content_type() def build_content_type(self): - content_type = self.headers.get('content-type') + content_type = self.headers.get("content-type") if content_type: self.content_type = content_type return - if hasattr(self._content, 'content_type'): + if hasattr(self._content, "content_type"): # like File self.content_type = self._content.content_type return @@ -532,14 +556,16 @@ def build_content_type(self): def filename(self): if self._filename: return self._filename - content_disposition = self.headers.get('content-disposition') + content_disposition = self.headers.get("content-disposition") if not content_disposition: return from urllib.parse import unquote + for part in unquote(content_disposition).split('filename="')[1:]: return part.strip('"') if self._filepath: from pathlib import Path + return Path(self._filepath).name return None @@ -576,24 +602,28 @@ def data(self): # return body def __str__(self): - reason = f' {self.reason}' if self.reason else '' - return f'{self.__class__.__name__} [{self.status}{reason}] ' \ - f'"{self.request.method.upper()} /%s"' % self.request.encoded_path.strip('/') \ - if self.request else f'{self.__class__.__name__} [{self.status}{reason}]' + reason = f" {self.reason}" if self.reason else "" + return ( + f"{self.__class__.__name__} [{self.status}{reason}] " + f'"{self.request.method.upper()} /%s"' + % self.request.encoded_path.strip("/") + if self.request + else f"{self.__class__.__name__} [{self.status}{reason}]" + ) def __repr__(self): return self.__str__() def _print(self, print_f): print(str(self)) - content_type = self.content_type or self.headers.get('content-type') + content_type = self.content_type or self.headers.get("content-type") if content_type: data = self.data content_length = self.content_length or len(str(data)) - print(f'{content_type} ({content_length or 0})') + print(f"{content_type} ({content_length or 0})") if data: print_f(data) - print('') + print("") def print(self): self._print(print) @@ -604,6 +634,7 @@ def pprint(self): @classmethod def dump_json(cls, content, encoder=None, ensure_ascii: bool = False, **kwargs): import json + kwargs.update(ensure_ascii=ensure_ascii) return json.dumps(content, cls=encoder or cls.__json_encoder_cls__, **kwargs) @@ -660,16 +691,16 @@ def raw_response(self, resp): self.adaptor = ResponseAdaptor.dispatch(resp) @property - def original_response(self) -> Optional['Response']: + def original_response(self) -> Optional["Response"]: # from 3xx redirect response, original_response is that 3xx response # including cached 304 responses if self._stack: return self._stack[0] return None - def push_response_stack(self, resp: 'Response'): + def push_response_stack(self, resp: "Response"): if not isinstance(resp, Response): - raise TypeError(f'Invalid response: {resp}') + raise TypeError(f"Invalid response: {resp}") self._stack.append(resp) @property @@ -763,7 +794,7 @@ def text(self) -> str: return self.adaptor.get_text() if isinstance(self._content, str): return self._content - return '' + return "" def set_header(self, name: str, value): self.headers[name] = value @@ -781,7 +812,7 @@ def set_cookie( domain: str = None, secure: bool = False, httponly: bool = False, - samesite: str = None + samesite: str = None, ): self.cookies[key] = value if expires is not None: @@ -790,28 +821,30 @@ def set_cookie( expires = http_time(expires) elif isinstance(expires, (int, float)): expires = http_time(datetime.utcfromtimestamp(expires), to_utc=False) - self.cookies[key]['expires'] = expires + self.cookies[key]["expires"] = expires else: - self.cookies[key]['expires'] = '' + self.cookies[key]["expires"] = "" if max_age is not None: - self.cookies[key]['max-age'] = int(max_age) + self.cookies[key]["max-age"] = int(max_age) if not expires: # IE requires expires, so set it if hasn't been already. - self.cookies[key]['expires'] = http_time(datetime.now() + timedelta(seconds=max_age)) + self.cookies[key]["expires"] = http_time( + datetime.now() + timedelta(seconds=max_age) + ) if path is not None: - self.cookies[key]['path'] = path + self.cookies[key]["path"] = path if domain is not None: - self.cookies[key]['domain'] = domain + self.cookies[key]["domain"] = domain if secure: - self.cookies[key]['secure'] = True + self.cookies[key]["secure"] = True if httponly: - self.cookies[key]['httponly'] = True + self.cookies[key]["httponly"] = True if samesite: - if samesite.lower() not in ('lax', 'none', 'strict'): + if samesite.lower() not in ("lax", "none", "strict"): raise ValueError('samesite must be "lax", "none", or "strict".') - self.cookies[key]['samesite'] = samesite.lower() + self.cookies[key]["samesite"] = samesite.lower() def delete_cookie(self, key: str, path: str = "/", domain: str = None) -> None: self.set_cookie(key, expires=0, max_age=0, path=path, domain=domain) @@ -821,16 +854,16 @@ def prepare_headers(self, with_content_type: bool = False) -> List[Tuple[str, st for key, val in self.headers.items(): if self.adaptor and is_hop_by_hop(key): continue - if str(key).lower() == 'content-type': + if str(key).lower() == "content-type": with_content_type = False header_values.append((str(key), str(val))) - if with_content_type and self.content_type and self._content: # non empty + if with_content_type and self.content_type and self._content: # non empty content_type = self.content_type if content_type and self.charset: - content_type = f'{content_type}; charset={self.charset}' - header_values.append(('Content-Type', content_type)) + content_type = f"{content_type}; charset={self.charset}" + header_values.append(("Content-Type", content_type)) for cookie in self.cookies.values(): - header_values.append(('Set-Cookie', cookie.OutputString())) + header_values.append(("Set-Cookie", cookie.OutputString())) return header_values def prepare_body(self): @@ -852,7 +885,7 @@ def prepare_body(self): body = self._content if not body: - return b'' + return b"" if self.is_json and not isinstance(body, (str, bytes)): try: return self.dump_json(body) @@ -874,7 +907,7 @@ def body(self) -> bytes: return body if not isinstance(body, str): body = str(body) - return body.encode(self.charset or 'utf-8', errors='replace') + return body.encode(self.charset or "utf-8", errors="replace") @property def error(self) -> Optional[Error]: @@ -889,15 +922,17 @@ def get_error(self): return self._error if self.success: return None - e = exc.HttpError.STATUS_EXCEPTIONS.get(self.status, exc.ServerError)(self.message) + e = exc.HttpError.STATUS_EXCEPTIONS.get(self.status, exc.ServerError)( + self.message + ) return Error(e, request=self.request) @property def traffic(self): if self._traffic: return self._traffic - value = 12 # HTTP/1.1 200 OK - value += len(str(self.status)) + len(str(self.reason or 'ok')) + value = 12 # HTTP/1.1 200 OK + value += len(str(self.status)) + len(str(self.reason or "ok")) value += self.content_length or 0 for key, val in self.headers.items(): value += len(str(key)) + len(str(val)) + 4 @@ -917,8 +952,10 @@ def mock(cls): try: from utype.utils.example import get_example_from_parser except ImportError: - raise NotImplementedError(f'Response.mock() not implemented, please upgrade utype') - parser = getattr(cls, '__parser__', None) + raise NotImplementedError( + f"Response.mock() not implemented, please upgrade utype" + ) + parser = getattr(cls, "__parser__", None) kwargs = {} if parser: kwargs = get_example_from_parser(parser) @@ -938,11 +975,11 @@ def validate(self, *_, **__): return self.success @classmethod - def server_error(cls, message=''): + def server_error(cls, message=""): return cls(message=message, status=500) @classmethod - def permission_denied(cls, message=''): + def permission_denied(cls, message=""): return cls(message=message, status=500) @classmethod @@ -962,11 +999,11 @@ def accepted(cls): return cls(status=202) @classmethod - def bad_request(cls, message=''): + def bad_request(cls, message=""): return cls(message=message, status=400) @classmethod - def not_found(cls, message=''): + def not_found(cls, message=""): return cls(message=message, status=404) @classmethod diff --git a/utilmeta/core/server/backends/apiflask.py b/utilmeta/core/server/backends/apiflask.py index 2757853..774c407 100644 --- a/utilmeta/core/server/backends/apiflask.py +++ b/utilmeta/core/server/backends/apiflask.py @@ -10,15 +10,23 @@ class APIFlaskServerAdaptor(FlaskServerAdaptor): application_cls = APIFlask app: APIFlask - def generate(self, spec: str = 'openapi'): - if spec == 'openapi': + def generate(self, spec: str = "openapi"): + if spec == "openapi": app: APIFlask = self.application() - return app._get_spec('json', force_update=True) + return app._get_spec("json", force_update=True) - def add_api(self, app: APIFlask, utilmeta_api_class, route: str = '', asynchronous: bool = False): - f = super().add_api(app, utilmeta_api_class, route=route, asynchronous=asynchronous) + def add_api( + self, + app: APIFlask, + utilmeta_api_class, + route: str = "", + asynchronous: bool = False, + ): + f = super().add_api( + app, utilmeta_api_class, route=route, asynchronous=asynchronous + ) # spec = getattr(f, '_spec', None) - f._spec = {'hide': True} + f._spec = {"hide": True} return f @property diff --git a/utilmeta/core/server/backends/base.py b/utilmeta/core/server/backends/base.py index e2cc4f3..c570494 100644 --- a/utilmeta/core/server/backends/base.py +++ b/utilmeta/core/server/backends/base.py @@ -1,4 +1,5 @@ from typing import TYPE_CHECKING, List + if TYPE_CHECKING: from utilmeta import UtilMeta from utilmeta.core.api import API @@ -14,6 +15,7 @@ def __init__(self, *args, **kwargs): self.args = args self.kwargs = kwargs from utilmeta import service + self.service = service def process_request(self, request: Request): @@ -27,24 +29,27 @@ class ServerAdaptor(BaseAdaptor): # __backends_route__ = 'backends' @classmethod - def reconstruct(cls, adaptor: 'BaseAdaptor'): + def reconstruct(cls, adaptor: "BaseAdaptor"): pass - def adapt(self, api: 'API', route: str, asynchronous: bool = None): + def adapt(self, api: "API", route: str, asynchronous: bool = None): raise NotImplementedError @classmethod - def get_module_name(cls, obj: 'UtilMeta'): + def get_module_name(cls, obj: "UtilMeta"): if inspect.ismodule(obj): # maybe the backend return obj.__name__ return super().get_module_name(obj.backend) @classmethod - def qualify(cls, obj: 'UtilMeta'): + def qualify(cls, obj: "UtilMeta"): if not cls.backend or not obj.backend: return False - return cls.get_module_name(obj.backend).lower() == cls.get_module_name(cls.backend).lower() + return ( + cls.get_module_name(obj.backend).lower() + == cls.get_module_name(cls.backend).lower() + ) backend = None default_asynchronous = False @@ -57,7 +62,7 @@ def qualify(cls, obj: 'UtilMeta'): async_cache_adaptor_cls = None DEFAULT_PORT = 8000 - def __init__(self, config: 'UtilMeta'): + def __init__(self, config: "UtilMeta"): self.root = None self.config = config self.background = config.background @@ -74,29 +79,29 @@ def __init__(self, config: 'UtilMeta'): def root_pattern(self): if not self.config.root_url: return None - return re.compile('%s/(.*)' % self.config.root_url.strip('/')) + return re.compile("%s/(.*)" % self.config.root_url.strip("/")) @property def root_path(self) -> str: - return '' + return "" @property def version(self) -> str: - return '' + return "" @property def production(self) -> bool: return False def load_route(self, path: str): - path = (path or '').strip('/') + path = (path or "").strip("/") if not self.config.root_url: return path match = self.root_pattern.match(path) if match: return match.groups()[0] if path == self.config.root_url: - return '' + return "" if self.config.preference.strict_root_route: raise exceptions.NotFound return path @@ -113,6 +118,7 @@ def worker_post_fork(self): def apply_fork(self): try: import uwsgidecorators # noqa + uwsgidecorators.postfork(self.worker_post_fork) except ModuleNotFoundError: pass @@ -133,7 +139,7 @@ def mount(self, app, route: str): def application(self): pass - def generate(self, spec: str = 'openapi'): + def generate(self, spec: str = "openapi"): raise NotImplementedError @property @@ -147,7 +153,7 @@ def async_startup(self) -> bool: @classmethod def is_asgi(cls, app): if not inspect.isfunction(app): - app = getattr(app, '__call__', None) + app = getattr(app, "__call__", None) if not app: return False return inspect.iscoroutinefunction(app) @@ -163,4 +169,4 @@ def add_middleware(self, middleware): if isinstance(middleware, ServiceMiddleware): self.middlewares.append(middleware) else: - raise NotImplementedError(f'middleware of {middleware} no implemented') + raise NotImplementedError(f"middleware of {middleware} no implemented") diff --git a/utilmeta/core/server/backends/django/adaptor.py b/utilmeta/core/server/backends/django/adaptor.py index 3932b9c..30dea67 100644 --- a/utilmeta/core/server/backends/django/adaptor.py +++ b/utilmeta/core/server/backends/django/adaptor.py @@ -27,13 +27,14 @@ from django.core.handlers.asgi import ASGIHandler from django.core.handlers.wsgi import WSGIHandler -_current_request = contextvars.ContextVar('_django.request') -_current_response = contextvars.ContextVar('_django.response') +_current_request = contextvars.ContextVar("_django.request") +_current_response = contextvars.ContextVar("_django.response") try: from django.utils.decorators import sync_and_async_middleware except ImportError: + def sync_and_async_middleware(func): """ Mark a middleware factory as returning a hybrid middleware supporting both @@ -50,9 +51,9 @@ def process_response(self, request, response: HttpResponseBase): if origin and localhost(origin): for key in response.cookies.keys(): cookie = response.cookies[key] - pop(cookie, 'domain') + pop(cookie, "domain") if not localhost(request.get_host()): - cookie['samesite'] = 'None' + cookie["samesite"] = "None" return response @@ -63,9 +64,9 @@ class DjangoServerAdaptor(ServerAdaptor): sync_db_adaptor_cls = DjangoDatabaseAdaptor sync_cache_adaptor_cls = DjangoCacheAdaptor default_asynchronous = False - URLPATTERNS = 'urlpatterns' + URLPATTERNS = "urlpatterns" DEFAULT_PORT = 8000 - DEFAULT_HOST = '127.0.0.1' + DEFAULT_HOST = "127.0.0.1" settings_cls = DjangoSettings ASYNC_SUPPORTED = django.VERSION >= (3, 1) @@ -73,7 +74,11 @@ def __init__(self, config: UtilMeta): super().__init__(config) self._ready = False self.settings = config.get_config(self.settings_cls) or self.settings_cls() - self.app = config._application if isinstance(config._application, BaseHandler) else None + self.app = ( + config._application + if isinstance(config._application, BaseHandler) + else None + ) if self.app: if isinstance(self.app, ASGIHandler): self.asynchronous = self.config.asynchronous = True @@ -83,7 +88,7 @@ def __init__(self, config: UtilMeta): self._mounts = {} def load_route(self, path: str): - return (path or '').strip('/') + return (path or "").strip("/") def setup(self): if self._ready: @@ -93,7 +98,7 @@ def setup(self): root_api = self.config.resolve() if self.config.root_url: - url_pattern = rf'^{self.config.root_url}(\/.*)?$' + url_pattern = rf"^{self.config.root_url}(\/.*)?$" else: url_pattern = root_api._get_route_pattern() @@ -101,7 +106,7 @@ def setup(self): root_api, route=url_pattern, asynchronous=self.asynchronous, - top=bool(url_pattern) + top=bool(url_pattern), ) self.setup_middlewares() self.apply_fork() @@ -120,12 +125,12 @@ def setup_middlewares(self): if not hasattr(self.settings.module, func.__name__): setattr(self.settings.module, func.__name__, func) self.settings.merge_list_settings( - 'MIDDLEWARE', [f'{self.settings.module_name}.{func.__name__}'] + "MIDDLEWARE", [f"{self.settings.module_name}.{func.__name__}"] ) if self.app: self.app.load_middleware(is_async=self.asynchronous) else: - raise ValueError(f'setup django middleware failed: settings not loaded') + raise ValueError(f"setup django middleware failed: settings not loaded") @property def backend_views_empty(self) -> bool: @@ -134,7 +139,7 @@ def backend_views_empty(self) -> bool: urls = getattr(self.settings.url_conf, self.URLPATTERNS, []) for url in urls: if isinstance(url, URLPattern): - wrapped = getattr(url.callback, '__wrapped__', None) + wrapped = getattr(url.callback, "__wrapped__", None) if wrapped and isinstance(wrapped, type) and issubclass(wrapped, API): pass else: @@ -184,14 +189,11 @@ def process_response(self, django_response): if not isinstance(response, Response): response = Response( - response=response_adaptor_cls(django_response), - request=request + response=response_adaptor_cls(django_response), request=request ) else: if not response.adaptor: - response.adaptor = response_adaptor_cls( - django_response - ) + response.adaptor = response_adaptor_cls(django_response) response_updated = False for middleware in middlewares: @@ -220,13 +222,20 @@ def __call__(self, request): @sync_and_async_middleware def utilmeta_middleware(get_response): # One-time configuration and initialization goes here. - if self.asynchronous and inspect.iscoroutinefunction(get_response) and self.ASYNC_SUPPORTED: + if ( + self.asynchronous + and inspect.iscoroutinefunction(get_response) + and self.ASYNC_SUPPORTED + ): + async def middleware_func(request): middleware = UtilMetaMiddleware(get_response) # Do something here! response = await middleware.__acall__(request) return response + else: + def middleware_func(request): middleware = UtilMetaMiddleware(get_response) # Do something here! @@ -241,41 +250,49 @@ def check_application(self): wsgi_app = self.settings.wsgi_app if not wsgi_app: if self.config.production: - raise ValueError(f'Django wsgi application not specified, you should use ' - f'{self.settings.wsgi_app_attr or "app"} = service.application() ' - f'in {self.settings.wsgi_module_ref or "your service file"}') + raise ValueError( + f"Django wsgi application not specified, you should use " + f'{self.settings.wsgi_app_attr or "app"} = service.application() ' + f'in {self.settings.wsgi_module_ref or "your service file"}' + ) else: wsgi_module = self.settings.wsgi_module if not wsgi_module: if self.settings.wsgi_application: - raise ValueError(f'Invalid Django WSGI_APPLICATION: ' - f'{repr(self.settings.wsgi_application)}') - raise ValueError(f'Django WSGI_APPLICATION not specified') + raise ValueError( + f"Invalid Django WSGI_APPLICATION: " + f"{repr(self.settings.wsgi_application)}" + ) + raise ValueError(f"Django WSGI_APPLICATION not specified") if not self.settings.wsgi_app_attr: - raise ValueError(f'Django WSGI_APPLICATION not specified or invalid:' - f' {repr(self.settings.wsgi_application)}') - warnings.warn('Django application not specified, auto-assigning, you should use ' - f'{self.settings.wsgi_app_attr or "app"} = service.application() ' - f'in {self.settings.wsgi_module_ref or "your service file"} at production') + raise ValueError( + f"Django WSGI_APPLICATION not specified or invalid:" + f" {repr(self.settings.wsgi_application)}" + ) + warnings.warn( + "Django application not specified, auto-assigning, you should use " + f'{self.settings.wsgi_app_attr or "app"} = service.application() ' + f'in {self.settings.wsgi_module_ref or "your service file"} at production' + ) setattr(wsgi_module, self.settings.wsgi_app_attr, self.application()) def mount(self, app, route: str): - urls_attr = getattr(app, 'urls', None) + urls_attr = getattr(app, "urls", None) if not urls_attr or not isinstance(urls_attr, (list, tuple)): - raise TypeError('Invalid application to mount to django, anyone with "urls" attribute is supported, ' - 'such as NinjaAPI in django-ninja or DefaultRouter in django-rest-framework') + raise TypeError( + 'Invalid application to mount to django, anyone with "urls" attribute is supported, ' + "such as NinjaAPI in django-ninja or DefaultRouter in django-rest-framework" + ) if all(isinstance(pattern, URLPattern) for pattern in urls_attr): - urls_attr = include((urls_attr, route.strip('/'))) + urls_attr = include((urls_attr, route.strip("/"))) # to mount django-ninja app or django-rest-framework router urls = getattr(self.settings.url_conf, self.URLPATTERNS, []) - urls.append( - path(route.strip('/') + '/', urls_attr) - ) + urls.append(path(route.strip("/") + "/", urls_attr)) setattr(self.settings.url_conf, self.URLPATTERNS, urls) self._mounts[route] = app - def adapt(self, api: 'API', route: str, asynchronous: bool = None): + def adapt(self, api: "API", route: str, asynchronous: bool = None): if asynchronous is None: asynchronous = self.default_asynchronous # func = self._get_api(api, asynchronous=asynchronous) @@ -283,16 +300,25 @@ def adapt(self, api: 'API', route: str, asynchronous: bool = None): # return re_path(path, func) self.add_api(api, route=route, asynchronous=asynchronous) - def add_api(self, utilmeta_api_class, route: str = '', asynchronous: bool = False, top: bool = False): + def add_api( + self, + utilmeta_api_class, + route: str = "", + asynchronous: bool = False, + top: bool = False, + ): api = self._get_api(utilmeta_api_class, asynchronous=asynchronous) urls = getattr(self.settings.url_conf, self.URLPATTERNS, []) find = False for url in urls: if isinstance(url, URLPattern): if str(url.pattern) == str(route): - wrapped = getattr(url.callback, '__wrapped__', None) + wrapped = getattr(url.callback, "__wrapped__", None) if wrapped: - if wrapped == utilmeta_api_class or wrapped.__qualname__ == utilmeta_api_class.__qualname__: + if ( + wrapped == utilmeta_api_class + or wrapped.__qualname__ == utilmeta_api_class.__qualname__ + ): find = True break if find: @@ -310,18 +336,22 @@ def _get_api(self, utilmeta_api_class, asynchronous: bool = False): make sure it is called after all your fastapi route is set """ from utilmeta.core.api.base import API + if not issubclass(utilmeta_api_class, API): - raise TypeError(f'Invalid api class: {utilmeta_api_class}') + raise TypeError(f"Invalid api class: {utilmeta_api_class}") if asynchronous and self.ASYNC_SUPPORTED: - async def f(request, route: str = '', *args, **kwargs): + + async def f(request, route: str = "", *args, **kwargs): req = None try: req = _current_request.get(None) route = self.load_route(route) if not isinstance(req, Request): - req = Request(self.request_adaptor_cls(request, route, *args, **kwargs)) + req = Request( + self.request_adaptor_cls(request, route, *args, **kwargs) + ) else: req.adaptor.route = route req.adaptor.request = request @@ -329,18 +359,24 @@ async def f(request, route: str = '', *args, **kwargs): root = utilmeta_api_class(req) resp = await root() except Exception as e: - resp = getattr(utilmeta_api_class, 'response', Response)(error=e, request=req) + resp = getattr(utilmeta_api_class, "response", Response)( + error=e, request=req + ) _current_response.set(resp) return self.response_adaptor_cls.reconstruct(resp) + else: - def f(request, route: str = '', *args, **kwargs): + + def f(request, route: str = "", *args, **kwargs): req = None try: req = _current_request.get(None) route = self.load_route(route) if not isinstance(req, Request): - req = Request(self.request_adaptor_cls(request, route, *args, **kwargs)) + req = Request( + self.request_adaptor_cls(request, route, *args, **kwargs) + ) else: req.adaptor.route = route req.adaptor.request = request @@ -348,7 +384,9 @@ def f(request, route: str = '', *args, **kwargs): root = utilmeta_api_class(req) resp = root() except Exception as e: - resp = getattr(utilmeta_api_class, 'response', Response)(error=e, request=req) + resp = getattr(utilmeta_api_class, "response", Response)( + error=e, request=req + ) _current_response.set(resp) return self.response_adaptor_cls.reconstruct(resp) @@ -379,11 +417,11 @@ def run(self): self.config.startup() if self.asynchronous: try: - from daphne.server import Server # noqa + from daphne.server import Server # noqa except ModuleNotFoundError: pass else: - print('using [daphne] as asgi server') + print("using [daphne] as asgi server") try: Server( application=self.application(), @@ -395,9 +433,11 @@ def run(self): return from utilmeta.utils import requires - requires('uvicorn') + + requires("uvicorn") import uvicorn - print('using [uvicorn] as asgi server') + + print("using [uvicorn] as asgi server") try: uvicorn.run( self.application(), @@ -409,11 +449,19 @@ def run(self): return if self.config.production: - server = 'asgi (like uvicorn/daphne)' if self.asynchronous else 'wsgi (like uwsgi/gunicorn)' - raise ValueError(f'django in production cannot use service.run(), please use an {server} server') + server = ( + "asgi (like uvicorn/daphne)" + if self.asynchronous + else "wsgi (like uwsgi/gunicorn)" + ) + raise ValueError( + f"django in production cannot use service.run(), please use an {server} server" + ) else: if self.asynchronous: - raise ValueError(f'django debug runserver does not support asgi, please use an asgi server') + raise ValueError( + f"django debug runserver does not support asgi, please use an asgi server" + ) try: self.runserver() finally: @@ -422,33 +470,39 @@ def run(self): @property def location(self): - return f'{self.config.host or self.DEFAULT_HOST}:{self.config.port}' + return f"{self.config.host or self.DEFAULT_HOST}:{self.config.port}" @property def production(self) -> bool: - return getattr(self.settings.django_settings, 'DEBUG', None) is False + return getattr(self.settings.django_settings, "DEBUG", None) is False @property def daphne_endpoint(self): - return f"tcp:{self.config.port}:interface={self.config.host or self.DEFAULT_HOST}" + return ( + f"tcp:{self.config.port}:interface={self.config.host or self.DEFAULT_HOST}" + ) def runserver(self): # debug server - argv = [sys.argv[0], 'runserver', self.location] # if len(sys.argv) == 1 else sys.argv + argv = [ + sys.argv[0], + "runserver", + self.location, + ] # if len(sys.argv) == 1 else sys.argv # if 'runserver' in argv: if not self.config.auto_reload: - argv.append('--noreload') + argv.append("--noreload") execute_from_command_line(argv) @classmethod - def get_drf_openapi( - cls, - title=None, url=None, description=None, version=None - ): + def get_drf_openapi(cls, title=None, url=None, description=None, version=None): from rest_framework.schemas.openapi import SchemaGenerator - generator = SchemaGenerator(title=title, url=url, description=description, version=version) - def generator_func(service: 'UtilMeta'): + generator = SchemaGenerator( + title=title, url=url, description=description, version=version + ) + + def generator_func(service: "UtilMeta"): return generator.get_schema(public=True) return generator_func @@ -458,24 +512,27 @@ def get_django_ninja_openapi(cls): from ninja.openapi.schema import get_schema from ninja import NinjaAPI - def generator_func(service: 'UtilMeta'): + def generator_func(service: "UtilMeta"): app = service.application() if isinstance(app, NinjaAPI): return get_schema(app) - raise TypeError(f'Invalid application: {app} for django ninja. NinjaAPI() instance expected') + raise TypeError( + f"Invalid application: {app} for django ninja. NinjaAPI() instance expected" + ) return generator_func - def generate(self, spec: str = 'openapi'): - if spec == 'openapi': + def generate(self, spec: str = "openapi"): + if spec == "openapi": if self.settings.django_settings: - if 'rest_framework' in self.settings.django_settings.INSTALLED_APPS: + if "rest_framework" in self.settings.django_settings.INSTALLED_APPS: # 1. try drf from rest_framework.schemas.openapi import SchemaGenerator + generator = SchemaGenerator( title=self.config.title, description=self.config.description, - version=self.config.version + version=self.config.version, ) return generator.get_schema(public=True) diff --git a/utilmeta/core/server/backends/django/cmd.py b/utilmeta/core/server/backends/django/cmd.py index 8bf1e37..f95750a 100644 --- a/utilmeta/core/server/backends/django/cmd.py +++ b/utilmeta/core/server/backends/django/cmd.py @@ -10,17 +10,17 @@ from .settings import DjangoSettings -initial_file = '0001_initial' +initial_file = "0001_initial" class DjangoCommand(BaseServiceCommand): - name = 'django' + name = "django" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.settings = self.service.get_config(DjangoSettings) # self.settings.setup(self.service) - self.service.setup() # setup here + self.service.setup() # setup here @command def add(self, name: str): @@ -30,48 +30,58 @@ def add(self, name: str): """ path = self.settings.apps_path or self.cwd app = os.path.join(path, name) - migrations = os.path.join(app, 'migrations') + migrations = os.path.join(app, "migrations") init = os.path.join(app, INIT_FILE) m_init = os.path.join(migrations, INIT_FILE) - api = os.path.join(app, 'api.py') - models = os.path.join(app, 'models.py') - schema = os.path.join(app, 'schema.py') + api = os.path.join(app, "api.py") + models = os.path.join(app, "models.py") + schema = os.path.join(app, "schema.py") if os.path.exists(app): - print(RED % f'meta Error: target application directory: {app} is already exists') + print( + RED + % f"meta Error: target application directory: {app} is already exists" + ) exit(1) os.makedirs(migrations) - write_to(m_init, content='') - write_to(init, content='') - write_to(api, content='from utilmeta.core import api, orm, request, response\n') - write_to(models, content='from django.db import models\n') - write_to(schema, content='import utype\nfrom utilmeta.core import orm') + write_to(m_init, content="") + write_to(init, content="") + write_to(api, content="from utilmeta.core import api, orm, request, response\n") + write_to(models, content="from django.db import models\n") + write_to(schema, content="import utype\nfrom utilmeta.core import orm") - print(f"meta: django application: <{BLUE % name}> successfully added to path: {app}") + print( + f"meta: django application: <{BLUE % name}> successfully added to path: {app}" + ) @command - def makemigrations(self, app_label: str = None, all: bool = Arg('--all', '-a', default=False)): + def makemigrations( + self, app_label: str = None, all: bool = Arg("--all", "-a", default=False) + ): """ execute django makemigrations command to generate migrations files """ - if app_label == '*': + if app_label == "*": all = True - args = ['meta', 'makemigrations'] + args = ["meta", "makemigrations"] execute_from_command_line(args) if all: for app in self.settings.app_labels(): db = self.settings.get_db(app) for d in db.dbs: # include master and replicas - if d.alias == 'default': + if d.alias == "default": continue - execute_from_command_line([*args, app, f'--database={d.alias}']) + execute_from_command_line([*args, app, f"--database={d.alias}"]) return @command - def fixmigrations(self, app_name: str, database: str = Arg('--database', default=None)): + def fixmigrations( + self, app_name: str, database: str = Arg("--database", default=None) + ): # if the db's stage is not corresponding to the django_migration from django.apps.registry import apps, AppConfig + dbs = self.service.get_config(DatabaseConnections) for alias, db in dbs.items(): if database and database != alias: @@ -80,11 +90,12 @@ def fixmigrations(self, app_name: str, database: str = Arg('--database', default from django.db.migrations.executor import MigrationExecutor from django.db import connections from django.db import DatabaseError + conn = connections[alias] executor = MigrationExecutor(conn) - targets = sorted([ - key[1] for key in executor.loader.graph.nodes if key[0] == app_name - ]) + targets = sorted( + [key[1] for key in executor.loader.graph.nodes if key[0] == app_name] + ) created_migrations = [] try: @@ -92,9 +103,13 @@ def fixmigrations(self, app_name: str, database: str = Arg('--database', default migrated = 0 for i, target in enumerate(targets): try: - if not executor.recorder.migration_qs.filter(app=app_name, name=target).exists(): + if not executor.recorder.migration_qs.filter( + app=app_name, name=target + ).exists(): migrated += 1 - print(f'[{app_name}] migrating {repr(target)} at database [{repr(alias)}]') + print( + f"[{app_name}] migrating {repr(target)} at database [{repr(alias)}]" + ) # target_executor = MigrationExecutor(conn) target_executor = MigrationExecutor(conn) # reload from db @@ -102,37 +117,49 @@ def fixmigrations(self, app_name: str, database: str = Arg('--database', default if not plan: return target_executor.migrate(targets, plan) - print(f'[{app_name} migrating {repr(target)} completed at database [{repr(alias)}]') + print( + f"[{app_name} migrating {repr(target)} completed at database [{repr(alias)}]" + ) except DatabaseError as e: - print(f'[{app_name}] migrate unrecorded migration: {repr(target)} ' - f'failed with error: {e}, save migration to [{repr(alias)}] and skip to next') + print( + f"[{app_name}] migrate unrecorded migration: {repr(target)} " + f"failed with error: {e}, save migration to [{repr(alias)}] and skip to next" + ) # close and reconnect database (or migration won't be saved) conn.close() conn.connect() - created_migrations.append(executor.recorder.migration_qs.create( - name=target, - app=app_name, - )) + created_migrations.append( + executor.recorder.migration_qs.create( + name=target, + app=app_name, + ) + ) if i == len(targets) - 1: cfg: AppConfig = apps.get_app_config(app_name) - print('testing schemas of all models...') + print("testing schemas of all models...") for model in cfg.get_models(): try: _ = model.objects.values()[:1] except DatabaseError as _e: - print(f'model: {model} load failed: {_e}') + print(f"model: {model} load failed: {_e}") raise e - print(F'all missing migrations recorded: {created_migrations}') + print( + f"all missing migrations recorded: {created_migrations}" + ) break # if all models is ok, we stop trying if not migrated: - print(f'[{app_name}] migrations is clean at database [{repr(alias)}]') + print( + f"[{app_name}] migrations is clean at database [{repr(alias)}]" + ) # 2. if the recorded migration is not migrated except Exception as e: - print(f'[{app_name}] fix migrations at database: {repr(alias)} failed: {repr(e)}]') + print( + f"[{app_name}] fix migrations at database: {repr(alias)} failed: {repr(e)}]" + ) if created_migrations: - print(f'[{app_name}] deleting created migrations') + print(f"[{app_name}] deleting created migrations") for migration in created_migrations: try: migration.delete() @@ -145,10 +172,11 @@ def mergemigrations(self, app_name: str): merge all migrations files to minimum for django apps """ from django.apps.registry import apps, AppConfig + dbs = self.service.get_config(DatabaseConnections) - if app_name == '*' or app_name == '__all__': - print('merge all apps:') + if app_name == "*" or app_name == "__all__": + print("merge all apps:") # for key, cfg in apps.app_configs.items(): # cfg: AppConfig # self.mergemigrations(cfg.label) @@ -158,7 +186,7 @@ def mergemigrations(self, app_name: str): continue cfg: AppConfig # if cfg.label == app_name: - migrations_path = os.path.join(cfg.path, 'migrations') + migrations_path = os.path.join(cfg.path, "migrations") files = next(os.walk(migrations_path))[2] for file in files: if file.startswith(SEG): @@ -167,53 +195,60 @@ def mergemigrations(self, app_name: str): for alias, db in dbs.items(): from django.db import connections + with connections[alias].cursor() as cursor: cursor.execute("DELETE FROM django_migrations WHERE name != ''") - execute_from_command_line(['meta', 'makemigrations']) - print(f'migrations for all app has merged') + execute_from_command_line(["meta", "makemigrations"]) + print(f"migrations for all app has merged") for key, cfg in apps.app_configs.items(): cfg: AppConfig # if cfg.label == app_name: - migrations_path = os.path.join(cfg.path, 'migrations') + migrations_path = os.path.join(cfg.path, "migrations") files = next(os.walk(migrations_path))[2] for file in files: if file.startswith(SEG): continue - file_name = str(file).rstrip('.py') + file_name = str(file).rstrip(".py") for alias, db in dbs.items(): from django.db import connections + with connections[alias].cursor() as cursor: - cursor.execute("INSERT INTO django_migrations " - "(app, name, applied) values ('%s', '%s', '%s')" - % (cfg.label, file_name, str(datetime.datetime.now()))) + cursor.execute( + "INSERT INTO django_migrations " + "(app, name, applied) values ('%s', '%s', '%s')" + % (cfg.label, file_name, str(datetime.datetime.now())) + ) # os.remove(os.path.join(migrations_path, file)) return - print('merging for app:', app_name) + print("merging for app:", app_name) for alias, db in dbs.items(): from django.db import connections + with connections[alias].cursor() as cursor: - cursor.execute("DELETE FROM django_migrations WHERE app='%s' and name != '%s'" # noqa - % (app_name, initial_file)) + cursor.execute( + "DELETE FROM django_migrations WHERE app='%s' and name != '%s'" # noqa + % (app_name, initial_file) + ) for key, cfg in apps.app_configs.items(): cfg: AppConfig if cfg.label == app_name: - migrations_path = os.path.join(cfg.path, 'migrations') + migrations_path = os.path.join(cfg.path, "migrations") files = next(os.walk(migrations_path))[2] for file in files: if file.startswith(SEG): continue os.remove(os.path.join(migrations_path, file)) - execute_from_command_line(['meta', 'makemigrations']) - print(f'migrations for app: <{app_name}> has merged') + execute_from_command_line(["meta", "makemigrations"]) + print(f"migrations for app: <{app_name}> has merged") # @property # def exec_args(self): @@ -221,4 +256,5 @@ def mergemigrations(self, app_name: str): def fallback(self): import sys + execute_from_command_line(sys.argv) diff --git a/utilmeta/core/server/backends/django/settings.py b/utilmeta/core/server/backends/django/settings.py index 658ea06..4a8d82b 100644 --- a/utilmeta/core/server/backends/django/settings.py +++ b/utilmeta/core/server/backends/django/settings.py @@ -33,62 +33,63 @@ # "django.contrib.staticfiles", ] DEFAULT_DB_ENGINE = { - 'sqlite': 'django.db.backends.sqlite3', - 'oracle': 'django.db.backends.oracle', - 'mysql': 'django.db.backends.mysql', - 'postgres': 'django.db.backends.postgresql' + "sqlite": "django.db.backends.sqlite3", + "oracle": "django.db.backends.oracle", + "mysql": "django.db.backends.mysql", + "postgres": "django.db.backends.postgresql", } WSGI_APPLICATION = "WSGI_APPLICATION" ASGI_APPLICATION = "ASGI_APPLICATION" ROOT_URLCONF = "ROOT_URLCONF" DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField" -SETTINGS_MODULE = 'DJANGO_SETTINGS_MODULE' +SETTINGS_MODULE = "DJANGO_SETTINGS_MODULE" DEFAULT_LANGUAGE_CODE = "en-us" DEFAULT_TIME_ZONE = "UTC" DEFAULT_USE_I18N = True DEFAULT_USE_TZ = True -DB = 'django.core.cache.backends.db.DatabaseCache' -FILE = 'django.core.cache.backends.filebased.FileBasedCache' -DUMMY = 'django.core.cache.backends.dummy.DummyCache' -LOCMEM = 'django.core.cache.backends.locmem.LocMemCache' -MEMCACHED = 'django.core.cache.backends.memcached.MemcachedCache' -PYLIBMC = 'django.core.cache.backends.memcached.PyLibMCCache' -DJANGO_REDIS = 'django.core.cache.backends.redis.RedisCache' +DB = "django.core.cache.backends.db.DatabaseCache" +FILE = "django.core.cache.backends.filebased.FileBasedCache" +DUMMY = "django.core.cache.backends.dummy.DummyCache" +LOCMEM = "django.core.cache.backends.locmem.LocMemCache" +MEMCACHED = "django.core.cache.backends.memcached.MemcachedCache" +PYLIBMC = "django.core.cache.backends.memcached.PyLibMCCache" +DJANGO_REDIS = "django.core.cache.backends.redis.RedisCache" CACHE_BACKENDS = { - 'db': DB, - 'database': DB, - 'file': FILE, - 'locmem': LOCMEM, - 'memcached': MEMCACHED, - 'redis': DJANGO_REDIS, - 'pylibmc': PYLIBMC + "db": DB, + "database": DB, + "file": FILE, + "locmem": LOCMEM, + "memcached": MEMCACHED, + "redis": DJANGO_REDIS, + "pylibmc": PYLIBMC, } class DjangoSettings(Config): def __init__( - self, - module_name: str = None, *, - # current settings module for django project - root_urlconf: str = None, - # current url conf (if there is an exists django project) - secret_key: str = None, - apps_package: str = None, - # package ref (such as 'domain' / 'service.applications') - apps: Union[tuple, List[str]] = (), - database_routers: tuple = (), - allowed_hosts: list = (), - middleware: Union[tuple, List[str]] = (), - default_autofield: str = None, - wsgi_application: str = None, - # time_zone: str = None, - # use_tz: bool = None, - user_i18n: bool = None, - language: str = None, - append_slash: bool = False, - extra: dict = None, - # urlpatterns: list = None, + self, + module_name: str = None, + *, + # current settings module for django project + root_urlconf: str = None, + # current url conf (if there is an exists django project) + secret_key: str = None, + apps_package: str = None, + # package ref (such as 'domain' / 'service.applications') + apps: Union[tuple, List[str]] = (), + database_routers: tuple = (), + allowed_hosts: list = (), + middleware: Union[tuple, List[str]] = (), + default_autofield: str = None, + wsgi_application: str = None, + # time_zone: str = None, + # use_tz: bool = None, + user_i18n: bool = None, + language: str = None, + append_slash: bool = False, + extra: dict = None, + # urlpatterns: list = None, ): super().__init__(locals()) self.module_name = module_name @@ -124,16 +125,17 @@ def __init__( self.load_apps() def register(self, plugin): - getter = getattr(plugin, 'as_django', None) + getter = getattr(plugin, "as_django", None) if callable(getter): plugin_settings = getter() if not isinstance(plugin_settings, dict): - raise TypeError(f'Invalid settings: {plugin_settings}') + raise TypeError(f"Invalid settings: {plugin_settings}") self._plugin_settings.update(plugin_settings) if self.module: # already set self._settings.update(plugin_settings) from django.conf import settings + for attr, value in plugin_settings.items(): setattr(self.module, attr, value) setattr(settings, attr, value) @@ -148,6 +150,7 @@ def apps_path(self): @classmethod def app_labels(cls) -> List[str]: from django.apps.registry import apps + labels = [] for key, cfg in apps.app_configs.items(): labels.append(cfg.label) @@ -155,7 +158,7 @@ def app_labels(cls) -> List[str]: def get_db(self, app_label: str): # TODO - return 'default' + return "default" def get_secret(self, service: UtilMeta): if self.secret_key: @@ -166,13 +169,16 @@ def get_secret(self, service: UtilMeta): import hashlib import warnings import utilmeta + if service.production: - raise ValueError(f'django: secret_key not set for production') + raise ValueError(f"django: secret_key not set for production") else: - warnings.warn('django: secret_key not set, auto generating') - tag = f'{service.name}:{service.description}:{service.version_str}' \ - f'{service.backend_name}:{service.module_name}' \ - f'{django.__version__}{utilmeta.__version__}{sys.version}{platform.platform()}'.encode() + warnings.warn("django: secret_key not set, auto generating") + tag = ( + f"{service.name}:{service.description}:{service.version_str}" + f"{service.backend_name}:{service.module_name}" + f"{django.__version__}{utilmeta.__version__}{sys.version}{platform.platform()}".encode() + ) return hashlib.sha256(tag).hexdigest() def load_apps(self): @@ -181,9 +187,9 @@ def load_apps(self): if self.apps_package: apps_path = self.apps_path - hosted_labels = [p for p in next(os.walk(apps_path))[1] if '__' not in p] + hosted_labels = [p for p in next(os.walk(apps_path))[1] if "__" not in p] for app in hosted_labels: - label = f'{self.apps_package}.{app}' + label = f"{self.apps_package}.{app}" if label not in installed_apps: installed_apps.append(label) @@ -193,57 +199,57 @@ def load_apps(self): @classmethod def get_cache(cls, cache: Cache): return { - 'BACKEND': CACHE_BACKENDS.get(cache.engine) or cache.engine, - 'LOCATION': cache.get_location(), - 'OPTIONS': cache.options or {}, - 'KEY_FUNCTION': cache.key_function, - 'KEY_PREFIX': cache.prefix, - 'TIMEOUT': cache.timeout, - 'MAX_ENTRIES': cache.max_entries, + "BACKEND": CACHE_BACKENDS.get(cache.engine) or cache.engine, + "LOCATION": cache.get_location(), + "OPTIONS": cache.options or {}, + "KEY_FUNCTION": cache.key_function, + "KEY_PREFIX": cache.prefix, + "TIMEOUT": cache.timeout, + "MAX_ENTRIES": cache.max_entries, } @classmethod def get_time(cls, time_config: Time): return { - 'DATETIME_FORMAT': time_config.datetime_format, - 'DATE_FORMAT': time_config.date_format, - 'TIME_ZONE': time_config.time_zone or DEFAULT_TIME_ZONE, - 'USE_TZ': time_config.use_tz, + "DATETIME_FORMAT": time_config.datetime_format, + "DATE_FORMAT": time_config.date_format, + "TIME_ZONE": time_config.time_zone or DEFAULT_TIME_ZONE, + "USE_TZ": time_config.use_tz, } @classmethod def get_database(cls, db: Database, service: UtilMeta): engine = db.engine - if '.' not in db.engine: + if "." not in db.engine: for name, eg in DEFAULT_DB_ENGINE.items(): if name.lower() in engine.lower(): - if name == 'postgres' and service.asynchronous and django.VERSION >= (4, 2): + if ( + name == "postgres" + and service.asynchronous + and django.VERSION >= (4, 2) + ): # COMPAT DJANGO > 4.2 - engine = 'utilmeta.core.server.backends.django.postgresql' + engine = "utilmeta.core.server.backends.django.postgresql" else: engine = eg break options = {} if db.ssl: - options['sslmode'] = 'require' - if 'sqlite' in engine: - return { - 'ENGINE': engine, - 'NAME': str(db.name), - 'OPTIONS': options - } + options["sslmode"] = "require" + if "sqlite" in engine: + return {"ENGINE": engine, "NAME": str(db.name), "OPTIONS": options} return { - 'ENGINE': engine, - 'HOST': db.host, - 'PORT': db.port, - 'NAME': db.name, - 'USER': db.user, - 'TIME_ZONE': db.time_zone, - 'PASSWORD': db.password, - 'CONN_MAX_AGE': db.max_age, - 'DISABLE_SERVER_SIDE_CURSORS': db.pooled, - 'OPTIONS': options, + "ENGINE": engine, + "HOST": db.host, + "PORT": db.port, + "NAME": db.name, + "USER": db.user, + "TIME_ZONE": db.time_zone, + "PASSWORD": db.password, + "CONN_MAX_AGE": db.max_age, + "DISABLE_SERVER_SIDE_CURSORS": db.pooled, + "OPTIONS": options, # 'ATOMIC_REQUESTS': False, # 'AUTOCOMMIT': True, } @@ -251,23 +257,27 @@ def get_database(cls, db: Database, service: UtilMeta): def hook(self, service: UtilMeta): from .cmd import DjangoCommand from .adaptor import DjangoServerAdaptor + service.register_command(DjangoCommand) if isinstance(service.adaptor, DjangoServerAdaptor): service.adaptor.settings = self # replace settings - def apply_settings(self, service: UtilMeta, django_settings: Union[Settings, LazySettings]): + def apply_settings( + self, service: UtilMeta, django_settings: Union[Settings, LazySettings] + ): self.django_settings = django_settings adaptor = service.adaptor from .adaptor import DjangoServerAdaptor + if isinstance(adaptor, DjangoServerAdaptor): adaptor.settings = self - databases = getattr(django_settings, 'DATABASES', {}) + databases = getattr(django_settings, "DATABASES", {}) if not isinstance(databases, dict): databases = {} - caches = getattr(django_settings, 'CACHES', {}) + caches = getattr(django_settings, "CACHES", {}) if not isinstance(caches, dict): caches = {} @@ -283,43 +293,61 @@ def apply_settings(self, service: UtilMeta, django_settings: Union[Settings, Laz if databases: for key, val in databases.items(): - val = {str(k).lower(): v for k, v in val.items()} if isinstance(val, dict) else {} + val = ( + {str(k).lower(): v for k, v in val.items()} + if isinstance(val, dict) + else {} + ) if not val: continue if key not in db_config.databases: - db_config.add_database(service, alias=key, database=Database( - name=val.get('name'), - user=val.get('user'), - password=val.get('password'), - engine=val.get('engine'), - host=val.get('host'), - port=val.get('port'), - options=val.get('options'), - )) + db_config.add_database( + service, + alias=key, + database=Database( + name=val.get("name"), + user=val.get("user"), + password=val.get("password"), + engine=val.get("engine"), + host=val.get("host"), + port=val.get("port"), + options=val.get("options"), + ), + ) if caches: for key, val in caches.items(): - val = {str(k).lower(): v for k, v in val.items()} if isinstance(val, dict) else {} + val = ( + {str(k).lower(): v for k, v in val.items()} + if isinstance(val, dict) + else {} + ) if not val: continue if key not in cache_config.caches: - options = val.get('options') or {} - cache_config.add_cache(service, alias=key, cache=Cache( - engine=val.get('backend'), - host=val.get('host'), - port=val.get('port'), - options=val.get('options'), - timeout=val.get('timeout'), - location=val.get('location'), - prefix=val.get('key_prefix'), - max_entries=val.get('max_entries') or options.get('MAX_ENTRIES'), - key_function=val.get('key_function') - )) + options = val.get("options") or {} + cache_config.add_cache( + service, + alias=key, + cache=Cache( + engine=val.get("backend"), + host=val.get("host"), + port=val.get("port"), + options=val.get("options"), + timeout=val.get("timeout"), + location=val.get("location"), + prefix=val.get("key_prefix"), + max_entries=val.get("max_entries") + or options.get("MAX_ENTRIES"), + key_function=val.get("key_function"), + ), + ) db_changed = False cached_changed = False from utilmeta.core.orm.backends.django.database import DjangoDatabaseAdaptor + for name, db in db_config.databases.items(): if not db.sync_adaptor_cls: db.sync_adaptor_cls = DjangoDatabaseAdaptor @@ -328,6 +356,7 @@ def apply_settings(self, service: UtilMeta, django_settings: Union[Settings, Laz databases[name] = self.get_database(db, service) from utilmeta.core.cache.backends.django import DjangoCacheAdaptor + for name, cache in cache_config.caches.items(): if not cache.sync_adaptor_cls: cache.sync_adaptor_cls = DjangoCacheAdaptor @@ -336,12 +365,15 @@ def apply_settings(self, service: UtilMeta, django_settings: Union[Settings, Laz caches[name] = self.get_cache(cache) if db_changed: - self.change_settings('DATABASES', databases, force=True) + self.change_settings("DATABASES", databases, force=True) from django.db import connections - connections._settings = connections.settings = connections.configure_settings(None) + + connections._settings = ( + connections.settings + ) = connections.configure_settings(None) if cached_changed: - self.change_settings('CACHES', caches, force=True) + self.change_settings("CACHES", caches, force=True) # ------------------ patch_model_fields(service) @@ -350,26 +382,29 @@ def apply_settings(self, service: UtilMeta, django_settings: Union[Settings, Laz hosts = list(self.allowed_hosts) if service.origin: from urllib.parse import urlparse + hosts.append(urlparse(service.origin).hostname) - self.merge_list_settings('MIDDLEWARE', self.middleware) - self.merge_list_settings('ALLOWED_HOSTS', hosts) - self.merge_list_settings('DATABASE_ROUTERS', self.database_routers) + self.merge_list_settings("MIDDLEWARE", self.middleware) + self.merge_list_settings("ALLOWED_HOSTS", hosts) + self.merge_list_settings("DATABASE_ROUTERS", self.database_routers) if self.append_slash: - self.change_settings('APPEND_SLASH', self.append_slash, force=True) + self.change_settings("APPEND_SLASH", self.append_slash, force=True) try: - if not getattr(django_settings, 'SECRET_KEY', None): - self.change_settings('SECRET_KEY', self.get_secret(service), force=False) + if not getattr(django_settings, "SECRET_KEY", None): + self.change_settings( + "SECRET_KEY", self.get_secret(service), force=False + ) except ImproperlyConfigured: - self.change_settings('SECRET_KEY', self.get_secret(service), force=False) + self.change_settings("SECRET_KEY", self.get_secret(service), force=False) if service.production: # elsewhere we keep the original settings - self.change_settings('DEBUG', False, force=True) + self.change_settings("DEBUG", False, force=True) else: - if getattr(django_settings, 'DEBUG', None) is False: + if getattr(django_settings, "DEBUG", None) is False: service.production = True time_config = Time.config() @@ -379,30 +414,36 @@ def apply_settings(self, service: UtilMeta, django_settings: Union[Settings, Laz else: # the default django DATETIME_FORMAT is N j, Y, P # which is not a valid datetime string - service.use(Time( - time_zone=getattr(django_settings, 'TIME_ZONE', None), - use_tz=getattr(django_settings, 'USE_TZ', True), - # date_format=getattr(django_settings, 'DATE_FORMAT', Time.DATE_DEFAULT), - # datetime_format=getattr(django_settings, 'DATETIME_FORMAT', Time.DATETIME_DEFAULT), - # time_format=getattr(django_settings, 'TIME_FORMAT', Time.TIME_DEFAULT), - )) + service.use( + Time( + time_zone=getattr(django_settings, "TIME_ZONE", None), + use_tz=getattr(django_settings, "USE_TZ", True), + # date_format=getattr(django_settings, 'DATE_FORMAT', Time.DATE_DEFAULT), + # datetime_format=getattr(django_settings, 'DATETIME_FORMAT', Time.DATETIME_DEFAULT), + # time_format=getattr(django_settings, 'TIME_FORMAT', Time.TIME_DEFAULT), + ) + ) # set DEFAULT_AUTO_FIELD before a (probably) apps reload - explicit_settings = getattr(self.django_settings, '_explicit_settings', None) - self.change_settings('DEFAULT_AUTO_FIELD', - self.default_autofield or DEFAULT_AUTO_FIELD, force=True) + explicit_settings = getattr(self.django_settings, "_explicit_settings", None) + self.change_settings( + "DEFAULT_AUTO_FIELD", + self.default_autofield or DEFAULT_AUTO_FIELD, + force=True, + ) if isinstance(explicit_settings, set): - explicit_settings.add('DEFAULT_AUTO_FIELD') + explicit_settings.add("DEFAULT_AUTO_FIELD") # this is to prevent django W042: Auto-created primary key used when not defining a if self.language: - self.change_settings('LANGUAGE_CODE', self.language, force=True) + self.change_settings("LANGUAGE_CODE", self.language, force=True) if self.use_i18n: - self.change_settings('USE_I18N', self.use_i18n, force=True) + self.change_settings("USE_I18N", self.use_i18n, force=True) if self.apps: - new_apps = self.merge_list_settings('INSTALLED_APPS', self.apps) + new_apps = self.merge_list_settings("INSTALLED_APPS", self.apps) from django.apps import apps + if apps.ready: # apps already setup apps.ready = False @@ -428,11 +469,18 @@ def apply_settings(self, service: UtilMeta, django_settings: Union[Settings, Laz self.module = service.module os.environ[SETTINGS_MODULE] = self.module_name - self.wsgi_application = (getattr(django_settings, WSGI_APPLICATION, None) or - self.wsgi_application or self.get_service_wsgi_app(service)) - self.root_urlconf = getattr(django_settings, ROOT_URLCONF, None) or self.root_urlconf + self.wsgi_application = ( + getattr(django_settings, WSGI_APPLICATION, None) + or self.wsgi_application + or self.get_service_wsgi_app(service) + ) + self.root_urlconf = ( + getattr(django_settings, ROOT_URLCONF, None) or self.root_urlconf + ) if self.root_urlconf: - self.url_conf = sys.modules.get(self.root_urlconf) or import_obj(self.root_urlconf) + self.url_conf = sys.modules.get(self.root_urlconf) or import_obj( + self.root_urlconf + ) else: # raise ValueError(f'Invalid root urlconf: {self.root_urlconf}') self.root_urlconf = service.module_name or self.module_name @@ -448,8 +496,10 @@ def change_settings(self, settings_name, value, force=False): try: if not force and hasattr(self.django_settings, settings_name): return - if (hasattr(self.django_settings, settings_name) and - getattr(self.django_settings, settings_name) != value): + if ( + hasattr(self.django_settings, settings_name) + and getattr(self.django_settings, settings_name) != value + ): pass else: return @@ -457,6 +507,7 @@ def change_settings(self, settings_name, value, force=False): pass setattr(self.django_settings, settings_name, value) from django.core.signals import setting_changed + setting_changed.send( sender=self.__class__, setting=settings_name, @@ -483,10 +534,10 @@ def merge_list_settings(self, settings_name: str, settings_list: list): @classmethod def get_service_wsgi_app(cls, service: UtilMeta): - app = service.meta_config.get('app') + app = service.meta_config.get("app") if not app: - return f'{service.module_name}.app' - return str(app).replace(':', '.') + return f"{service.module_name}.app" + return str(app).replace(":", ".") def setup(self, service: UtilMeta): # django_settings = None @@ -494,6 +545,7 @@ def setup(self, service: UtilMeta): module_name = os.environ.get(SETTINGS_MODULE) try: from django.conf import settings + _ = settings.INSTALLED_APPS # if the settings is not configured, this will trigger ImproperlyConfigured except (ImportError, ImproperlyConfigured): @@ -532,16 +584,21 @@ def setup(self, service: UtilMeta): caches = {} if db_config: - if db_config.databases and 'default' not in db_config.databases: + if db_config.databases and "default" not in db_config.databases: # often: a no-db service add Operations() # we need to define a '__ops' db, but django will force us to # define a 'default' db - db_config.add_database(service, 'default', database=Database( - name=os.path.join(service.project_dir, '__default_db'), - engine='sqlite3' - )) + db_config.add_database( + service, + "default", + database=Database( + name=os.path.join(service.project_dir, "__default_db"), + engine="sqlite3", + ), + ) from utilmeta.core.orm.backends.django.database import DjangoDatabaseAdaptor + for name, db in db_config.databases.items(): if not db.sync_adaptor_cls: db.sync_adaptor_cls = DjangoDatabaseAdaptor @@ -549,6 +606,7 @@ def setup(self, service: UtilMeta): if cache_config: from utilmeta.core.cache.backends.django import DjangoCacheAdaptor + for name, cache in cache_config.caches.items(): if not cache.sync_adaptor_cls: cache.sync_adaptor_cls = DjangoCacheAdaptor @@ -557,50 +615,56 @@ def setup(self, service: UtilMeta): middleware = list(self.middleware or DEFAULT_MIDDLEWARE) adaptor = service.adaptor from .adaptor import DjangoServerAdaptor + if isinstance(adaptor, DjangoServerAdaptor): adaptor.settings = self middleware_func = adaptor.middleware_func if middleware_func: setattr(self.module, middleware_func.__name__, middleware_func) - middleware.append(f'{self.module_name}.{middleware_func.__name__}') + middleware.append(f"{self.module_name}.{middleware_func.__name__}") hosts = list(self.allowed_hosts) if service.origin: from urllib.parse import urlparse + hosts.append(urlparse(service.origin).hostname) - self.wsgi_application = self.wsgi_application or self.get_service_wsgi_app(service) + self.wsgi_application = self.wsgi_application or self.get_service_wsgi_app( + service + ) settings = { - 'DEBUG': not service.production, - 'SECRET_KEY': self.get_secret(service), - 'BASE_DIR': service.project_dir, - 'MIDDLEWARE': middleware, - 'INSTALLED_APPS': self.apps, - 'ALLOWED_HOSTS': hosts, - 'DATABASE_ROUTERS': self.database_routers, - 'APPEND_SLASH': self.append_slash, - 'LANGUAGE_CODE': self.language, - 'USE_I18N': self.use_i18n, - 'DEFAULT_AUTO_FIELD': self.default_autofield or DEFAULT_AUTO_FIELD, + "DEBUG": not service.production, + "SECRET_KEY": self.get_secret(service), + "BASE_DIR": service.project_dir, + "MIDDLEWARE": middleware, + "INSTALLED_APPS": self.apps, + "ALLOWED_HOSTS": hosts, + "DATABASE_ROUTERS": self.database_routers, + "APPEND_SLASH": self.append_slash, + "LANGUAGE_CODE": self.language, + "USE_I18N": self.use_i18n, + "DEFAULT_AUTO_FIELD": self.default_autofield or DEFAULT_AUTO_FIELD, # 'DATABASES': databases, # 'CACHES': caches, ROOT_URLCONF: self.root_urlconf or service.module_name, - WSGI_APPLICATION: self.wsgi_application + WSGI_APPLICATION: self.wsgi_application, } if databases: - settings.update({'DATABASES': databases}) + settings.update({"DATABASES": databases}) if caches: - settings.update({'CACHES': caches}) + settings.update({"CACHES": caches}) time_config = Time.config() if time_config: settings.update(self.get_time(time_config)) else: # mandatory - settings.update({ - 'TIME_ZONE': DEFAULT_TIME_ZONE, - 'USE_TZ': True, - }) + settings.update( + { + "TIME_ZONE": DEFAULT_TIME_ZONE, + "USE_TZ": True, + } + ) if self._plugin_settings: settings.update(self._plugin_settings) @@ -623,7 +687,7 @@ def setup(self, service: UtilMeta): # at most circumstances, there are module name os.environ[SETTINGS_MODULE] = settings_name else: - os.environ.setdefault(SETTINGS_MODULE, '__main__') + os.environ.setdefault(SETTINGS_MODULE, "__main__") # not using setdefault to prevent IDE set the wrong value by default django.setup(set_prefix=False) @@ -631,26 +695,32 @@ def setup(self, service: UtilMeta): # import root url conf after the django setup if self.root_urlconf: - self.url_conf = sys.modules.get(self.root_urlconf) or import_obj(self.root_urlconf) + self.url_conf = sys.modules.get(self.root_urlconf) or import_obj( + self.root_urlconf + ) else: self.url_conf = service.module - urlpatterns = getattr(self.url_conf, 'urlpatterns', []) + urlpatterns = getattr(self.url_conf, "urlpatterns", []) # if self.urlpatterns: # urlpatterns = urlpatterns + self.urlpatterns - setattr(self.url_conf, 'urlpatterns', urlpatterns or []) + setattr(self.url_conf, "urlpatterns", urlpatterns or []) # this set is required, otherwise url_conf.urlpatterns is not exists try: from django.conf import settings except (ImportError, ImproperlyConfigured) as e: - raise ImproperlyConfigured(f'DjangoSettings: configure django failed: {e}') from e + raise ImproperlyConfigured( + f"DjangoSettings: configure django failed: {e}" + ) from e else: self.django_settings = settings - explicit_settings = getattr(self.django_settings, '_explicit_settings', None) + explicit_settings = getattr( + self.django_settings, "_explicit_settings", None + ) if isinstance(explicit_settings, set): - explicit_settings.add('DEFAULT_AUTO_FIELD') + explicit_settings.add("DEFAULT_AUTO_FIELD") # this is to prevent django W042: Auto-created primary key used when not defining a @property @@ -658,15 +728,15 @@ def wsgi_module_ref(self): wsgi_app_ref = self.wsgi_application if not wsgi_app_ref: return None - if ':' in wsgi_app_ref: - return wsgi_app_ref.split(':')[0] - return '.'.join(wsgi_app_ref.split('.')[:-1]) + if ":" in wsgi_app_ref: + return wsgi_app_ref.split(":")[0] + return ".".join(wsgi_app_ref.split(".")[:-1]) @property def wsgi_app_attr(self): wsgi_app_ref = self.wsgi_application - if isinstance(wsgi_app_ref, str) and '.' in wsgi_app_ref: - return wsgi_app_ref.split('.')[-1] + if isinstance(wsgi_app_ref, str) and "." in wsgi_app_ref: + return wsgi_app_ref.split(".")[-1] return None @property diff --git a/utilmeta/core/server/backends/django/utils.py b/utilmeta/core/server/backends/django/utils.py index 901bed8..e96121c 100644 --- a/utilmeta/core/server/backends/django/utils.py +++ b/utilmeta/core/server/backends/django/utils.py @@ -5,12 +5,13 @@ def patch_model_fields(service): - if sys.version_info >= (3, 9) or os.name != 'nt': + if sys.version_info >= (3, 9) or os.name != "nt": if django.VERSION >= (3, 1): return # sqlite in-compat from utilmeta.core.orm import DatabaseConnections + dbs = service.get_config(DatabaseConnections) if dbs: has_sqlite = False @@ -24,15 +25,19 @@ def patch_model_fields(service): return if has_not_sqlite: if django.VERSION >= (3, 1): - warnings.warn(f'You are using mixed database engines with sqlite3 in Windows under Python 3.9, ' - f'JSONField cannot operate properly') + warnings.warn( + f"You are using mixed database engines with sqlite3 in Windows under Python 3.9, " + f"JSONField cannot operate properly" + ) return from django.db import models from utype import JSONEncoder class RawJSONField(models.Field): - def __init__(self, verbose_name=None, name=None, encoder=None, decoder=None, **kwargs): + def __init__( + self, verbose_name=None, name=None, encoder=None, decoder=None, **kwargs + ): self.encoder = encoder or JSONEncoder self.decoder = decoder super().__init__(verbose_name, name, **kwargs) @@ -40,10 +45,10 @@ def __init__(self, verbose_name=None, name=None, encoder=None, decoder=None, **k def get_internal_type(self): # act like TextField # return 'JSONField' - return 'TextField' + return "TextField" def db_type(self, connection): - return 'text' + return "text" def from_db_value(self, value, expression, connection): if value is not None: @@ -52,6 +57,7 @@ def from_db_value(self, value, expression, connection): def to_python(self, value): import json + if value is not None: try: return json.loads(value) @@ -61,6 +67,7 @@ def to_python(self, value): def get_prep_value(self, value): import json + if value is not None: return json.dumps(value, cls=self.encoder, ensure_ascii=False) return value @@ -70,6 +77,7 @@ def value_to_string(self, obj): def get_db_prep_value(self, value, connection, prepared=False): import json + if value is not None: return json.dumps(value, cls=self.encoder, ensure_ascii=False) return value @@ -77,4 +85,5 @@ def get_db_prep_value(self, value, connection, prepared=False): models.JSONField = RawJSONField if django.VERSION < (3, 1): from django.db.models import PositiveIntegerField + models.PositiveBigIntegerField = PositiveIntegerField diff --git a/utilmeta/core/server/backends/fastapi.py b/utilmeta/core/server/backends/fastapi.py index d1f3719..77b336f 100644 --- a/utilmeta/core/server/backends/fastapi.py +++ b/utilmeta/core/server/backends/fastapi.py @@ -8,19 +8,19 @@ class FastAPIServerAdaptor(StarletteServerAdaptor): application_cls = FastAPI app: FastAPI - def generate(self, spec: str = 'openapi'): - if spec == 'openapi': + def generate(self, spec: str = "openapi"): + if spec == "openapi": app: FastAPI = self.application() - app.openapi_schema = None # clear cache + app.openapi_schema = None # clear cache return app.openapi() @property def root_path(self) -> str: - return str(getattr(self.app, 'root_path', '') or '').strip('/') + return str(getattr(self.app, "root_path", "") or "").strip("/") @property def version(self) -> str: - return getattr(self.app, 'version', '') + return getattr(self.app, "version", "") # def load_route(self, request): # path = request.path_params.get('path') or request.url.path diff --git a/utilmeta/core/server/backends/flask.py b/utilmeta/core/server/backends/flask.py index 93cbd93..6e0a74a 100644 --- a/utilmeta/core/server/backends/flask.py +++ b/utilmeta/core/server/backends/flask.py @@ -9,8 +9,8 @@ import sys import contextvars -_current_request = contextvars.ContextVar('_flask.request') -_current_response = contextvars.ContextVar('_flask.response') +_current_request = contextvars.ContextVar("_flask.request") +_current_response = contextvars.ContextVar("_flask.response") class FlaskServerAdaptor(ServerAdaptor): @@ -20,7 +20,7 @@ class FlaskServerAdaptor(ServerAdaptor): application_cls = Flask default_asynchronous = False HANDLED_METHODS = ("DELETE", "HEAD", "GET", "OPTIONS", "PATCH", "POST", "PUT") - DEFAULT_HOST = '127.0.0.1' + DEFAULT_HOST = "127.0.0.1" DEFAULT_PORT = 5000 # REQUEST_ATTR = '_utilmeta_request' @@ -28,15 +28,18 @@ class FlaskServerAdaptor(ServerAdaptor): def __init__(self, config): super().__init__(config) - self.app = self.config._application if isinstance(self.config._application, self.application_cls) \ + self.app = ( + self.config._application + if isinstance(self.config._application, self.application_cls) else self.application_cls(self.config.module_name) + ) self._ready = False # def init_application(self): # return self.config._application if isinstance(self.config._application, self.application_cls) \ # else self.application_cls(self.config.module_name) - def adapt(self, api: 'API', route: str, asynchronous: bool = None): + def adapt(self, api: "API", route: str, asynchronous: bool = None): if asynchronous is None: asynchronous = self.default_asynchronous self.add_api(self.app, api, asynchronous=asynchronous, route=route) @@ -52,11 +55,7 @@ def setup_middlewares(self): def setup(self): if self._ready: return - self.add_api( - self.app, - self.resolve(), - asynchronous=self.asynchronous - ) + self.add_api(self.app, self.resolve(), asynchronous=self.asynchronous) self.setup_middlewares() self.apply_fork() @@ -71,7 +70,7 @@ def run(self, **kwargs): host=self.config.host or self.DEFAULT_HOST, port=self.config.port, debug=not self.config.production, - **kwargs + **kwargs, ) finally: self.config.shutdown() @@ -79,7 +78,7 @@ def run(self, **kwargs): @property def backend_views_empty(self) -> bool: for val in self.app.view_functions.values(): - wrapped = getattr(val, '__wrapped__', None) + wrapped = getattr(val, "__wrapped__", None) if wrapped and isinstance(wrapped, type) and issubclass(wrapped, API): pass else: @@ -94,26 +93,35 @@ def async_startup(self) -> bool: def production(self) -> bool: return not self.app.debug - def add_api(self, app: Flask, utilmeta_api_class, route: str = '', asynchronous: bool = False): + def add_api( + self, + app: Flask, + utilmeta_api_class, + route: str = "", + asynchronous: bool = False, + ): """ Mount a API class make sure it is called after all your fastapi route is set """ from utilmeta.core.api.base import API + if not issubclass(utilmeta_api_class, API): - raise TypeError(f'Invalid api class: {utilmeta_api_class}') + raise TypeError(f"Invalid api class: {utilmeta_api_class}") - if route and route.strip('/'): - route = '/' + route.strip('/') - prepend = route + '/' + if route and route.strip("/"): + route = "/" + route.strip("/") + prepend = route + "/" else: - prepend = route = '/' + prepend = route = "/" if asynchronous: - @app.route(route, defaults={'path': ''}, methods=self.HANDLED_METHODS) - @app.route('%s' % prepend, methods=self.HANDLED_METHODS) + + @app.route(route, defaults={"path": ""}, methods=self.HANDLED_METHODS) + @app.route("%s" % prepend, methods=self.HANDLED_METHODS) async def f(path: str): from flask import request + req = None try: req = _current_request.get(None) @@ -127,14 +135,19 @@ async def f(path: str): resp = await utilmeta_api_class(req)() except Exception as e: - resp = getattr(utilmeta_api_class, 'response', Response)(error=e, request=req) + resp = getattr(utilmeta_api_class, "response", Response)( + error=e, request=req + ) _current_response.set(resp) return self.response_adaptor_cls.reconstruct(resp) + else: - @app.route(route, defaults={'path': ''}, methods=self.HANDLED_METHODS) - @app.route('%s' % prepend, methods=self.HANDLED_METHODS) + + @app.route(route, defaults={"path": ""}, methods=self.HANDLED_METHODS) + @app.route("%s" % prepend, methods=self.HANDLED_METHODS) def f(path: str): from flask import request + req = None try: req = _current_request.get(None) @@ -148,9 +161,12 @@ def f(path: str): resp = utilmeta_api_class(req)() except Exception as e: - resp = getattr(utilmeta_api_class, 'response', Response)(error=e, request=req) + resp = getattr(utilmeta_api_class, "response", Response)( + error=e, request=req + ) _current_response.set(resp) return self.response_adaptor_cls.reconstruct(resp) + f.__wrapped__ = utilmeta_api_class return f @@ -187,6 +203,7 @@ def wsgi_app(self, environ: dict, start_response): try: ctx.push() from flask import request as flask_request + # ----------------------- response = None request = Request(self.request_adaptor_cls(flask_request)) @@ -205,26 +222,20 @@ def wsgi_app(self, environ: dict, start_response): _current_response.set(None) if not isinstance(response, Response): response = Response( - response=self.response_adaptor_cls( - flask_response - ), - request=request + response=self.response_adaptor_cls(flask_response), + request=request, ) else: if not response.adaptor: - response.adaptor = self.response_adaptor_cls( - flask_response - ) + response.adaptor = self.response_adaptor_cls(flask_response) except Exception as e: error = e flask_response = self.app.handle_exception(e) response = Response( - response=self.response_adaptor_cls( - flask_response - ), + response=self.response_adaptor_cls(flask_response), error=e, - request=request + request=request, ) except: # noqa: B001 error = sys.exc_info()[1] @@ -244,6 +255,7 @@ def wsgi_app(self, environ: dict, start_response): finally: if "werkzeug.debug.preserve_context" in environ: from flask.app import _cv_app, _cv_request + environ["werkzeug.debug.preserve_context"](_cv_app.get()) environ["werkzeug.debug.preserve_context"](_cv_request.get()) diff --git a/utilmeta/core/server/backends/sanic.py b/utilmeta/core/server/backends/sanic.py index 40ac945..47e6f27 100644 --- a/utilmeta/core/server/backends/sanic.py +++ b/utilmeta/core/server/backends/sanic.py @@ -8,8 +8,8 @@ from utilmeta.core.api import API import contextvars -_current_request = contextvars.ContextVar('_sanic.request') -_current_response = contextvars.ContextVar('_sanic.response') +_current_request = contextvars.ContextVar("_sanic.request") +_current_response = contextvars.ContextVar("_sanic.response") class SanicServerAdaptor(ServerAdaptor): @@ -17,23 +17,26 @@ class SanicServerAdaptor(ServerAdaptor): request_adaptor_cls = SanicRequestAdaptor response_adaptor_cls = SanicResponseAdaptor application_cls = Sanic - DEFAULT_NAME = 'sanic_application' + DEFAULT_NAME = "sanic_application" default_asynchronous = True HANDLED_METHODS = ("DELETE", "HEAD", "GET", "OPTIONS", "PATCH", "POST", "PUT") def __init__(self, config): super().__init__(config) - self.app = self.config._application if isinstance(self.config._application, self.application_cls) \ + self.app = ( + self.config._application + if isinstance(self.config._application, self.application_cls) else self.application_cls(self.config.name or self.DEFAULT_NAME) + ) self._ready = False self._extenstion = None @property def root_path(self) -> str: - server_name = getattr(self.app.config, 'SERVER_NAME', '') - if server_name and '/' in server_name: - return server_name.split('/')[1] - return '' + server_name = getattr(self.app.config, "SERVER_NAME", "") + if server_name and "/" in server_name: + return server_name.split("/")[1] + return "" @property def production(self) -> bool: @@ -43,7 +46,7 @@ def application(self): self.setup() return self.app - def adapt(self, api: 'API', route: str, asynchronous: bool = None): + def adapt(self, api: "API", route: str, asynchronous: bool = None): if asynchronous is None: asynchronous = self.default_asynchronous self.add_api(self.app, api, asynchronous=asynchronous, route=route) @@ -65,19 +68,16 @@ def on_request(self, sanic_request): def on_response(self, sanic_request, sanic_response): response = _current_response.get(None) - request = _current_request.get(None) or Request(self.request_adaptor_cls(sanic_request)) + request = _current_request.get(None) or Request( + self.request_adaptor_cls(sanic_request) + ) if not isinstance(response, Response): response = Response( - response=self.response_adaptor_cls( - sanic_response - ), - request=request + response=self.response_adaptor_cls(sanic_response), request=request ) else: if not response.adaptor: - response.adaptor = self.response_adaptor_cls( - sanic_response - ) + response.adaptor = self.response_adaptor_cls(sanic_response) response_updated = False for middleware in self.middlewares: @@ -102,10 +102,10 @@ def setup_middlewares(self): @property def backend_views_empty(self) -> bool: for val in self.app.router.routes: - handler = getattr(val, 'handler', None) + handler = getattr(val, "handler", None) if not handler: continue - wrapped = getattr(handler, '__wrapped__', None) + wrapped = getattr(handler, "__wrapped__", None) if wrapped and isinstance(wrapped, type) and issubclass(wrapped, API): pass else: @@ -115,11 +115,7 @@ def backend_views_empty(self) -> bool: def setup(self): if self._ready: return - self.add_api( - self.app, - self.resolve(), - asynchronous=self.asynchronous - ) + self.add_api(self.app, self.resolve(), asynchronous=self.asynchronous) self.setup_middlewares() @@ -139,27 +135,34 @@ def run(self, **kwargs): host=self.config.host, port=self.config.port, debug=not self.config.production, - **kwargs + **kwargs, ) - def add_api(self, app: Sanic, utilmeta_api_class, route: str = '', asynchronous: bool = False): + def add_api( + self, + app: Sanic, + utilmeta_api_class, + route: str = "", + asynchronous: bool = False, + ): """ Mount a API class make sure it is called after all your fastapi route is set """ from utilmeta.core.api.base import API + if not issubclass(utilmeta_api_class, API): - raise TypeError(f'Invalid api class: {utilmeta_api_class}') + raise TypeError(f"Invalid api class: {utilmeta_api_class}") - if route and route.strip('/'): - route = '/' + route.strip('/') - prepend = route + '/' + if route and route.strip("/"): + route = "/" + route.strip("/") + prepend = route + "/" else: - prepend = '/' + prepend = "/" if asynchronous: # @app.route('%s' % prepend, methods=self.HANDLED_METHODS, static=True) - async def f(request, path: str = ''): + async def f(request, path: str = ""): req = None try: req = _current_request.get(None) @@ -173,12 +176,15 @@ async def f(request, path: str = ''): resp = await utilmeta_api_class(req)() except Exception as e: - resp = getattr(utilmeta_api_class, 'response', Response)(error=e, request=req) + resp = getattr(utilmeta_api_class, "response", Response)( + error=e, request=req + ) _current_response.set(resp) return self.response_adaptor_cls.reconstruct(resp) + else: # @app.route('%s' % prepend, methods=self.HANDLED_METHODS, static=True) - def f(request, path: str = ''): + def f(request, path: str = ""): req = None try: req = _current_request.get(None) @@ -192,7 +198,9 @@ def f(request, path: str = ''): resp = utilmeta_api_class(req)() except Exception as e: - resp = getattr(utilmeta_api_class, 'response', Response)(error=e, request=req) + resp = getattr(utilmeta_api_class, "response", Response)( + error=e, request=req + ) _current_response.set(resp) return self.response_adaptor_cls.reconstruct(resp) @@ -200,19 +208,20 @@ def f(request, path: str = ''): # app.route(route, methods=self.HANDLED_METHODS, name='core_methods')(f) f.__wrapped__ = utilmeta_api_class return app.route( - '%s' % prepend, + "%s" % prepend, methods=self.HANDLED_METHODS, - name=getattr(utilmeta_api_class, '__ref__', utilmeta_api_class.__name__), + name=getattr(utilmeta_api_class, "__ref__", utilmeta_api_class.__name__), # or there might be "Duplicate route names detected" - static=True + static=True, )(f) - def generate(self, spec: str = 'openapi'): - if spec == 'openapi': + def generate(self, spec: str = "openapi"): + if spec == "openapi": app = self.app # from sanic_ext import Extend # setup = not hasattr(app, "_ext") from sanic_routing.exceptions import FinalizationError + try: _ = app.ext except RuntimeError: @@ -226,15 +235,19 @@ def generate(self, spec: str = 'openapi'): try: from sanic_ext.extensions.openapi.builders import SpecificationBuilder from sanic_ext.extensions.openapi.blueprint import blueprint_factory - bp = app.blueprints.get('openapi') or blueprint_factory(app.config) + + bp = app.blueprints.get("openapi") or blueprint_factory(app.config) for listener in bp._future_listeners: - if listener.listener.__name__ == 'build_spec': + if listener.listener.__name__ == "build_spec": listener.listener(app, None) return SpecificationBuilder().build(app).serialize() except (ModuleNotFoundError, ImportError): try: - from sanic_openapi.openapi3.builders import SpecificationBuilder # noqa + from sanic_openapi.openapi3.builders import ( + SpecificationBuilder, + ) # noqa + return SpecificationBuilder().build(app).serialize() except (ModuleNotFoundError, ImportError): pass diff --git a/utilmeta/core/server/backends/starlette.py b/utilmeta/core/server/backends/starlette.py index a8c3acc..2635438 100644 --- a/utilmeta/core/server/backends/starlette.py +++ b/utilmeta/core/server/backends/starlette.py @@ -16,7 +16,7 @@ from typing import Optional from urllib.parse import urlparse -_current_request = contextvars.ContextVar('_starlette.request') +_current_request = contextvars.ContextVar("_starlette.request") # _current_response = contextvars.ContextVar('_starlette.response') # starlette 's response may cross the context, so that cannot be picked @@ -29,24 +29,32 @@ class StarletteServerAdaptor(ServerAdaptor): default_asynchronous = True DEFAULT_PORT = 8000 RECORD_RESPONSE_BODY_STATUS_GTE = 400 - RECORD_RESPONSE_BODY_LENGTH_LTE = 1024 ** 2 - RECORD_REQUEST_BODY_LENGTH_LTE = 1024 ** 2 + RECORD_RESPONSE_BODY_LENGTH_LTE = 1024**2 + RECORD_REQUEST_BODY_LENGTH_LTE = 1024**2 RECORD_REQUEST_BODY_TYPES = [ - RequestType.JSON, RequestType.XML, RequestType.APP_XML, RequestType.HTML, - RequestType.FORM_DATA, RequestType.FORM_URLENCODED, RequestType.PLAIN + RequestType.JSON, + RequestType.XML, + RequestType.APP_XML, + RequestType.HTML, + RequestType.FORM_DATA, + RequestType.FORM_URLENCODED, + RequestType.PLAIN, ] - DEFAULT_HOST = '127.0.0.1' + DEFAULT_HOST = "127.0.0.1" HANDLED_METHODS = ["DELETE", "HEAD", "GET", "OPTIONS", "PATCH", "POST", "PUT"] def __init__(self, config): super().__init__(config=config) - self.app = self.config._application if isinstance(self.config._application, self.application_cls) \ + self.app = ( + self.config._application + if isinstance(self.config._application, self.application_cls) else self.application_cls(debug=not self.config.production) + ) self._ready = False self._mounts = {} - def adapt(self, api: 'API', route: str, asynchronous: bool = None): + def adapt(self, api: "API", route: str, asynchronous: bool = None): if asynchronous is None: asynchronous = self.default_asynchronous self.add_api(self.app, api, asynchronous=asynchronous, route=route) @@ -54,13 +62,14 @@ def adapt(self, api: 'API', route: str, asynchronous: bool = None): def mount(self, app, route: str): if not self.is_asgi(app): from starlette.middleware.wsgi import WSGIMiddleware + # todo: fix deprecated app = WSGIMiddleware(app) self.app.mount(route, app) self._mounts[route] = app def load_route(self, request: StarletteRequest): - path = request.path_params.get('path') or request.url.path + path = request.path_params.get("path") or request.url.path return super().load_route(path) @property @@ -68,9 +77,9 @@ def backend_views_empty(self) -> bool: if self._mounts: return False for val in self.app.routes: - f = getattr(val, 'endpoint', None) + f = getattr(val, "endpoint", None) if f: - wrapped = getattr(f, '__wrapped__', None) + wrapped = getattr(f, "__wrapped__", None) if wrapped and isinstance(wrapped, type) and issubclass(wrapped, API): pass else: @@ -87,16 +96,16 @@ def production(self) -> bool: def setup_middlewares(self): if self.middlewares: from starlette.middleware.base import BaseHTTPMiddleware + self.app.add_middleware( - BaseHTTPMiddleware, # noqa - dispatch=self.get_middleware_func() + BaseHTTPMiddleware, dispatch=self.get_middleware_func() # noqa ) @classmethod async def get_response_body(cls, starlette_response: _StreamingResponse) -> bytes: response_body = [chunk async for chunk in starlette_response.body_iterator] starlette_response.body_iterator = iterate_in_threadpool(iter(response_body)) - return b''.join(response_body) + return b"".join(response_body) def get_middleware_func(self): async def utilmeta_middleware(starlette_request: StarletteRequest, call_next): @@ -116,38 +125,41 @@ async def utilmeta_middleware(starlette_request: StarletteRequest, call_next): if response is None: if request.adaptor.request_method.lower() in HAS_BODY_METHODS: - if request.content_type in self.RECORD_REQUEST_BODY_TYPES and ( - request.content_length or 0) <= self.RECORD_RESPONSE_BODY_LENGTH_LTE: + if ( + request.content_type in self.RECORD_REQUEST_BODY_TYPES + and (request.content_length or 0) + <= self.RECORD_RESPONSE_BODY_LENGTH_LTE + ): request.adaptor.body = await starlette_request.body() # read the body here any way, the request will cache it # and you cannot read it after response is generated _current_request.set(request) - starlette_response: Optional[_StreamingResponse] = await call_next(starlette_request) + starlette_response: Optional[_StreamingResponse] = await call_next( + starlette_request + ) _current_request.set(None) - response = request.adaptor.get_context('response') + response = request.adaptor.get_context("response") # response = _current_response.get(None) # _current_response.set(None) if not isinstance(response, Response): # from native starlette api - adaptor = self.response_adaptor_cls( - starlette_response - ) - if starlette_response.status_code >= self.RECORD_RESPONSE_BODY_STATUS_GTE: - if (adaptor.content_length or 0) <= self.RECORD_RESPONSE_BODY_LENGTH_LTE: + adaptor = self.response_adaptor_cls(starlette_response) + if ( + starlette_response.status_code + >= self.RECORD_RESPONSE_BODY_STATUS_GTE + ): + if ( + adaptor.content_length or 0 + ) <= self.RECORD_RESPONSE_BODY_LENGTH_LTE: body = await self.get_response_body(starlette_response) starlette_response.body = body # set body - response = Response( - response=adaptor, - request=request - ) + response = Response(response=adaptor, request=request) else: if not response.adaptor: - response.adaptor = self.response_adaptor_cls( - starlette_response - ) + response.adaptor = self.response_adaptor_cls(starlette_response) response_updated = False for middleware in self.middlewares: @@ -174,25 +186,28 @@ def setup(self): self.app, self.resolve(), asynchronous=self.asynchronous, - default=self.config.auto_created + default=self.config.auto_created, ) self.setup_middlewares() if self.asynchronous: - @self.app.on_event('startup') + + @self.app.on_event("startup") async def on_startup(): await self.config.startup() - @self.app.on_event('shutdown') + @self.app.on_event("shutdown") async def on_shutdown(): await self.config.shutdown() + else: - @self.app.on_event('startup') + + @self.app.on_event("startup") def on_startup(): self.config.startup() - @self.app.on_event('shutdown') + @self.app.on_event("shutdown") def on_shutdown(): self.config.shutdown() @@ -201,19 +216,28 @@ def on_shutdown(): def add_wsgi(self): pass - def add_api(self, app: Starlette, utilmeta_api_class, route: str = '', - asynchronous: bool = False, default: bool = False): + def add_api( + self, + app: Starlette, + utilmeta_api_class, + route: str = "", + asynchronous: bool = False, + default: bool = False, + ): """ Mount a API class make sure it is called after all your fastapi route is set """ from utilmeta.core.api.base import API - if not isinstance(utilmeta_api_class, type) or not issubclass(utilmeta_api_class, API): - raise TypeError(f'Invalid api class: {utilmeta_api_class}') - if route and route.strip('/'): - route = '/' + route.strip('/') + '/' + + if not isinstance(utilmeta_api_class, type) or not issubclass( + utilmeta_api_class, API + ): + raise TypeError(f"Invalid api class: {utilmeta_api_class}") + if route and route.strip("/"): + route = "/" + route.strip("/") + "/" else: - route = '/' + route = "/" # utilmeta_api_class: Type[API] if asynchronous: @@ -228,17 +252,18 @@ async def f(request: StarletteRequest, _default: bool = False): else: req.adaptor.route = path req.adaptor.request = request - resp = await utilmeta_api_class( - req - )() + resp = await utilmeta_api_class(req)() except Exception as e: if _default: if isinstance(e, exceptions.NotFound) and e.path: raise - resp = getattr(utilmeta_api_class, 'response', Response)(error=e, request=req) + resp = getattr(utilmeta_api_class, "response", Response)( + error=e, request=req + ) if req: req.adaptor.update_context(response=resp) return self.response_adaptor_cls.reconstruct(resp) + else: # @app.route('%s/{path:path}' % route, methods=cls.HANDLED_METHODS) def f(request: StarletteRequest, _default: bool = False): @@ -251,17 +276,18 @@ def f(request: StarletteRequest, _default: bool = False): else: req.adaptor.route = path req.adaptor.request = request - resp = utilmeta_api_class( - req - )() + resp = utilmeta_api_class(req)() except Exception as e: if _default: if isinstance(e, exceptions.NotFound) and e.path: raise - resp = getattr(utilmeta_api_class, 'response', Response)(error=e, request=req) + resp = getattr(utilmeta_api_class, "response", Response)( + error=e, request=req + ) if req: req.adaptor.update_context(response=resp) return self.response_adaptor_cls.reconstruct(resp) + f.__wrapped__ = utilmeta_api_class if default: @@ -269,6 +295,7 @@ def f(request: StarletteRequest, _default: bool = False): async def default_route(scope, receive, send): from starlette.requests import Request + request = Request(scope, receive=receive, send=send) try: response = f(request, True) @@ -278,12 +305,11 @@ async def default_route(scope, receive, send): # if the root router cannot analyze, we fall back to the original default return await original_default(scope, receive, send) await response(scope, receive, send) + app.router.default = default_route else: app.add_route( - path='%s{path:path}' % route, - route=f, - methods=self.HANDLED_METHODS + path="%s{path:path}" % route, route=f, methods=self.HANDLED_METHODS ) def application(self): @@ -296,11 +322,13 @@ def run(self, **kwargs): pass else: from utilmeta.utils import check_requirement - check_requirement('uvicorn', install_when_require=True) + + check_requirement("uvicorn", install_when_require=True) import uvicorn + uvicorn.run( self.app, host=self.config.host or self.DEFAULT_HOST, port=self.config.port, - **kwargs + **kwargs, ) diff --git a/utilmeta/core/server/backends/tornado.py b/utilmeta/core/server/backends/tornado.py index 1b3631e..9776ba8 100644 --- a/utilmeta/core/server/backends/tornado.py +++ b/utilmeta/core/server/backends/tornado.py @@ -18,38 +18,44 @@ class TornadoServerAdaptor(ServerAdaptor): def __init__(self, config): super().__init__(config) - self.app = self.config._application if isinstance(self.config._application, self.application_cls) else None + self.app = ( + self.config._application + if isinstance(self.config._application, self.application_cls) + else None + ) self._ready = False @property def production(self) -> bool: - return not self.app.settings.get('debug') + return not self.app.settings.get("debug") - def adapt(self, api: 'API', route: str, asynchronous: bool = None): + def adapt(self, api: "API", route: str, asynchronous: bool = None): if asynchronous is None: asynchronous = self.default_asynchronous - func = self.get_request_handler(api, asynchronous=asynchronous, append_slash=True) - path = rf'/{route.strip("/")}(\/.*)?' if route.strip('/') else '(.*)' - self.app.add_handlers( - '.*', [ - (path, func) - ] + func = self.get_request_handler( + api, asynchronous=asynchronous, append_slash=True ) + path = rf'/{route.strip("/")}(\/.*)?' if route.strip("/") else "(.*)" + self.app.add_handlers(".*", [(path, func)]) def load_route(self, path: str): - return (path or '').strip('/') + return (path or "").strip("/") - def get_request_handler(self, utilmeta_api_class, asynchronous: bool = False, append_slash: bool = False): + def get_request_handler( + self, utilmeta_api_class, asynchronous: bool = False, append_slash: bool = False + ): request_adaptor_cls = self.request_adaptor_cls service = self if append_slash: decorator = tornado.web.addslash else: + def decorator(f): return f if asynchronous: + class Handler(RequestHandler): @decorator async def get(self, *args, **kwargs): @@ -97,7 +103,9 @@ async def handle(self, path: str): if not isinstance(response, Response): response = Response(response=response, request=request) except Exception as e: - response = getattr(utilmeta_api_class, 'response', Response)(error=e, request=request) + response = getattr(utilmeta_api_class, "response", Response)( + error=e, request=request + ) for middleware in service.middlewares: _response = middleware.process_response(response) @@ -111,7 +119,9 @@ async def handle(self, path: str): return body = response.prepare_body() self.write(body) + else: + class Handler(RequestHandler): @decorator def get(self, *args, **kwargs): @@ -159,7 +169,9 @@ def handle(self, path: str): if not isinstance(response, Response): response = Response(response=response, request=request) except Exception as e: - response = getattr(utilmeta_api_class, 'response', Response)(error=e, request=request) + response = getattr(utilmeta_api_class, "response", Response)( + error=e, request=request + ) for middleware in service.middlewares: _response = middleware.process_response(response) or response @@ -179,10 +191,7 @@ def handle(self, path: str): @property def request_handler(self): - return self.get_request_handler( - self.resolve(), - asynchronous=self.asynchronous - ) + return self.get_request_handler(self.resolve(), asynchronous=self.asynchronous) def application(self): return self.setup() @@ -206,20 +215,14 @@ def setup(self): root_api = self.resolve() if self.config.root_url: - url_pattern = rf'/{self.config.root_url}(\/.*)?' + url_pattern = rf"/{self.config.root_url}(\/.*)?" else: - url_pattern = '/' + root_api._get_route_pattern().lstrip('^') + url_pattern = "/" + root_api._get_route_pattern().lstrip("^") if self.app: - self.app.add_handlers( - '.*', [ - (url_pattern, self.request_handler) - ] - ) + self.app.add_handlers(".*", [(url_pattern, self.request_handler)]) return self.app - self.app = self.application_cls([ - (url_pattern, self.request_handler) - ]) + self.app = self.application_cls([(url_pattern, self.request_handler)]) self._ready = True return self.app diff --git a/utilmeta/core/server/backends/werkzeug.py b/utilmeta/core/server/backends/werkzeug.py index 098b03f..8173fab 100644 --- a/utilmeta/core/server/backends/werkzeug.py +++ b/utilmeta/core/server/backends/werkzeug.py @@ -38,7 +38,7 @@ def __init__(self, config): def application_cls(self): class _Application(Application): def dispatch_request(self, request): - return Response('Hello World!') + return Response("Hello World!") def wsgi_app(self, environ, start_response): request = Request(environ) @@ -47,6 +47,7 @@ def wsgi_app(self, environ, start_response): def __call__(self, environ, start_response): return self.wsgi_app(environ, start_response) + return _Application def application(self): @@ -56,8 +57,9 @@ def application(self): @property def root_route(self): if not self.config.root_url: - return '' - return '/' + self.config.root_url.strip('/') + return "" + return "/" + self.config.root_url.strip("/") + # # def setup(self): # if self._ready: @@ -117,5 +119,3 @@ def root_route(self): # except Exception as e: # resp = getattr(utilmeta_api_class, 'response', Response)(error=e) # return cls.response_adaptor_cls.reconstruct(resp) - - diff --git a/utilmeta/core/server/service.py b/utilmeta/core/server/service.py index b993407..5cce7ca 100644 --- a/utilmeta/core/server/service.py +++ b/utilmeta/core/server/service.py @@ -3,17 +3,32 @@ import sys import os import re -from utilmeta.utils import (import_obj, awaitable, search_file, ignore_errors, LOCAL_IP, requires, path_merge, get_ip, - cached_property, get_origin, load_ini, read_from, write_to, localhost) +from utilmeta.utils import ( + import_obj, + awaitable, + search_file, + ignore_errors, + LOCAL_IP, + requires, + path_merge, + get_ip, + cached_property, + get_origin, + load_ini, + read_from, + write_to, + localhost, +) from utilmeta.conf.base import Config import inspect from utilmeta.core.api import API from pathlib import Path from ipaddress import ip_address + # if TYPE_CHECKING: # from utilmeta.core.api.specs.base import BaseAPISpec -T = TypeVar('T') +T = TypeVar("T") class UtilMeta: @@ -21,7 +36,8 @@ class UtilMeta: def __init__( self, - module_name: Optional[str], *, + module_name: Optional[str], + *, backend, name: str = None, title: str = None, @@ -29,7 +45,7 @@ def __init__( production: bool = None, host: str = None, port: int = None, - scheme: str = 'http', + scheme: str = "http", origin: str = None, version: Union[str, tuple] = None, # application=None, @@ -39,12 +55,12 @@ def __init__( asynchronous: bool = None, auto_reload: bool = None, api=None, - route: str = '', + route: str = "", ): """ - ! THERE MUST BE NO IMPORT BEFORE THE CONFIG IS ASSIGNED PROPERLY ! - if there is, the utils will use the incorrect initial settings Config and cause the - runtime error (hard to find) + ! THERE MUST BE NO IMPORT BEFORE THE CONFIG IS ASSIGNED PROPERLY ! + if there is, the utils will use the incorrect initial settings Config and cause the + runtime error (hard to find) """ # if not name.replace('-', '_').isidentifier(): @@ -55,13 +71,16 @@ def __init__( self.meta_path = None self.project_dir = Path(os.getcwd()) self.meta_config = {} - self.root_url = str(route or '').strip('/') + self.root_url = str(route or "").strip("/") if self.root_url: from urllib.parse import urlparse + if urlparse(self.root_url).scheme: - raise ValueError(f'UtilMeta service route: {repr(route)} must be a relative url, you can specify ' - f'the absolute url origin by parameter') + raise ValueError( + f"UtilMeta service route: {repr(route)} must be a relative url, you can specify " + f"the absolute url origin by parameter" + ) # self.root_url = str(root_url).strip('/') self.production = production @@ -85,7 +104,9 @@ def __init__( if host_addr: self.host_addr = ip_address(host_addr) except ValueError as e: - raise ValueError(f'UtilMeta service: invalid host: {repr(host)}, must be a valid IP address') from e + raise ValueError( + f"UtilMeta service: invalid host: {repr(host)}, must be a valid IP address" + ) from e self.port = port self.scheme = scheme @@ -108,13 +129,16 @@ def __init__( self.load_meta() import utilmeta + try: - srv: 'UtilMeta' = utilmeta.service + srv: "UtilMeta" = utilmeta.service except AttributeError: utilmeta.service = self else: if srv.name != self.name: - raise ValueError(f'Conflict service: {repr(self.name)}, {srv.name} in same process') + raise ValueError( + f"Conflict service: {repr(self.name)}, {srv.name} in same process" + ) utilmeta.service = self self.backend = None @@ -122,17 +146,19 @@ def __init__( self.backend_version = None from utilmeta.core.server.backends.base import ServerAdaptor + self.adaptor: Optional[ServerAdaptor] = None self.set_backend(backend) self._pool = None @property def module(self): - return sys.modules.get(self.module_name or '__main__') + return sys.modules.get(self.module_name or "__main__") @property def preference(self): from utilmeta.conf.preference import Preference + return self.get_config(Preference) or Preference.get() @property @@ -149,42 +175,54 @@ def root_api(self, api): try: api.__mount__(sub_api, route=route) except ValueError as e: - warnings.warn(f'utilmeta.service: mount {sub_api} to service failed with error: {e}') + warnings.warn( + f"utilmeta.service: mount {sub_api} to service failed with error: {e}" + ) self._unmounted_apis = {} self._root_api = api elif isinstance(api, str): self._root_api_ref = api elif api: - raise TypeError(f'Invalid root API for UtilMeta service: {api}, should be a API class' - f' inheriting utilmeta.core.api.API or a string reference to that class') + raise TypeError( + f"Invalid root API for UtilMeta service: {api}, should be a API class" + f" inheriting utilmeta.core.api.API or a string reference to that class" + ) def load_meta(self): - self.meta_path = search_file('utilmeta.ini') or search_file('meta.ini') + self.meta_path = search_file("utilmeta.ini") or search_file("meta.ini") if self.meta_path: self.project_dir = Path(os.path.dirname(self.meta_path)) try: config = load_ini(read_from(self.meta_path), parse_key=True) except Exception as e: - warnings.warn(f'load ini file: {self.meta_path} failed with error: {e}') + warnings.warn(f"load ini file: {self.meta_path} failed with error: {e}") else: - self.meta_config = config.get('utilmeta') or config.get('service') or {} + self.meta_config = config.get("utilmeta") or config.get("service") or {} if not isinstance(self.meta_config, dict): self.meta_config = {} - self.name = self.name or str(self.meta_config.get('name', '')).strip() - self.pid_file = self.meta_config.get('pidfile') or self.meta_config.get('pid') + self.name = self.name or str(self.meta_config.get("name", "")).strip() + self.pid_file = self.meta_config.get("pidfile") or self.meta_config.get( + "pid" + ) if self.pid_file: if not os.path.isabs(self.pid_file): self.pid_file = path_merge(str(self.project_dir), self.pid_file) - self.name = self.name or (os.path.basename(self.project_dir) if self.project_dir else None) + self.name = self.name or ( + os.path.basename(self.project_dir) if self.project_dir else None + ) if not self.name: - raise ValueError(f'UtilMeta service name not specified, you can set name using' - f' UtilMeta(name="your-project-name")') + raise ValueError( + f"UtilMeta service name not specified, you can set name using" + f' UtilMeta(name="your-project-name")' + ) - if not re.fullmatch(r'[A-Za-z0-9_-]+', self.name): - raise ValueError(f'UtilMeta service name: {repr(self.name)} can only contains alphanumeric characters, ' - 'underscore "_" and hyphen "-"') + if not re.fullmatch(r"[A-Za-z0-9_-]+", self.name): + raise ValueError( + f"UtilMeta service name: {repr(self.name)} can only contains alphanumeric characters, " + 'underscore "_" and hyphen "-"' + ) @property def pid(self) -> Optional[int]: @@ -194,7 +232,7 @@ def pid(self) -> Optional[int]: try: return int(read_from(self.pid_file).strip()) except Exception as e: - warnings.warn(f'read PID failed: {e}') + warnings.warn(f"read PID failed: {e}") return None def set_asynchronous(self, asynchronous: bool): @@ -210,6 +248,7 @@ def set_asynchronous(self, asynchronous: bool): from utilmeta.core.orm.databases.config import DatabaseConnections from utilmeta.core.cache.config import CacheConnections + dbs = self.get_config(DatabaseConnections) if dbs: for alias, database in dbs.databases.items(): @@ -236,25 +275,28 @@ def set_backend(self, backend): elif isinstance(backend, type) and issubclass(backend, ServerAdaptor): self.adaptor = backend(self) backend = backend.backend - backend_name = getattr(backend, '__name__', str(backend)) + backend_name = getattr(backend, "__name__", str(backend)) elif inspect.ismodule(backend): - backend_name = getattr(backend, '__name__', str(backend)) + backend_name = getattr(backend, "__name__", str(backend)) else: # maybe an application - module = getattr(backend, '__module__', None) + module = getattr(backend, "__module__", None) if module and callable(backend): # application application = backend - backend_name = str(module).split('.')[0] + backend_name = str(module).split(".")[0] backend = import_obj(backend_name) else: - raise TypeError(f'Invalid service backend: {repr(backend)}, ' - f'must be a supported module or application') + raise TypeError( + f"Invalid service backend: {repr(backend)}, " + f"must be a supported module or application" + ) if backend: - backend_version = getattr(backend, '__version__', None) + backend_version = getattr(backend, "__version__", None) if backend_version is None: from importlib.metadata import version + backend_version = version(backend_name) self.backend = backend @@ -269,7 +311,9 @@ def set_backend(self, backend): if not isinstance(self._application, self.adaptor.application_cls): self._application = None - warnings.warn(f'Replacing server backend from [{self.adaptor.backend}] to [{self.backend_name}]') + warnings.warn( + f"Replacing server backend from [{self.adaptor.backend}] to [{self.backend_name}]" + ) # if not self.adaptor: self.adaptor = ServerAdaptor.dispatch(self) @@ -281,13 +325,17 @@ def set_backend(self, backend): if application and self.adaptor.application_cls: if not isinstance(application, self.adaptor.application_cls): - raise ValueError(f'Invalid application for {repr(self.backend_name)}: {application}') + raise ValueError( + f"Invalid application for {repr(self.backend_name)}: {application}" + ) def __repr__(self): - return f'UtilMeta({repr(self.module_name)}, ' \ - f'name={repr(self.name)}, ' \ - f'backend={self.backend}, ' \ - f'version={self.version}, background={self.background})' + return ( + f"UtilMeta({repr(self.module_name)}, " + f"name={repr(self.name)}, " + f"backend={self.backend}, " + f"version={self.version}, background={self.background})" + ) def __str__(self): return self.__repr__() @@ -295,29 +343,34 @@ def __str__(self): @property def version_str(self): if isinstance(self.version, str): - return self.version or '0.1.0' + return self.version or "0.1.0" if not isinstance(self.version, tuple): - return '0.1.0' + return "0.1.0" parts = [] for i, v in enumerate(self.version): parts.append(str(v)) if i < len(self.version) - 1: if isinstance(self.version[i + 1], int): - parts.append('.') + parts.append(".") else: - parts.append('-') - return ''.join(parts) + parts.append("-") + return "".join(parts) def register_command(self, command_cls, name: str = None): from utilmeta.bin.base import BaseCommand + if not issubclass(command_cls, BaseCommand): - raise TypeError(f'UtilMeta: Invalid command class: {command_cls} to register, ' - f'must be subclass of BaseCommand') + raise TypeError( + f"UtilMeta: Invalid command class: {command_cls} to register, " + f"must be subclass of BaseCommand" + ) if name: if name in self.commands: if self.commands[name] != command_cls: - raise ValueError(f'UtilMeta: conflict command' - f' [{repr(name)}]: {command_cls}, {self.commands[name]}') + raise ValueError( + f"UtilMeta: conflict command" + f" [{repr(name)}]: {command_cls}, {self.commands[name]}" + ) return self.commands[name] = command_cls else: @@ -342,6 +395,7 @@ def get_config(self, config_class: Type[T]) -> Optional[T]: def get_client(self, live: bool = False, backend=None, **kwargs): from utilmeta.core.cli.base import Client + return Client(service=self, internal=not live, backend=backend, **kwargs) def setup(self): @@ -367,9 +421,11 @@ def startup(self): if isinstance(config, Config): r = config.on_startup(self) if inspect.isawaitable(r): - raise ValueError(f'detect awaitable config setup: {config}, you should use async ' - f'backend such as starlette / sanic / tornado') - for func in self.events.get('startup', []): + raise ValueError( + f"detect awaitable config setup: {config}, you should use async " + f"backend such as starlette / sanic / tornado" + ) + for func in self.events.get("startup", []): func() @awaitable(startup) @@ -379,7 +435,7 @@ async def startup(self): r = config.on_startup(self) if inspect.isawaitable(r): await r - for func in self.events.get('startup', []): + for func in self.events.get("startup", []): r = func() if inspect.isawaitable(r): await r @@ -388,7 +444,7 @@ def shutdown(self): for cls, config in self.configs.items(): if isinstance(config, Config): config.on_shutdown(self) - for func in self.events.get('shutdown', []): + for func in self.events.get("shutdown", []): func() @awaitable(shutdown) @@ -398,23 +454,25 @@ async def shutdown(self): r = config.on_shutdown(self) if inspect.isawaitable(r): await r - for func in self.events.get('shutdown', []): + for func in self.events.get("shutdown", []): r = func() if inspect.isawaitable(r): await r def on_startup(self, f): if callable(f): - self.events.setdefault('startup', []).append(f) + self.events.setdefault("startup", []).append(f) def on_shutdown(self, f): if callable(f): - self.events.setdefault('shutdown', []).append(f) + self.events.setdefault("shutdown", []).append(f) - def mount(self, api=None, route: str = ''): + def mount(self, api=None, route: str = ""): if not api: + def deco(_api): return self.mount(_api, route=route) + return deco elif isinstance(api, str): pass @@ -423,25 +481,29 @@ def deco(_api): else: # try to mount a wsgi/asgi app if not route: - raise ValueError('Mounting applications required not-empty route') + raise ValueError("Mounting applications required not-empty route") if not self.adaptor: - raise ValueError('UtilMeta: backend is required to mount applications') + raise ValueError("UtilMeta: backend is required to mount applications") self.adaptor.mount(api, route=route) return if self._root_api: - if getattr(self.root_api, '__ref__', str(self.root_api)) != getattr(api, '__ref__', str(api)): - raise ValueError(f'UtilMeta: root api conflicted: {api}, {self.root_api}, ' - f'you can only mount a service once') + if getattr(self.root_api, "__ref__", str(self.root_api)) != getattr( + api, "__ref__", str(api) + ): + raise ValueError( + f"UtilMeta: root api conflicted: {api}, {self.root_api}, " + f"you can only mount a service once" + ) return self.root_api = api - self.root_url = str(route).strip('/') + self.root_url = str(route).strip("/") def mount_to_api(self, api, route: str, eager: bool = False): if not inspect.isclass(api) and issubclass(api, API): - raise TypeError(f'Invalid API: {api}') + raise TypeError(f"Invalid API: {api}") - route = str(route).strip('/') + route = str(route).strip("/") if not eager and not self._root_api: # if not eagerly mount @@ -475,7 +537,7 @@ def resolve(self) -> Type[API]: return self._root_api if self._root_api_ref: ref = self._root_api_ref - if '.' not in ref: + if "." not in ref: # in current module root_api = getattr(self.module, ref) else: @@ -487,22 +549,37 @@ def resolve(self) -> Type[API]: # some ext API like OperationsAPI might be mounted to class RootAPI(API): pass + self.root_api = RootAPI return RootAPI - raise ValueError('utilmeta.service: RootAPI not mounted') + raise ValueError("utilmeta.service: RootAPI not mounted") return self._root_api - @ignore_errors # just ignore some unicode error happened in win + @ignore_errors # just ignore some unicode error happened in win def print_info(self): from utilmeta import __version__ from utilmeta.bin.constant import BLUE, GREEN, DOT - print(BLUE % '|', f'UtilMeta v{__version__} starting service [%s]' % (BLUE % self.name)) - print(BLUE % '|', ' version:', self.version_str) - print(BLUE % '|', ' stage:', (BLUE % f'{DOT} production') if self.production else (GREEN % f'{DOT} debug')) - print(BLUE % '|', ' backend:', f'{self.backend_name} ({self.backend_version})', - (BLUE % f'| asynchronous') if self.asynchronous else '') - print(BLUE % '|', ' base url:', f'{self.base_url}') - print('') + + print( + BLUE % "|", + f"UtilMeta v{__version__} starting service [%s]" % (BLUE % self.name), + ) + print(BLUE % "|", " version:", self.version_str) + print( + BLUE % "|", + " stage:", + (BLUE % f"{DOT} production") + if self.production + else (GREEN % f"{DOT} debug"), + ) + print( + BLUE % "|", + " backend:", + f"{self.backend_name} ({self.backend_version})", + (BLUE % f"| asynchronous") if self.asynchronous else "", + ) + print(BLUE % "|", " base url:", f"{self.base_url}") + print("") def resolve_port(self): if self.port: @@ -510,6 +587,7 @@ def resolve_port(self): host = self.host or LOCAL_IP import socket + if self.adaptor and self.adaptor.DEFAULT_PORT: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: if s.connect_ex((host, self.adaptor.DEFAULT_PORT)) != 0: @@ -525,7 +603,7 @@ def resolve_port(self): def run(self, **kwargs): if not self.adaptor: - raise NotImplementedError('UtilMeta: service backend not specified') + raise NotImplementedError("UtilMeta: service backend not specified") self.resolve_port() self.print_info() self.setup() @@ -540,11 +618,11 @@ def write_pid(self): try: write_to(self.pid_file, str(os.getpid())) except Exception as e: - warnings.warn(f'write PID to {self.pid_file} failed: {e}') + warnings.warn(f"write PID to {self.pid_file} failed: {e}") def application(self): if not self.adaptor: - raise NotImplementedError('UtilMeta: service backend not specified') + raise NotImplementedError("UtilMeta: service backend not specified") self.setup() app = self.adaptor.application() self._application = app @@ -555,14 +633,14 @@ def get_origin(self, no_localhost: bool = False, force_ip: bool = False): if no_localhost and localhost(host) or force_ip: host = self.ip port = self.port or (self.adaptor.DEFAULT_PORT if self.adaptor else None) - if port == 80 and self.scheme == 'http': + if port == 80 and self.scheme == "http": port = None - elif port == 443 and self.scheme == 'https': + elif port == 443 and self.scheme == "https": port = None if port: if self.host_addr and self.host_addr.version == 6: - host = f'[{host}]' - host += f':{port}' + host = f"[{host}]" + host += f":{port}" return f'{self.scheme or "http"}://{host}' @property @@ -580,7 +658,7 @@ def origin(self): @property def base_url(self): if self.root_url: - return self.origin + '/' + self.root_url + return self.origin + "/" + self.root_url return self.origin @property @@ -590,6 +668,7 @@ def auto_created(self): @property def pool(self): from utilmeta.conf.pool import ThreadPool + pool = self.get_config(ThreadPool) if not pool: pool = ThreadPool() @@ -599,4 +678,5 @@ def pool(self): @cached_property def ip(self): from utilmeta.utils import get_server_ip + return get_server_ip() diff --git a/utilmeta/core/websocket/__init__.py b/utilmeta/core/websocket/__init__.py index c3eb433..0d0b76e 100644 --- a/utilmeta/core/websocket/__init__.py +++ b/utilmeta/core/websocket/__init__.py @@ -1,4 +1,3 @@ from .base import Websocket from .request import WebsocketRequest from .properties import ClientEvent, ServerEvent - diff --git a/utilmeta/ops/__init__.py b/utilmeta/ops/__init__.py index 80cd063..e2eafe8 100644 --- a/utilmeta/ops/__init__.py +++ b/utilmeta/ops/__init__.py @@ -1,4 +1,4 @@ -__spec_version__ = '0.4.0' -__website__ = 'https://ops.utilmeta.com' +__spec_version__ = "0.4.0" +__website__ = "https://ops.utilmeta.com" from .config import Operations diff --git a/utilmeta/ops/aggregation.py b/utilmeta/ops/aggregation.py index d5b90ea..53f25b0 100644 --- a/utilmeta/ops/aggregation.py +++ b/utilmeta/ops/aggregation.py @@ -13,7 +13,7 @@ def user_comp_hash(user_id): compress=True, case_insensitive=False, consistent=True, - mod=2 ** 24 + mod=2**24, ) @@ -38,11 +38,7 @@ def get_agent_dist(logs: models.QuerySet, exists: bool = False): num = qs.exists() if exists else qs.count() if num: device_dist[device] = num - return dict( - os_dist=os_dist, - browser_dist=browser_dist, - device_dist=device_dist - ) + return dict(os_dist=os_dist, browser_dist=browser_dist, device_dist=device_dist) except (DatabaseError, FieldError): # raise FieldError when json key lookup is not supported return {} @@ -50,27 +46,27 @@ def get_agent_dist(logs: models.QuerySet, exists: bool = False): def aggregate_endpoint_logs(service: str, to_time: datetime, layer: int = 0): from .log import _endpoints_map + result = {} for ident, endpoint in _endpoints_map.items(): if not endpoint.remote_id: continue data = aggregate_logs( - service=service, - to_time=to_time, - layer=layer, - endpoint_ident=ident + service=service, to_time=to_time, layer=layer, endpoint_ident=ident ) if data: result[endpoint.remote_id] = data return result -def aggregate_logs(service: str, - to_time: datetime, - layer: int = 0, - include_users: Union[bool, int] = False, - include_ips: Union[bool, int] = False, - endpoint_ident: str = None): +def aggregate_logs( + service: str, + to_time: datetime, + layer: int = 0, + include_users: Union[bool, int] = False, + include_ips: Union[bool, int] = False, + endpoint_ident: str = None, +): timespan = [timedelta(hours=1), timedelta(days=1)][layer] start = to_time - timespan @@ -79,11 +75,8 @@ def aggregate_logs(service: str, # gte ~ lt only apply to layer0, in upper layer it's gt ~ lte from .models import ServiceLog, WorkerMonitor - query = dict( - service=service, - time__gte=start, - time__lt=current - ) + + query = dict(service=service, time__gte=start, time__lt=current) if endpoint_ident: query.update(endpoint_ident=endpoint_ident) @@ -91,59 +84,67 @@ def aggregate_logs(service: str, if endpoint_ident or layer: total_data = service_logs.aggregate( - requests=models.Count('id'), - avg_time=models.Avg('duration'), + requests=models.Count("id"), + avg_time=models.Avg("duration"), ) - total_requests = requests = total_data.get('requests') or 0 - avg_time = total_data.get('avg_time') or 0 + total_requests = requests = total_data.get("requests") or 0 + avg_time = total_data.get("avg_time") or 0 else: worker_qs = WorkerMonitor.objects.filter( - worker__instance__service=service, - time__gte=start, - time__lt=current + worker__instance__service=service, time__gte=start, time__lt=current ) requests = service_logs.count() - total_requests = worker_qs.aggregate(v=models.Sum('requests'))['v'] or requests + total_requests = worker_qs.aggregate(v=models.Sum("requests"))["v"] or requests if total_requests: - avg_time = worker_qs.aggregate( - v=models.Sum(models.F('avg_time') * models.F('requests'), - output_field=models.DecimalField()) / total_requests)['v'] or 0 + avg_time = ( + worker_qs.aggregate( + v=models.Sum( + models.F("avg_time") * models.F("requests"), + output_field=models.DecimalField(), + ) + / total_requests + )["v"] + or 0 + ) else: avg_time = 0 if not total_requests: return - errors = service_logs.filter(models.Q(level=LogLevel.ERROR) | models.Q(level__iexact='ERROR')).count() + errors = service_logs.filter( + models.Q(level=LogLevel.ERROR) | models.Q(level__iexact="ERROR") + ).count() dict_values = {} aggregate_info = [] if not layer: - aggregate_info.extend([ - ('status', 'statuses', None), - ('level', 'levels', None), - ]) + aggregate_info.extend( + [ + ("status", "statuses", None), + ("level", "levels", None), + ] + ) if include_users: - aggregate_info.append([ - ('user_id', 'user_dist', include_users) - ]) + aggregate_info.append([("user_id", "user_dist", include_users)]) if include_ips: - aggregate_info.append([ - ('ip', 'ip_dist', include_ips) - ]) + aggregate_info.append([("ip", "ip_dist", include_ips)]) for name, field, limit in aggregate_info: - qs = service_logs.exclude(**{name: None}).values( - name).annotate(count=models.Count('id')) + qs = ( + service_logs.exclude(**{name: None}) + .values(name) + .annotate(count=models.Count("id")) + ) if isinstance(limit, int): - qs = qs.order_by('-count')[:limit] - dicts = {val[name]: val['count'] for val in qs} - if name == 'user_id': + qs = qs.order_by("-count")[:limit] + dicts = {val[name]: val["count"] for val in qs} + if name == "user_id": dicts = {user_comp_hash(k): v for k, v in dicts.items()} dict_values[field] = dicts - service_logs_duration = service_logs.order_by('-duration') + service_logs_duration = service_logs.order_by("-duration") mean_time = service_logs_duration[requests // 2].duration if requests else 0 p95_time = service_logs_duration[requests // 20].duration if requests else 0 p99_time = service_logs_duration[requests // 100].duration if requests else 0 @@ -156,24 +157,24 @@ def aggregate_logs(service: str, p99_time=p99_time, p999_time=p999_time, **dict_values, - **replace_null(service_logs.aggregate( - time_stddev=models.StdDev('duration'), - uv=models.Count('user_id', distinct=True), - ip=models.Count('ip', distinct=True), - )), + **replace_null( + service_logs.aggregate( + time_stddev=models.StdDev("duration"), + uv=models.Count("user_id", distinct=True), + ip=models.Count("ip", distinct=True), + ) + ), ) agent_dist = get_agent_dist(service_logs) if not endpoint_ident else {} # get MAX RPS max_rps = ( - service_logs.annotate( - second=TruncSecond('time') - ) - .values('second') - .annotate(request_count=models.Count('id')) - .values_list('request_count', flat=True) - .order_by('-request_count') + service_logs.annotate(second=TruncSecond("time")) + .values("second") + .annotate(request_count=models.Count("id")) + .values_list("request_count", flat=True) + .order_by("-request_count") .first() ) @@ -189,11 +190,13 @@ def aggregate_logs(service: str, max_rps=max_rps, **agent_dist, **dict_values, - **replace_null(service_logs.aggregate( - in_traffic=models.Sum('in_traffic'), - out_traffic=models.Sum('out_traffic'), - time_stddev=models.StdDev('duration'), - uv=models.Count('user_id', distinct=True), - ip=models.Count('ip', distinct=True), - )), + **replace_null( + service_logs.aggregate( + in_traffic=models.Sum("in_traffic"), + out_traffic=models.Sum("out_traffic"), + time_stddev=models.StdDev("duration"), + uv=models.Count("user_id", distinct=True), + ip=models.Count("ip", distinct=True), + ) + ), ) diff --git a/utilmeta/ops/api/__init__.py b/utilmeta/ops/api/__init__.py index c03bf53..15da182 100644 --- a/utilmeta/ops/api/__init__.py +++ b/utilmeta/ops/api/__init__.py @@ -14,22 +14,24 @@ from .log import LogAPI from .servers import ServersAPI from .token import TokenAPI -from .utils import opsRequire, WrappedResponse, config, supervisor_var, \ - SupervisorObject, resources_var, access_token_var +from .utils import ( + opsRequire, + WrappedResponse, + config, + supervisor_var, + SupervisorObject, + resources_var, + access_token_var, +) from ..log import request_logger, Logger -NO_CACHES = ['no-cache', 'no-store', 'max-age=0'] +NO_CACHES = ["no-cache", "no-store", "max-age=0"] @api.CORS( - allow_origin='*', - allow_headers=[ - 'authorization', - 'cache-control', - 'x-utilmeta-node-id', - 'x-node-id' - ], - cors_max_age=3600 * 6 + allow_origin="*", + allow_headers=["authorization", "cache-control", "x-utilmeta-node-id", "x-node-id"], + cors_max_age=3600 * 6, ) class OperationsAPI(api.API): __external__ = True @@ -53,11 +55,11 @@ class OperationsAPI(api.API): # openapi = OpenAPI(service)() @api.get - @opsRequire('api.view') + @opsRequire("api.view") def openapi(self): - cache_control = self.request.headers.get('Cache-Control') + cache_control = self.request.headers.get("Cache-Control") if cache_control and any(h in cache_control for h in NO_CACHES): - openapi = config.load_openapi(no_store='no-store' in cache_control) + openapi = config.load_openapi(no_store="no-store" in cache_control) else: openapi = config.openapi return response.Response(openapi) @@ -82,12 +84,13 @@ def post(self, data: SupervisorData = request.Body): # this is critical # if the POST /api/ops redirect to GET /api/ops by 301 # the supervisor will not notice the difference by the result data if this field is not filled - **self.get() + **self.get(), ) def get(self): try: - from utilmeta import service # noqa + from utilmeta import service # noqa + name = service.name except ImportError: # raise exceptions.ServerError('service not initialized') @@ -99,13 +102,15 @@ def get(self): ) @adapt_async(close_conn=config.db_alias) - @opsRequire('service.config') + @opsRequire("service.config") def patch(self, data: SupervisorPatch = request.Body): supervisor: SupervisorObject = supervisor_var.getter(self.request) if not supervisor or not supervisor.id: - raise exceptions.NotFound('Supervisor not found', state='supervisor_not_found') + raise exceptions.NotFound( + "Supervisor not found", state="supervisor_not_found" + ) if supervisor.node_id != data.node_id: - raise exceptions.BadRequest('Inconsistent supervisor node_id') + raise exceptions.BadRequest("Inconsistent supervisor node_id") data.id = supervisor.id # backup_urls # base_url @@ -120,70 +125,75 @@ def patch(self, data: SupervisorPatch = request.Body): # task_settings: dict # aggregate_settings: dict data.save() - return dict( - node_id=data.node_id, - **self.get() - ) + return dict(node_id=data.node_id, **self.get()) @adapt_async(close_conn=config.db_alias) - @opsRequire('service.delete') + @opsRequire("service.delete") def delete(self): supervisor: SupervisorObject = supervisor_var.getter(self.request) if supervisor: if supervisor.init_key: # this supervisor is not marked as delete - raise exceptions.BadRequest('Supervisor not marked as deleted', state='delete_failed') + raise exceptions.BadRequest( + "Supervisor not marked as deleted", state="delete_failed" + ) if supervisor.node_id: from utilmeta import service from utilmeta.ops import models + for model in models.supervisor_related_models: try: model.objects.filter( node_id=supervisor.node_id, - ).update( - node_id=None, - service=service.name - ) + ).update(node_id=None, service=service.name) except EmptyResultSet: continue if config.node_id: from utilmeta.bin.utils import update_meta_ini_file - update_meta_ini_file(node=None) # clear local node_id + + update_meta_ini_file(node=None) # clear local node_id Supervisor.objects.filter(pk=supervisor.id).delete() return 1 - raise exceptions.NotFound('Supervisor not found', state='supervisor_not_found') + raise exceptions.NotFound("Supervisor not found", state="supervisor_not_found") - @api.before('*', excludes=(get, post)) - def handle_token(self, node_id: str = request.HeaderParam( - 'X-UtilMeta-Node-ID', - alias_from=['x-node-id'], - default=None - )): + @api.before("*", excludes=(get, post)) + def handle_token( + self, + node_id: str = request.HeaderParam( + "X-UtilMeta-Node-ID", alias_from=["x-node-id"], default=None + ), + ): type, token = self.request.authorization - node_id = node_id or self.request.query.get('node') + node_id = node_id or self.request.query.get("node") if not token: if not config.local_disabled and config.is_local: from utilmeta import service - if str(self.request.ip_address) == '127.0.0.1': + + if str(self.request.ip_address) == "127.0.0.1": # LOCAL -> LOCAL MANAGE try: - supervisor = SupervisorObject.init(Supervisor.objects.filter( - node_id=node_id, - disabled=False, - local=True, - ops_api=config.ops_api, - )) + supervisor = SupervisorObject.init( + Supervisor.objects.filter( + node_id=node_id, + disabled=False, + local=True, + ops_api=config.ops_api, + ) + ) supervisor_var.setter(self.request, supervisor) except orm.EmptyQueryset: - supervisor_var.setter(self.request, SupervisorObject( - id=None, - service=service.name, - node_id=None, - disabled=False, - ident=None, - local=True, - ops_api=config.ops_api, - )) + supervisor_var.setter( + self.request, + SupervisorObject( + id=None, + service=service.name, + node_id=None, + disabled=False, + ident=None, + local=True, + ops_api=config.ops_api, + ), + ) pass # raise exceptions.Unauthorized var.scopes.setter(self.request, config.local_scope) @@ -191,7 +201,7 @@ def handle_token(self, node_id: str = request.HeaderParam( raise exceptions.Unauthorized # node can also be included in the query params to avoid additional headers if not node_id: - raise exceptions.BadRequest('Node ID required', state='node_required') + raise exceptions.BadRequest("Node ID required", state="node_required") validated = False for supervisor in SupervisorObject.serialize( Supervisor.objects.filter( @@ -199,56 +209,61 @@ def handle_token(self, node_id: str = request.HeaderParam( # we don't use service name as identifier # that might not be synced disabled=False, - public_key__isnull=False + public_key__isnull=False, ) ): try: token_data = decode_token(token, public_key=supervisor.public_key) except ValueError: - raise exceptions.BadRequest('Invalid token format', state='token_expired') + raise exceptions.BadRequest( + "Invalid token format", state="token_expired" + ) if not token_data: continue - token_node_id = token_data.get('nid') + token_node_id = token_data.get("nid") if token_node_id != node_id: - raise exceptions.Conflict(f'Invalid node id') - issuer = token_data.get('iss') or '' + raise exceptions.Conflict(f"Invalid node id") + issuer = token_data.get("iss") or "" if not str(supervisor.base_url).startswith(issuer): - raise exceptions.Conflict(f'Invalid token issuer: {repr(issuer)}') - audience = token_data.get('aud') or '' + raise exceptions.Conflict(f"Invalid token issuer: {repr(issuer)}") + audience = token_data.get("aud") or "" if not config.ops_api.startswith(audience): # todo: log, but not force to reject pass - expires = token_data.get('exp') + expires = token_data.get("exp") if not expires: - raise exceptions.UnprocessableEntity('Invalid token: no expires') + raise exceptions.UnprocessableEntity("Invalid token: no expires") if self.request.time.timestamp() > expires: - raise exceptions.BadRequest('Invalid token: expired', state='token_expired') + raise exceptions.BadRequest( + "Invalid token: expired", state="token_expired" + ) # SCOPE ---------------------------- - scope = token_data.get('scope') or '' - scopes = scope.split(' ') if ' ' in scope else scope.split(',') + scope = token_data.get("scope") or "" + scopes = scope.split(" ") if " " in scope else scope.split(",") scope_names = [] resources = [] for name in scopes: - if ':' in name: - name, resource = name.split(':') + if ":" in name: + name, resource = name.split(":") resources.append(resource) scope_names.append(name) var.scopes.setter(self.request, scope_names) resources_var.setter(self.request, resources) # ------------------------------------- - token_id = token_data.get('jti') or '' + token_id = token_data.get("jti") or "" if not token_id: - raise exceptions.BadRequest('Invalid token: id required', state='token_expired') + raise exceptions.BadRequest( + "Invalid token: id required", state="token_expired" + ) try: token_obj = AccessTokenSchema.init( AccessToken.objects.filter( - token_id=token_id, - issuer_id=supervisor.id + token_id=token_id, issuer_id=supervisor.id ) ) except orm.EmptyQueryset: @@ -258,7 +273,9 @@ def handle_token(self, node_id: str = request.HeaderParam( if token_obj.revoked: # force revoked # e.g. the subject permissions has changed after the token issued - raise exceptions.BadRequest('Invalid token: revoked', state='token_expired') + raise exceptions.BadRequest( + "Invalid token: revoked", state="token_expired" + ) token_obj.last_activity = self.request.time token_obj.used_times += 1 token_obj.save() @@ -267,17 +284,19 @@ def handle_token(self, node_id: str = request.HeaderParam( token_obj = AccessTokenSchema( token_id=token_id, issuer_id=supervisor.id, - issued_at=datetime.fromtimestamp(token_data.get('iat')), + issued_at=datetime.fromtimestamp(token_data.get("iat")), expiry_time=datetime.fromtimestamp(expires), - subject=token_data.get('sub'), + subject=token_data.get("sub"), last_activity=self.request.time, used_times=1, ip=str(self.request.ip_address), - scope=scopes + scope=scopes, ) token_obj.save() except IntegrityError: - raise exceptions.BadRequest('Invalid token: id duplicated', state='token_expired') + raise exceptions.BadRequest( + "Invalid token: id duplicated", state="token_expired" + ) # set context vars # scope @@ -288,11 +307,13 @@ def handle_token(self, node_id: str = request.HeaderParam( break if not validated: - raise exceptions.BadRequest('Supervisor not found', state='supervisor_not_found') + raise exceptions.BadRequest( + "Supervisor not found", state="supervisor_not_found" + ) - @api.handle('*') + @api.handle("*") def handle_errors(self, e: Error): if isinstance(e.exception, DatabaseError): # do not expose the state of database error - e.exc = exceptions.ServerError('server error') + e.exc = exceptions.ServerError("server error") return self.response(request=self.request, error=e) diff --git a/utilmeta/ops/api/data.py b/utilmeta/ops/api/data.py index 81e49f9..b17ea46 100644 --- a/utilmeta/ops/api/data.py +++ b/utilmeta/ops/api/data.py @@ -11,7 +11,7 @@ class QuerySchema(utype.Schema): # id_list: list = None query: dict = {} - orders: List[str] = ['pk'] + orders: List[str] = ["pk"] rows: int = utype.Field(default=10, le=100, ge=1) page: int = utype.Field(default=1, ge=1) fields: list = [] @@ -40,15 +40,15 @@ class DataAPI(api.API): # using: str = request.QueryParam(default=None) def get_model(self): - if '.' not in self.model: + if "." not in self.model: return None # security check tables = self.get_tables() for table in tables: - if table.get('ref') == self.model: + if table.get("ref") == self.model: if table.model: return table.model - raise exceptions.BadRequest(f'Invalid model: {self.model}') + raise exceptions.BadRequest(f"Invalid model: {self.model}") # deprecate the import usage as it maybe dangerous def __init__(self, *args, **kwargs): @@ -65,7 +65,7 @@ def adaptor(self): try: self._adaptor = ModelAdaptor.dispatch(self.model_class) except NotImplementedError: - raise exceptions.BadRequest(f'Invalid model: {self.model}') + raise exceptions.BadRequest(f"Invalid model: {self.model}") return self._adaptor def parse_result(self, data, max_length: Optional[int] = None): @@ -75,7 +75,7 @@ def parse_result(self, data, max_length: Optional[int] = None): return data elif isinstance(data, dict): for k in list(data.keys()): - if k == 'pk': + if k == "pk": continue field = self.adaptor.get_field(k) if config.is_secret(k) and not field.related_model: @@ -87,40 +87,45 @@ def parse_result(self, data, max_length: Optional[int] = None): return data return reduce_value(data, max_length=max_length) - @api.get('tables') - @opsRequire('data.view') + @api.get("tables") + @opsRequire("data.view") def get_tables(self) -> List[TableSchema]: global _tables if _tables is not None: return _tables from ..resources import ResourcesManager + _tables = ResourcesManager().get_tables(with_model=True) return _tables # scope: data.view:[TABLE_IDENT] - @api.post('query') - @opsRequire('data.query') + @api.post("query") + @opsRequire("data.query") @adapt_async(close_conn=True) # close all connections def query_data(self, query: QuerySchema = request.Body): try: unsliced_qs = self.adaptor.get_queryset(**query.query) count = unsliced_qs.count() - qs = unsliced_qs.order_by(*query.orders)[(query.page - 1) * query.rows: query.page * query.rows] + qs = unsliced_qs.order_by(*query.orders)[ + (query.page - 1) * query.rows : query.page * query.rows + ] fields = query.fields if not fields: - fields = ['pk'] + [f.column_name for f in self.adaptor.get_fields( - many=False, no_inherit=True) if f.column_name] + fields = ["pk"] + [ + f.column_name + for f in self.adaptor.get_fields(many=False, no_inherit=True) + if f.column_name + ] values = self.adaptor.values(qs, *fields) except self.adaptor.field_errors as e: raise exceptions.BadRequest(str(e)) from e return self.response( - self.parse_result(values, max_length=query.max_length), - count=count + self.parse_result(values, max_length=query.max_length), count=count ) - @api.post('create') - @opsRequire('data.create') + @api.post("create") + @opsRequire("data.create") @adapt_async(close_conn=True) # close all connections def create_data(self, data: CreateDataSchema = request.Body): @@ -133,31 +138,30 @@ def create_data(self, data: CreateDataSchema = request.Body): values = self.adaptor.values(qs, *data.return_fields) return self.parse_result(values, max_length=data.return_max_length) - @api.post('update') - @opsRequire('data.update') + @api.post("update") + @opsRequire("data.update") @adapt_async(close_conn=True) # close all connections def update_data(self, data: UpdateDataSchema = request.Body): for val in data.data: - pk = pop(val, 'pk') + pk = pop(val, "pk") if pk: self.adaptor.update(val, pk=pk) - def delete_data(self, - id: str = request.BodyParam - # query: dict = request.BodyParam, - # limit: Optional[int] = request.BodyParam(None) - ): + def delete_data( + self, + id: str = request.BodyParam + # query: dict = request.BodyParam, + # limit: Optional[int] = request.BodyParam(None) + ): # qs = self.adaptor.get_queryset(**query) # if limit is not None: # qs = qs.order_by('pk')[:limit] return self.adaptor.delete(pk=id) - @api.post('delete') - @opsRequire('data.delete') + @api.post("delete") + @opsRequire("data.delete") @awaitable(delete_data) - async def delete_data(self, - id: str = request.BodyParam - ): + async def delete_data(self, id: str = request.BodyParam): # apply for async CASCADE return await self.adaptor.adelete(pk=id) diff --git a/utilmeta/ops/api/log.py b/utilmeta/ops/api/log.py index 00b62d6..3cada79 100644 --- a/utilmeta/ops/api/log.py +++ b/utilmeta/ops/api/log.py @@ -8,20 +8,14 @@ from django.db import models from utype.types import * -AGGREGATION_FIELDS = [ - 'method', - 'level', - 'status', - 'request_type' - 'response_type' -] +AGGREGATION_FIELDS = ["method", "level", "status", "request_type" "response_type"] class LogAPI(api.API): supervisor: SupervisorObject = supervisor_var response = WrappedResponse - @opsRequire('log.view') + @opsRequire("log.view") def get(self, id: int) -> ServiceLogSchema: try: return ServiceLogSchema.init(id) @@ -32,12 +26,12 @@ class LogQuery(orm.Query[ServiceLog]): __distinct__ = False offset: int = orm.Offset(default=None) page: int = orm.Page() - rows: int = orm.Limit(default=20, le=100, alias_from=['limit']) + rows: int = orm.Limit(default=20, le=100, alias_from=["limit"]) endpoint_ident: str = orm.Filter() endpoint_like: str = orm.Filter( - query=lambda v: models.Q(endpoint_ident__icontains=v) | - models.Q(endpoint_ref__icontains=v) | - models.Q(path__icontains=v) + query=lambda v: models.Q(endpoint_ident__icontains=v) + | models.Q(endpoint_ref__icontains=v) + | models.Q(path__icontains=v) ) method: str = orm.Filter(query=lambda v: models.Q(method__iexact=v)) level: str = orm.Filter(query=lambda v: models.Q(level__iexact=v)) @@ -47,78 +41,85 @@ class LogQuery(orm.Query[ServiceLog]): status: int = orm.Filter() status_gte: int = orm.Filter(query=lambda v: models.Q(status__gte=v)) status_lte: int = orm.Filter(query=lambda v: models.Q(status__lte=v)) - time_gte: datetime = orm.Filter(query=lambda v: models.Q(time__gte=v), alias_from=['time>=']) - time_lte: datetime = orm.Filter(query=lambda v: models.Q(time__lte=v), alias_from=['time<=']) + time_gte: datetime = orm.Filter( + query=lambda v: models.Q(time__gte=v), alias_from=["time>="] + ) + time_lte: datetime = orm.Filter( + query=lambda v: models.Q(time__lte=v), alias_from=["time<="] + ) - admin: bool = orm.Filter(query=lambda v: models.Q(access_token__isnull=not v), default=False) + admin: bool = orm.Filter( + query=lambda v: models.Q(access_token__isnull=not v), default=False + ) request_type: str = orm.Filter() response_type: str = orm.Filter() start: int = orm.Filter(query=lambda v: models.Q(time__gte=v)) end: int = orm.Filter(query=lambda v: models.Q(time__lte=v)) - order: str = orm.OrderBy({ - ServiceLog.time: orm.Order(), - ServiceLog.duration: orm.Order(), - ServiceLog.in_traffic: orm.Order(), - ServiceLog.out_traffic: orm.Order(), - ServiceLog.length: orm.Order(), - }, default='-time') + order: str = orm.OrderBy( + { + ServiceLog.time: orm.Order(), + ServiceLog.duration: orm.Order(), + ServiceLog.in_traffic: orm.Order(), + ServiceLog.out_traffic: orm.Order(), + ServiceLog.length: orm.Order(), + }, + default="-time", + ) @property def log_q(self): if self.supervisor.node_id: - q = models.Q(node_id=self.supervisor.node_id) | models.Q(service=self.supervisor.service) + q = models.Q(node_id=self.supervisor.node_id) | models.Q( + service=self.supervisor.service + ) else: q = models.Q(service=self.supervisor.service) return q - @opsRequire('log.view') + @opsRequire("log.view") @api.get @adapt_async(close_conn=config.db_alias) def service(self, query: LogQuery): base_qs = ServiceLog.objects.filter(self.log_q) - logs = ServiceLogBase.serialize( - query.get_queryset(base_qs) - ) + logs = ServiceLogBase.serialize(query.get_queryset(base_qs)) if config.log.hide_ip_address or config.log.hide_user_id: for log in logs: if config.log.hide_ip_address: - log.ip = '*.*.*.*' if log.ip else '' + log.ip = "*.*.*.*" if log.ip else "" if config.log.hide_user_id: - log.user_id = '***' if log.user_id else None - return self.response( - result=logs, - count=query.count(base_qs) - ) + log.user_id = "***" if log.user_id else None + return self.response(result=logs, count=query.count(base_qs)) - @opsRequire('log.view') - @api.get('service/values') + @opsRequire("log.view") + @api.get("service/values") @adapt_async(close_conn=config.db_alias) def service_log_values(self, query: LogQuery): base_qs = ServiceLog.objects.filter(self.log_q) - qs = query.get_queryset( - base_qs - ) + qs = query.get_queryset(base_qs) result = {} for field in AGGREGATION_FIELDS: mp = {} - for val in qs.exclude(**{field: None}).values(field).annotate( - count=models.Count('id')).order_by('-count'): - mp[val[field]] = val['count'] + for val in ( + qs.exclude(**{field: None}) + .values(field) + .annotate(count=models.Count("id")) + .order_by("-count") + ): + mp[val[field]] = val["count"] result[field] = mp return result - @opsRequire('log.delete') + @opsRequire("log.delete") @adapt_async(close_conn=config.db_alias) def delete(self, query: LogQuery): qs = query.get_queryset( ServiceLog.objects.filter( - service=self.supervisor.service, - node_id=self.supervisor.node_id + service=self.supervisor.service, node_id=self.supervisor.node_id ) ) qs.delete() - @opsRequire('log.view') + @opsRequire("log.view") @api.get @adapt_async(close_conn=config.db_alias) def realtime( @@ -130,51 +131,71 @@ def realtime( users: int = 0, endpoints: int = 0, ): - logs = ServiceLog.objects.filter(self.log_q, time__gte=self.request.time - timedelta(seconds=within)) + logs = ServiceLog.objects.filter( + self.log_q, time__gte=self.request.time - timedelta(seconds=within) + ) if apis: logs = logs.filter(endpoint__ident__in=apis) aggregate_info = [] if ips: - aggregate_info.append(('ip', 'ip_dist', ips)) + aggregate_info.append(("ip", "ip_dist", ips)) if users: - aggregate_info.append(('user_id', 'users', users)) + aggregate_info.append(("user_id", "users", users)) if endpoints: - aggregate_info.append(('endpoint__ident', 'endpoints', endpoints)) + aggregate_info.append(("endpoint__ident", "endpoints", endpoints)) dict_values = {} for name, field, max_num in aggregate_info: - dict_values[field] = {val[name]: val['count'] for val in - logs.exclude(**{name: None}).values(name).annotate( - count=models.Count('id')).order_by('-count')[:max_num]} + dict_values[field] = { + val[name]: val["count"] + for val in logs.exclude(**{name: None}) + .values(name) + .annotate(count=models.Count("id")) + .order_by("-count")[:max_num] + } try: from django.db.models.functions.datetime import ExtractMinute - time_dist = convert_data_frame(list(logs.values( - min=ExtractMinute(self.request.time - models.F('time'))).annotate( - uv=models.Count('user_id', distinct=True), - ip=models.Count('ip', distinct=True) - ).order_by('-min').filter(min__lte=int(within / 60))), keys=['min', 'uv', 'ip']) + + time_dist = convert_data_frame( + list( + logs.values(min=ExtractMinute(self.request.time - models.F("time"))) + .annotate( + uv=models.Count("user_id", distinct=True), + ip=models.Count("ip", distinct=True), + ) + .order_by("-min") + .filter(min__lte=int(within / 60)) + ), + keys=["min", "uv", "ip"], + ) except ValueError: # sqlite does not support extract minute # ValueError: Extract requires native DurationField database support. from django.db.models import CharField from django.db.models.functions import Trunc - values = list(logs.annotate(min=Trunc('time', 'minute')).values('min').annotate( - uv=models.Count('user_id', distinct=True), - ip=models.Count('ip', distinct=True) - ).order_by('min')) + + values = list( + logs.annotate(min=Trunc("time", "minute")) + .values("min") + .annotate( + uv=models.Count("user_id", distinct=True), + ip=models.Count("ip", distinct=True), + ) + .order_by("min") + ) for val in values: - val['min'] = int((self.request.time - val['min']).total_seconds() / 60) - time_dist = convert_data_frame(values, keys=['min', 'uv', 'ip']) + val["min"] = int((self.request.time - val["min"]).total_seconds() / 60) + time_dist = convert_data_frame(values, keys=["min", "uv", "ip"]) return dict( time_dist=time_dist, **dict_values, **logs.aggregate( - avg_time=models.Avg('duration'), - requests=models.Count('id'), - errors=models.Count('id', filter=models.Q(level='ERROR')), - uv=models.Count('user_id', distinct=True), - ip=models.Count('ip', distinct=True), + avg_time=models.Avg("duration"), + requests=models.Count("id"), + errors=models.Count("id", filter=models.Q(level="ERROR")), + uv=models.Count("user_id", distinct=True), + ip=models.Count("ip", distinct=True), ) ) diff --git a/utilmeta/ops/api/servers.py b/utilmeta/ops/api/servers.py index bfde871..d20580d 100644 --- a/utilmeta/ops/api/servers.py +++ b/utilmeta/ops/api/servers.py @@ -2,12 +2,31 @@ from utilmeta.core import api, orm from .utils import SupervisorObject, supervisor_var, WrappedResponse, opsRequire, config -from utilmeta.utils import time_now, convert_data_frame, exceptions, adapt_async, cached_property -from ..query import (WorkerSchema, ServerMonitorSchema, WorkerMonitorSchema, - CacheMonitorSchema, DatabaseMonitorSchema, - InstanceMonitorSchema, DatabaseConnectionSchema) -from ..models import (ServerMonitor, Worker, InstanceMonitor, WorkerMonitor, Resource, - DatabaseMonitor, CacheMonitor) +from utilmeta.utils import ( + time_now, + convert_data_frame, + exceptions, + adapt_async, + cached_property, +) +from ..query import ( + WorkerSchema, + ServerMonitorSchema, + WorkerMonitorSchema, + CacheMonitorSchema, + DatabaseMonitorSchema, + InstanceMonitorSchema, + DatabaseConnectionSchema, +) +from ..models import ( + ServerMonitor, + Worker, + InstanceMonitor, + WorkerMonitor, + Resource, + DatabaseMonitor, + CacheMonitor, +) from django.db import models from utype.types import * from utilmeta.core.orm import DatabaseConnections @@ -15,35 +34,53 @@ from django.db.models.functions import TruncMinute, TruncHour, TruncDate system_metrics_keys = [ - 'used_memory', - 'cpu_percent', - 'memory_percent', - 'disk_percent', - 'file_descriptors', - 'active_net_connections', - 'total_net_connections', - 'open_files' + "used_memory", + "cpu_percent", + "memory_percent", + "disk_percent", + "file_descriptors", + "active_net_connections", + "total_net_connections", + "open_files", ] service_metrics_keys = [ - 'in_traffic', - 'out_traffic', - 'requests', - 'rps', - 'errors', + "in_traffic", + "out_traffic", + "requests", + "rps", + "errors", # error requests made from current service to target instance - 'avg_time' + "avg_time", ] -server_metrics_keys = [*system_metrics_keys, 'load_avg_1', 'load_avg_5', 'load_avg_15'] -worker_metrics_keys = [*system_metrics_keys, *service_metrics_keys, 'threads'] -instance_metrics_keys = [*system_metrics_keys, *service_metrics_keys, 'threads', 'avg_workers'] -database_metrics_keys = ['used_space', 'server_used_space', - 'active_connections', 'current_connections', - 'server_connections', 'new_transactions'] -cache_metrics_key = ['cpu_percent', 'memory_percent', 'used_memory', - 'file_descriptors', 'open_files', 'current_connections', 'total_connections', 'qps'] -sum_keys = ['in_traffic', 'out_traffic', 'requests'] +server_metrics_keys = [*system_metrics_keys, "load_avg_1", "load_avg_5", "load_avg_15"] +worker_metrics_keys = [*system_metrics_keys, *service_metrics_keys, "threads"] +instance_metrics_keys = [ + *system_metrics_keys, + *service_metrics_keys, + "threads", + "avg_workers", +] +database_metrics_keys = [ + "used_space", + "server_used_space", + "active_connections", + "current_connections", + "server_connections", + "new_transactions", +] +cache_metrics_key = [ + "cpu_percent", + "memory_percent", + "used_memory", + "file_descriptors", + "open_files", + "current_connections", + "total_connections", + "qps", +] +sum_keys = ["in_traffic", "out_traffic", "requests"] class ResourceData(orm.Schema[Resource]): @@ -66,51 +103,53 @@ class ResourceData(orm.Schema[Resource]): class ServerResource(ResourceData): @property def ip(self) -> Optional[str]: - return self.data.get('ip') + return self.data.get("ip") @property def cpu_num(self) -> Optional[int]: - return self.data.get('cpu_num') + return self.data.get("cpu_num") @property def memory_total(self) -> Optional[int]: - return self.data.get('memory_total') + return self.data.get("memory_total") @property def disk_total(self) -> Optional[int]: - return self.data.get('disk_total') + return self.data.get("disk_total") @property def system(self) -> Optional[str]: - return self.data.get('system') + return self.data.get("system") @property def hostname(self) -> Optional[str]: - return self.data.get('hostname') + return self.data.get("hostname") @property def platform(self) -> Optional[dict]: - return self.data.get('platform') + return self.data.get("platform") class DatabaseResource(ResourceData): - connections: List[DatabaseConnectionSchema] = orm.Field('database_connections', default_factory=list) + connections: List[DatabaseConnectionSchema] = orm.Field( + "database_connections", default_factory=list + ) @property def connected(self) -> bool: - return self.data.get('connected') or False + return self.data.get("connected") or False @property def max_server_connections(self) -> int: - return self.data.get('max_server_connections') or 0 + return self.data.get("max_server_connections") or 0 @property def used_space(self) -> int: - return self.data.get('used_space') or 0 + return self.data.get("used_space") or 0 @property def transactions(self) -> int: - return self.data.get('transactions') or 0 + return self.data.get("transactions") or 0 @cached_property # @utype.Field(no_output=True) @@ -137,11 +176,11 @@ def port(self): class CacheResource(ResourceData): @property def connected(self) -> bool: - return self.data.get('connected') or False + return self.data.get("connected") or False @property def pid(self) -> Optional[int]: - return self.data.get('pid') + return self.data.get("pid") @cached_property # @utype.Field(no_output=True) @@ -164,42 +203,42 @@ def port(self): class InstanceResource(ResourceData): @property def backend(self) -> Optional[str]: - return self.data.get('backend') + return self.data.get("backend") @property def backend_version(self) -> Optional[str]: - return self.data.get('backend_version') + return self.data.get("backend_version") @property def version(self) -> Optional[str]: - return self.data.get('version') + return self.data.get("version") @property def spec_version(self) -> Optional[str]: - return self.data.get('spec_version') + return self.data.get("spec_version") @property def asynchronous(self) -> Optional[bool]: - return self.data.get('asynchronous') + return self.data.get("asynchronous") @property def production(self) -> Optional[bool]: - return self.data.get('production') + return self.data.get("production") @property def language(self) -> Optional[str]: - return self.data.get('language') + return self.data.get("language") @property def language_version(self) -> Optional[str]: - return self.data.get('language_version') + return self.data.get("language_version") @property def utilmeta_version(self) -> Optional[str]: - return self.data.get('utilmeta_version') + return self.data.get("utilmeta_version") -@opsRequire('metrics.view') +@opsRequire("metrics.view") class ServersAPI(api.API): supervisor: SupervisorObject = supervisor_var response = WrappedResponse @@ -212,15 +251,17 @@ class BaseQuery(orm.Query): # time > cursor start: datetime = orm.Filter(query=lambda v: models.Q(time__gte=v)) end: datetime = orm.Filter(query=lambda v: models.Q(time__lte=v)) - within_hours: int = orm.Filter(query=lambda v: models.Q( - time__gte=time_now() - timedelta(hours=v))) - within_days: int = orm.Filter(query=lambda v: models.Q( - time__gte=time_now() - timedelta(days=v))) + within_hours: int = orm.Filter( + query=lambda v: models.Q(time__gte=time_now() - timedelta(hours=v)) + ) + within_days: int = orm.Filter( + query=lambda v: models.Q(time__gte=time_now() - timedelta(days=v)) + ) limit: int = utype.Field(default=1000, le=1000, ge=0) class ServerMonitorQuery(BaseQuery[ServerMonitor]): # server_id: str = orm.Filter('server.remote_id') - server_id: str = orm.Filter(required=True, alias_from=['server']) + server_id: str = orm.Filter(required=True, alias_from=["server"]) layer: int = orm.Filter(default=0) def get_resources(self, type: str, id: str = None, **query): @@ -231,8 +272,9 @@ def get_resources(self, type: str, id: str = None, **query): if id: q &= models.Q(remote_id=id) else: - if type != 'server': + if type != "server": from utilmeta import service + q = models.Q(service=service.name) else: q = models.Q() @@ -249,33 +291,35 @@ def get_resources(self, type: str, id: str = None, **query): def get(self) -> List[ServerResource]: return ServerResource.serialize( self.get_resources( - type='server', + type="server", ) ) @api.get @adapt_async(close_conn=config.db_alias) def metrics(self, query: ServerMonitorQuery): - server = self.get_resources( - type='server', - id=query.server_id - ).first() + server = self.get_resources(type="server", id=query.server_id).first() if not server: - raise exceptions.NotFound('server not found') + raise exceptions.NotFound("server not found") query.server_id = server.pk return self.get_metrics_result( qs=query.get_queryset(), metrics_cls=ServerMonitorSchema, limit=query.limit, sample_interval=query.sample_interval, - metrics_keys=server_metrics_keys + metrics_keys=server_metrics_keys, ) @classmethod - def get_metrics_result(cls, qs, metrics_cls, limit: int, - metrics_keys: List[str], - sample_interval: int = None): - order = '-time' + def get_metrics_result( + cls, + qs, + metrics_cls, + limit: int, + metrics_keys: List[str], + sample_interval: int = None, + ): + order = "-time" trunc_func = None result = None @@ -283,29 +327,44 @@ def process_value(number): return round(number, 2) if isinstance(number, (float, Decimal)) else number if sample_interval: - trunc_func = { - 60: TruncMinute, - 3600: TruncHour, - 3600 * 24: TruncDate - }.get(sample_interval) + trunc_func = {60: TruncMinute, 3600: TruncHour, 3600 * 24: TruncDate}.get( + sample_interval + ) if trunc_func: - order = '-t' - qs = qs.annotate(t=trunc_func('time')).values('t').annotate( - **{'__' + key: models.Sum(key) if key in sum_keys else models.Avg(key) for key in metrics_keys} + order = "-t" + qs = ( + qs.annotate(t=trunc_func("time")) + .values("t") + .annotate( + **{ + "__" + key: models.Sum(key) + if key in sum_keys + else models.Avg(key) + for key in metrics_keys + } + ) ) else: # qs = qs.annotate(t=models) cursor = None val = {} result = [] - for value in list(qs.order_by(order).values('time', *metrics_keys)): - ts = int(value['time'].timestamp()) + for value in list(qs.order_by(order).values("time", *metrics_keys)): + ts = int(value["time"].timestamp()) _ts = ts - ts % sample_interval if _ts != cursor: if val and ts: - result.append({'time': cursor, **{k: process_value( - sum(v) if k in sum_keys else sum(v) / len(v) - ) for k, v in val.items()}}) + result.append( + { + "time": cursor, + **{ + k: process_value( + sum(v) if k in sum_keys else sum(v) / len(v) + ) + for k, v in val.items() + }, + } + ) val = {} if len(result) >= limit: break @@ -319,9 +378,15 @@ def process_value(number): if trunc_func: result = [] for value in list(qs): - result.append({'time': value['t'], - **{key.lstrip('__'): process_value(val) - for key, val in value.items()}}) + result.append( + { + "time": value["t"], + **{ + key.lstrip("__"): process_value(val) + for key, val in value.items() + }, + } + ) else: result = [] else: @@ -342,19 +407,16 @@ class WorkerQuery(orm.Query[Worker]): @api.get @adapt_async(close_conn=config.db_alias) def workers(self, query: WorkerQuery): - instance = self.get_resources( - type='instance', - id=query.instance_id - ).first() + instance = self.get_resources(type="instance", id=query.instance_id).first() if not instance: - raise exceptions.NotFound('instance not found') + raise exceptions.NotFound("instance not found") query.instance_id = instance.pk return WorkerSchema.serialize(query) class WorkerMonitorQuery(BaseQuery[WorkerMonitor]): worker_id: int = orm.Filter(required=True) - @api.get('worker/metrics') + @api.get("worker/metrics") @adapt_async(close_conn=config.db_alias) def worker_metrics(self, query: WorkerMonitorQuery): return self.get_metrics_result( @@ -362,7 +424,7 @@ def worker_metrics(self, query: WorkerMonitorQuery): metrics_cls=WorkerMonitorSchema, limit=query.limit, sample_interval=query.sample_interval, - metrics_keys=worker_metrics_keys + metrics_keys=worker_metrics_keys, ) @api.get @@ -370,22 +432,19 @@ def worker_metrics(self, query: WorkerMonitorQuery): def instances(self) -> List[InstanceResource]: return InstanceResource.serialize( self.get_resources( - type='instance', + type="instance", ) ) class InstanceMonitorQuery(BaseQuery[InstanceMonitor]): instance_id: str = orm.Filter(required=True) - @api.get('instance/metrics') + @api.get("instance/metrics") @adapt_async(close_conn=config.db_alias) def instance_metrics(self, query: InstanceMonitorQuery) -> dict: - instance = self.get_resources( - type='instance', - id=query.instance_id - ).first() + instance = self.get_resources(type="instance", id=query.instance_id).first() if not instance: - raise exceptions.NotFound('instance not found') + raise exceptions.NotFound("instance not found") query.instance_id = instance.pk # return convert_data_frame(InstanceMonitorSchema.serialize( # query @@ -395,7 +454,7 @@ def instance_metrics(self, query: InstanceMonitorQuery) -> dict: metrics_cls=InstanceMonitorSchema, limit=query.limit, sample_interval=query.sample_interval, - metrics_keys=instance_metrics_keys + metrics_keys=instance_metrics_keys, ) @api.get @@ -405,10 +464,7 @@ def databases(self) -> List[DatabaseResource]: if not db_config: return [] return DatabaseResource.serialize( - self.get_resources( - type='database', - ident__in=list(db_config.databases) - ) + self.get_resources(type="database", ident__in=list(db_config.databases)) ) @api.get @@ -418,10 +474,7 @@ def caches(self) -> List[CacheResource]: if not cache_config: return [] return CacheResource.serialize( - self.get_resources( - type='cache', - ident__in=list(cache_config.caches) - ) + self.get_resources(type="cache", ident__in=list(cache_config.caches)) ) class DatabaseMonitorQuery(BaseQuery[DatabaseMonitor]): @@ -430,15 +483,12 @@ class DatabaseMonitorQuery(BaseQuery[DatabaseMonitor]): class CacheMonitorQuery(BaseQuery[CacheMonitor]): cache_id: str = orm.Filter(required=True) - @api.get('database/metrics') + @api.get("database/metrics") @adapt_async(close_conn=config.db_alias) def database_metrics(self, query: DatabaseMonitorQuery) -> dict: - db = self.get_resources( - type='database', - id=query.database_id - ).first() + db = self.get_resources(type="database", id=query.database_id).first() if not db: - raise exceptions.NotFound('database not found') + raise exceptions.NotFound("database not found") query.database_id = db.pk # return convert_data_frame(DatabaseMonitorSchema.serialize( # query @@ -451,15 +501,12 @@ def database_metrics(self, query: DatabaseMonitorQuery) -> dict: metrics_keys=database_metrics_keys, ) - @api.get('cache/metrics') + @api.get("cache/metrics") @adapt_async(close_conn=config.db_alias) def cache_metrics(self, query: CacheMonitorQuery) -> dict: - cache = self.get_resources( - type='cache', - id=query.cache_id - ).first() + cache = self.get_resources(type="cache", id=query.cache_id).first() if not cache: - raise exceptions.NotFound('cache not found') + raise exceptions.NotFound("cache not found") query.cache_id = cache.pk # return convert_data_frame(CacheMonitorSchema.serialize( # query diff --git a/utilmeta/ops/api/token.py b/utilmeta/ops/api/token.py index 827b7da..54df516 100644 --- a/utilmeta/ops/api/token.py +++ b/utilmeta/ops/api/token.py @@ -45,40 +45,42 @@ def scope(self) -> List[str]: return request.var.scopes.getter(self.request) or [] @api.get - @opsRequire('token.view') + @opsRequire("token.view") @adapt_async(close_conn=config.db_alias) def get(self, query: AccessTokenQuery) -> List[AccessTokenSchema]: if not self.supervisor.id or not self.supervisor.node_id: - raise exceptions.NotFound('Supervisor not found', state='supervisor_not_found') + raise exceptions.NotFound( + "Supervisor not found", state="supervisor_not_found" + ) query.issuer_id = self.supervisor.id - return AccessTokenSchema.serialize( - query - ) + return AccessTokenSchema.serialize(query) @api.post - @opsRequire('token.revoke') + @opsRequire("token.revoke") @adapt_async(close_conn=config.db_alias) # this token will be generated and send directly from supervisor def revoke(self, id_list: List[str] = request.Body) -> int: if not self.supervisor.id or not self.supervisor.node_id: - raise exceptions.NotFound('Supervisor not found', state='supervisor_not_found') - exists = list(AccessToken.objects.filter( - token_id__in=id_list, - issuer_id=self.supervisor.id - ).values_list('token_id', flat=True)) + raise exceptions.NotFound( + "Supervisor not found", state="supervisor_not_found" + ) + exists = list( + AccessToken.objects.filter( + token_id__in=id_list, issuer_id=self.supervisor.id + ).values_list("token_id", flat=True) + ) for token_id in set(id_list).difference(exists): AccessToken.objects.create( token_id=token_id, issuer_id=self.supervisor.id, expiry_time=self.request.time + timedelta(days=1), - revoked=True + revoked=True, ) if exists: AccessToken.objects.filter( - token_id__in=id_list, - issuer_id=self.supervisor.id + token_id__in=id_list, issuer_id=self.supervisor.id ).update(revoked=True) return len(exists) diff --git a/utilmeta/ops/api/utils.py b/utilmeta/ops/api/utils.py index a7cd6de..cfd5c79 100644 --- a/utilmeta/ops/api/utils.py +++ b/utilmeta/ops/api/utils.py @@ -21,22 +21,24 @@ class SupervisorObject(orm.Schema[Supervisor]): # excludes = var.RequestContextVar('_excludes', cached=True) # params = var.RequestContextVar('_params', cached=True) -supervisor_var = var.RequestContextVar('_ops.supervisor', cached=True) -access_token_var = var.RequestContextVar('_ops.access_token', cached=True) -resources_var = var.RequestContextVar('_scopes.resource', cached=True, default=list) +supervisor_var = var.RequestContextVar("_ops.supervisor", cached=True) +access_token_var = var.RequestContextVar("_ops.access_token", cached=True) +resources_var = var.RequestContextVar("_scopes.resource", cached=True, default=list) config = Operations.config() class WrappedResponse(response.Response): - result_key = 'result' - message_key = 'msg' - state_key = 'state' - count_key = 'count' + result_key = "result" + message_key = "msg" + state_key = "state" + count_key = "count" class opsRequire(auth.Require): def validate_scopes(self, req: request.Request): if config.disabled_scope and config.disabled_scope.intersection(self.scopes): - raise exceptions.PermissionDenied(f'Operation: {self.scopes} denied by config') + raise exceptions.PermissionDenied( + f"Operation: {self.scopes} denied by config" + ) return super().validate_scopes(req) diff --git a/utilmeta/ops/client.py b/utilmeta/ops/client.py index 5413bb0..22b8924 100644 --- a/utilmeta/ops/client.py +++ b/utilmeta/ops/client.py @@ -6,20 +6,31 @@ from utype.types import * from .key import encrypt_data -from .schema import (NodeMetadata, SupervisorBasic, ServiceInfoSchema, SupervisorInfoSchema, \ - SupervisorData, ResourcesSchema, ResourcesData, NodeInfoSchema, InstanceResourceSchema, - SupervisorPatchSchema, OpenAPISchema, TableSchema) +from .schema import ( + NodeMetadata, + SupervisorBasic, + ServiceInfoSchema, + SupervisorInfoSchema, + SupervisorData, + ResourcesSchema, + ResourcesData, + NodeInfoSchema, + InstanceResourceSchema, + SupervisorPatchSchema, + OpenAPISchema, + TableSchema, +) class SupervisorResponse(response.Response): - result_key = 'result' - message_key = 'msg' - state_key = 'state' - count_key = 'count' + result_key = "result" + message_key = "msg" + state_key = "state" + count_key = "count" class SupervisorListResponse(SupervisorResponse): - name = 'list' + name = "list" result: List[SupervisorBasic] @@ -28,17 +39,17 @@ class OpenAPIResponse(response.Response): class InstanceResponse(SupervisorResponse): - name = 'instance' + name = "instance" result: List[InstanceResourceSchema] class TableResponse(SupervisorResponse): - name = 'table' + name = "table" result: List[TableSchema] class SupervisorInfoResponse(SupervisorResponse): - name = 'info' + name = "info" result: SupervisorInfoSchema def validate(self): @@ -52,7 +63,7 @@ def validate(self): class NodeInfoResponse(SupervisorResponse): - name = 'add_node' + name = "add_node" result: NodeInfoSchema def validate(self): @@ -63,7 +74,7 @@ def validate(self): class ServiceInfoResponse(SupervisorResponse): - name = 'info' + name = "info" result: ServiceInfoSchema def validate(self): @@ -74,7 +85,7 @@ def validate(self): class SupervisorResourcesResponse(SupervisorResponse): - name = 'resources' + name = "resources" result: ResourcesData @@ -84,79 +95,101 @@ class ReportResult(utype.Schema): class SupervisorNodeResponse(SupervisorResponse): - name = 'add_node' + name = "add_node" result: Optional[SupervisorData] = None class SupervisorReportResponse(SupervisorResponse): - name = 'report' + name = "report" result: ReportResult class SupervisorBatchReportResponse(SupervisorResponse): - name = 'batch_report' + name = "batch_report" result: List[dict] + # class AddNodeResponse(SupervisorResponse): # name = 'info' # result: InfoSchema class SupervisorClient(Client): - @api.post('/') - def add_node(self, data: NodeMetadata = request.Body) -> Union[SupervisorNodeResponse, SupervisorResponse]: pass + @api.post("/") + def add_node( + self, data: NodeMetadata = request.Body + ) -> Union[SupervisorNodeResponse, SupervisorResponse]: + pass - @api.post('/') - async def async_add_node(self, data: NodeMetadata = request.Body) \ - -> Union[SupervisorNodeResponse, SupervisorResponse]: pass + @api.post("/") + async def async_add_node( + self, data: NodeMetadata = request.Body + ) -> Union[SupervisorNodeResponse, SupervisorResponse]: + pass - @api.delete('/') - def delete_node(self) -> SupervisorResponse: pass + @api.delete("/") + def delete_node(self) -> SupervisorResponse: + pass - @api.post('/resources') - def upload_resources(self, data: ResourcesSchema = request.Body) \ - -> Union[SupervisorResourcesResponse, SupervisorResponse]: pass + @api.post("/resources") + def upload_resources( + self, data: ResourcesSchema = request.Body + ) -> Union[SupervisorResourcesResponse, SupervisorResponse]: + pass - @api.post('/resources') - async def async_upload_resources(self, data: ResourcesSchema = request.Body) \ - -> Union[SupervisorResourcesResponse, SupervisorResponse]: pass + @api.post("/resources") + async def async_upload_resources( + self, data: ResourcesSchema = request.Body + ) -> Union[SupervisorResourcesResponse, SupervisorResponse]: + pass - @api.get('/list') - def get_supervisors(self) -> Union[SupervisorListResponse, SupervisorResponse]: pass + @api.get("/list") + def get_supervisors(self) -> Union[SupervisorListResponse, SupervisorResponse]: + pass - @api.get('/list') - async def async_get_supervisors(self) -> Union[SupervisorListResponse, SupervisorResponse]: pass + @api.get("/list") + async def async_get_supervisors( + self, + ) -> Union[SupervisorListResponse, SupervisorResponse]: + pass - @api.get('/') - def get_info(self) -> Union[SupervisorInfoResponse, SupervisorResponse]: pass + @api.get("/") + def get_info(self) -> Union[SupervisorInfoResponse, SupervisorResponse]: + pass - @api.get('/') - async def async_get_info(self) -> Union[SupervisorInfoResponse, SupervisorResponse]: pass + @api.get("/") + async def async_get_info(self) -> Union[SupervisorInfoResponse, SupervisorResponse]: + pass - @api.post('/report') - def report_analytics(self, data: dict = request.Body) -> Union[SupervisorReportResponse, SupervisorResponse]: + @api.post("/report") + def report_analytics( + self, data: dict = request.Body + ) -> Union[SupervisorReportResponse, SupervisorResponse]: pass - @api.post('/report') - async def async_report_analytics(self, data: dict = request.Body)\ - -> Union[SupervisorReportResponse, SupervisorResponse]: + @api.post("/report") + async def async_report_analytics( + self, data: dict = request.Body + ) -> Union[SupervisorReportResponse, SupervisorResponse]: pass - @api.post('/report/batch') - def batch_report_analytics(self, data: list = request.Body) -> ( - Union)[SupervisorBatchReportResponse, SupervisorResponse]: + @api.post("/report/batch") + def batch_report_analytics( + self, data: list = request.Body + ) -> (Union)[SupervisorBatchReportResponse, SupervisorResponse]: pass - @api.post('/report/batch') - async def async_batch_report_analytics(self, data: list = request.Body) -> ( - Union)[SupervisorBatchReportResponse, SupervisorResponse]: + @api.post("/report/batch") + async def async_batch_report_analytics( + self, data: list = request.Body + ) -> (Union)[SupervisorBatchReportResponse, SupervisorResponse]: pass - @api.post('/alert') + @api.post("/alert") def alert_incident(self): pass - @api.post('/alert') + @api.post("/alert") async def async_alert_incident(self): pass @@ -167,46 +200,45 @@ async def async_alert_incident(self): # return r.data # return [] - def __init__(self, - access_key: str = None, - cluster_key: str = None, - cluster_id: str = None, - node_id: str = None, - service_id: str = None, - node_key: str = None, - **kwargs): + def __init__( + self, + access_key: str = None, + cluster_key: str = None, + cluster_id: str = None, + node_id: str = None, + service_id: str = None, + node_key: str = None, + **kwargs, + ): super().__init__(**kwargs) headers = {} if access_key: # only required in ADD_NODE operation - headers.update({ - 'X-Access-Key': access_key, - }) + headers.update( + { + "X-Access-Key": access_key, + } + ) if cluster_key: - headers.update({ - 'X-Cluster-Key': cluster_key - }) + headers.update({"X-Cluster-Key": cluster_key}) if cluster_id: - headers.update({ - 'X-Cluster-Id': cluster_id - }) + headers.update({"X-Cluster-Id": cluster_id}) if node_id: - headers.update({ - 'X-Node-ID': node_id - }) + headers.update({"X-Node-ID": node_id}) from .models import Supervisor + supervisor: Supervisor = Supervisor.objects.filter( node_id=node_id, ).first() if not supervisor: - raise ValueError(f'Supervisor for node ID [{node_id}] not exists') + raise ValueError(f"Supervisor for node ID [{node_id}] not exists") if not node_key: if supervisor.disabled: - raise ValueError('supervisor is disabled') + raise ValueError("supervisor is disabled") if supervisor.public_key: node_key = supervisor.public_key @@ -214,21 +246,16 @@ def __init__(self, self._base_url = supervisor.base_url if node_key: - headers.update({ - 'X-Node-Key': node_key - }) + headers.update({"X-Node-Key": node_key}) if service_id: - headers.update({ - 'X-Service-ID': service_id - }) + headers.update({"X-Service-ID": service_id}) from .config import Operations + config = Operations.config() if config: if config.proxy and config.proxy.forward: - self.update_base_headers({ - 'x-utilmeta-proxy-type': 'forward' - }) + self.update_base_headers({"x-utilmeta-proxy-type": "forward"}) self._base_url = config.proxy.proxy_url else: config.check_supervisor(self._base_url) @@ -253,75 +280,94 @@ def process_request(self, req: request.Request): try: encrypted = encrypt_data(req.body, public_key=pub_key) except Exception as e: - raise ValueError(f'Invalid Operations access key, encode body failed with error: {e}') + raise ValueError( + f"Invalid Operations access key, encode body failed with error: {e}" + ) req.body = encrypted # set request body return req class OperationsClient(Client): - def __init__( - self, - token: str = None, - node_id: str = None, - **kwargs - ): + def __init__(self, token: str = None, node_id: str = None, **kwargs): super().__init__(**kwargs) self.token = token self.node_id = node_id if self.token: - self.update_base_headers({ - 'authorization': f'Bearer {self.token}' - }) + self.update_base_headers({"authorization": f"Bearer {self.token}"}) if self.node_id: - self.update_base_headers({ - 'x-node-id': self.node_id - }) + self.update_base_headers({"x-node-id": self.node_id}) - @api.post('/') - def add_supervisor(self, data: SupervisorData = request.Body) -> NodeInfoResponse: pass + @api.post("/") + def add_supervisor(self, data: SupervisorData = request.Body) -> NodeInfoResponse: + pass - @api.post('/') - async def async_add_supervisor(self, data: SupervisorData = request.Body) -> NodeInfoResponse: pass + @api.post("/") + async def async_add_supervisor( + self, data: SupervisorData = request.Body + ) -> NodeInfoResponse: + pass - @api.patch('/') - def update_supervisor(self, data: SupervisorPatchSchema = request.Body) -> NodeInfoResponse: pass + @api.patch("/") + def update_supervisor( + self, data: SupervisorPatchSchema = request.Body + ) -> NodeInfoResponse: + pass - @api.patch('/') - async def async_update_supervisor(self, data: SupervisorPatchSchema = request.Body) -> NodeInfoResponse: pass + @api.patch("/") + async def async_update_supervisor( + self, data: SupervisorPatchSchema = request.Body + ) -> NodeInfoResponse: + pass - @api.post('/token/revoke') - def revoke_token(self, id_list: List[str] = request.Body) -> SupervisorResponse[int]: pass + @api.post("/token/revoke") + def revoke_token( + self, id_list: List[str] = request.Body + ) -> SupervisorResponse[int]: + pass - @api.post('/token/revoke') - async def async_revoke_token(self, id_list: List[str] = request.Body) -> SupervisorResponse[int]: pass + @api.post("/token/revoke") + async def async_revoke_token( + self, id_list: List[str] = request.Body + ) -> SupervisorResponse[int]: + pass - @api.delete('/') - def delete_supervisor(self) -> SupervisorResponse: pass + @api.delete("/") + def delete_supervisor(self) -> SupervisorResponse: + pass - @api.delete('/') - async def async_delete_supervisor(self) -> SupervisorResponse: pass + @api.delete("/") + async def async_delete_supervisor(self) -> SupervisorResponse: + pass - @api.get('/openapi') - def get_openapi(self) -> OpenAPIResponse: pass + @api.get("/openapi") + def get_openapi(self) -> OpenAPIResponse: + pass - @api.get('/openapi') - async def async_get_openapi(self) -> OpenAPIResponse: pass + @api.get("/openapi") + async def async_get_openapi(self) -> OpenAPIResponse: + pass - @api.get('/data/tables') - def get_tables(self) -> TableResponse: pass + @api.get("/data/tables") + def get_tables(self) -> TableResponse: + pass - @api.get('/data/tables') - async def async_get_tables(self) -> TableResponse: pass + @api.get("/data/tables") + async def async_get_tables(self) -> TableResponse: + pass - @api.get('/servers/instances') - def get_instances(self) -> InstanceResponse: pass + @api.get("/servers/instances") + def get_instances(self) -> InstanceResponse: + pass - @api.get('/servers/instances') - async def async_get_instances(self) -> InstanceResponse: pass + @api.get("/servers/instances") + async def async_get_instances(self) -> InstanceResponse: + pass - @api.get('/') - def get_info(self) -> Union[ServiceInfoResponse, SupervisorResponse]: pass + @api.get("/") + def get_info(self) -> Union[ServiceInfoResponse, SupervisorResponse]: + pass - @api.get('/') - async def async_get_info(self) -> Union[ServiceInfoResponse, SupervisorResponse]: pass + @api.get("/") + async def async_get_info(self) -> Union[ServiceInfoResponse, SupervisorResponse]: + pass diff --git a/utilmeta/ops/cmd.py b/utilmeta/ops/cmd.py index 4c104f5..410aea1 100644 --- a/utilmeta/ops/cmd.py +++ b/utilmeta/ops/cmd.py @@ -17,10 +17,14 @@ def try_to_connect(timeout: int = 5): return if not config.is_local: webbrowser.open_new_tab(__website__) - print(RED % f'connection key required to connect non-local service, please login to ' - f'{__website__} and generate one') + print( + RED + % f"connection key required to connect non-local service, please login to " + f"{__website__} and generate one" + ) return from utilmeta.ops.client import OperationsClient, ServiceInfoResponse + t = time.time() live = False while True: @@ -33,26 +37,31 @@ def try_to_connect(timeout: int = 5): else: break if not live: - print(RED % 'meta connect: service not live or OperationsAPI not mounted, ' - f'please check your OperationsAPI: {config.ops_api} is accessible before connect') + print( + RED % "meta connect: service not live or OperationsAPI not mounted, " + f"please check your OperationsAPI: {config.ops_api} is accessible before connect" + ) return - local_manage_url = f'{__website__}/localhost?local_node={config.ops_api}' - print(f'OperationsAPI connected at {local_manage_url}') + local_manage_url = f"{__website__}/localhost?local_node={config.ops_api}" + print(f"OperationsAPI connected at {local_manage_url}") webbrowser.open_new_tab(local_manage_url) class OperationsCommand(BaseServiceCommand): - name = 'ops' + name = "ops" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.config = self.service.get_config(Operations) if not self.config: - print(RED % f'meta {self.arg_name}: Operations config not integrated to application, ' - 'please follow the document at https://docs.utilmeta.com/py/en/guide/ops/') + print( + RED + % f"meta {self.arg_name}: Operations config not integrated to application, " + "please follow the document at https://docs.utilmeta.com/py/en/guide/ops/" + ) exit(1) # self.settings.setup(self.service) - self.service.setup() # setup here + self.service.setup() # setup here @command def migrate_ops(self): @@ -60,9 +69,12 @@ def migrate_ops(self): Migrate all required tables for UtilMeta Operations to the database """ from django.core.management import execute_from_command_line - execute_from_command_line(['manage.py', 'migrate', 'ops', f'--database={self.config.db_alias}']) + + execute_from_command_line( + ["manage.py", "migrate", "ops", f"--database={self.config.db_alias}"] + ) # 2. migrate for main database - execute_from_command_line(['manage.py', 'migrate', 'ops']) + execute_from_command_line(["manage.py", "migrate", "ops"]) # @command # def check_connect(self): @@ -81,12 +93,13 @@ def migrate_ops(self): # f'is available at {BLUE % self.config.ops_api}') @command - def connect(self, - to: str = None, - key: str = Arg('--key', default=None), - force: bool = Arg('-f', default=None), - service: str = Arg('--service', default=None) - ): + def connect( + self, + to: str = None, + key: str = Arg("--key", default=None), + force: bool = Arg("-f", default=None), + service: str = Arg("--service", default=None), + ): """ Connect your API service to UtilMeta platform to manage """ @@ -95,71 +108,91 @@ def connect(self, # check if service is live from .client import OperationsClient, ServiceInfoResponse - info = OperationsClient(base_url=self.config.ops_api, fail_silently=True).get_info() + + info = OperationsClient( + base_url=self.config.ops_api, fail_silently=True + ).get_info() live = isinstance(info, ServiceInfoResponse) and info.validate() failed = info.is_aborted or info.status >= 500 if not live: if failed: - print(RED % 'meta connect: service is down, ' - f'please check your OperationsAPI: {self.config.ops_api} is accessible before connect') + print( + RED % "meta connect: service is down, " + f"please check your OperationsAPI: {self.config.ops_api} is accessible before connect" + ) else: - print(YELLOW % 'meta connect: OperationsAPI not mounted (or service not restarted), ' - f'please check your OperationsAPI: {self.config.ops_api} is accessible before connect') - print('If you have integrated Operations config, please restart your service and retry, ' - f'or add {BLUE % "-f"} to force this connect') + print( + YELLOW + % "meta connect: OperationsAPI not mounted (or service not restarted), " + f"please check your OperationsAPI: {self.config.ops_api} is accessible before connect" + ) + print( + "If you have integrated Operations config, please restart your service and retry, " + f'or add {BLUE % "-f"} to force this connect' + ) if not force: exit(1) if not key: # check if it is localhost if self.config.is_local: - local_manage_url = f'{__website__}/localhost?local_node={self.config.ops_api}' - print(f'OperationsAPI connected at {local_manage_url}') + local_manage_url = ( + f"{__website__}/localhost?local_node={self.config.ops_api}" + ) + print(f"OperationsAPI connected at {local_manage_url}") webbrowser.open_new_tab(local_manage_url) exit(0) if not self.config.is_local: if not self.config.proxy and self.config.proxy_required: - print(YELLOW % f'meta connect: it seems that you are using a private base_url: {self.config.base_url} ' - f'without setting ' - 'a proxy in Operations, this service will be unable to access in the platform') + print( + YELLOW + % f"meta connect: it seems that you are using a private base_url: {self.config.base_url} " + f"without setting " + "a proxy in Operations, this service will be unable to access in the platform" + ) if not self.config.is_secure: - print(YELLOW % f'meta connect: you are trying to connect an insecure node:' - f' {self.config.ops_api} (with HTTP protocol), ' - 'we strongly recommend using HTTPS protocol instead') + print( + YELLOW + % f"meta connect: you are trying to connect an insecure node:" + f" {self.config.ops_api} (with HTTP protocol), " + "we strongly recommend using HTTPS protocol instead" + ) if not key: webbrowser.open_new_tab(__website__) if live: - print(f'meta connect: OperationsAPI of [{self.service.name}] ' - f'is available at {BLUE % self.config.ops_api}') - print(RED % f'meta connect: --key is required to connect non-local service, please login to ' - f'{__website__} and generate one') + print( + f"meta connect: OperationsAPI of [{self.service.name}] " + f"is available at {BLUE % self.config.ops_api}" + ) + print( + RED + % f"meta connect: --key is required to connect non-local service, please login to " + f"{__website__} and generate one" + ) exit(1) from .connect import connect_supervisor - connect_supervisor( - key=key, - base_url=to, - service_id=service - ) + + connect_supervisor(key=key, base_url=to, service_id=service) @command - def delete_supervisor(self, node: str = Arg(required=True), key: str = Arg('--key', required=True)): + def delete_supervisor( + self, node: str = Arg(required=True), key: str = Arg("--key", required=True) + ): """ Connect your API service to UtilMeta platform to manage """ # self.migrate_ops() # before connect from .connect import delete_supervisor - delete_supervisor( - key=key, - node_id=node - ) + + delete_supervisor(key=key, node_id=node) @command - def sync(self, force: bool = Arg('-f', default=False)): + def sync(self, force: bool = Arg("-f", default=False)): """ Sync APIs and resources to supervisor """ @@ -173,147 +206,222 @@ def stats(self): """ self.config.migrate(with_default=True) from .log import setup_locals + setup_locals(self.config) from .client import OperationsClient, ServiceInfoResponse - info = OperationsClient(base_url=self.config.ops_api, fail_silently=True).get_info() + + info = OperationsClient( + base_url=self.config.ops_api, fail_silently=True + ).get_info() live = isinstance(info, ServiceInfoResponse) and info.validate() from utilmeta.utils import readable_size import utilmeta from . import __website__ from .log import _instance, _databases, _caches, _supervisor - stage_str = 'production' if self.service.production else 'debug' - status_str = (GREEN % f'{DOT} live') if live else (RED % f'{DOT} down') - supervisor_str = BLUE % f'{_supervisor.url}' if _supervisor else \ - f'not connected (connect at {__website__})' - print(BANNER % '{:<60}'.format(f'UtilMeta v{utilmeta.__version__} Operations Stats')) - print(f' Service Name: {self.service.name}', f'({self.service.title})' if self.service.title else '') - print(f' Service Version: {self.service.version_str}') - print(f' Service Status: {status_str} ({stage_str})') - print(f' Service Backend:', f'{self.service.backend_name} ({self.service.backend_version})', - (BLUE % f'| asynchronous') if self.service.asynchronous else '') - print(f' Service Base URL: {self.config.base_url}') - print(f' OperationsAPI URL: {self.config.ops_api}') - print(f' UtilMeta Node: {supervisor_str}') - - print(BANNER % '{:<60}'.format('Service Instance Stats')) + stage_str = "production" if self.service.production else "debug" + status_str = (GREEN % f"{DOT} live") if live else (RED % f"{DOT} down") + supervisor_str = ( + BLUE % f"{_supervisor.url}" + if _supervisor + else f"not connected (connect at {__website__})" + ) + print( + BANNER + % "{:<60}".format(f"UtilMeta v{utilmeta.__version__} Operations Stats") + ) + print( + f" Service Name: {self.service.name}", + f"({self.service.title})" if self.service.title else "", + ) + print(f" Service Version: {self.service.version_str}") + print(f" Service Status: {status_str} ({stage_str})") + print( + f" Service Backend:", + f"{self.service.backend_name} ({self.service.backend_version})", + (BLUE % f"| asynchronous") if self.service.asynchronous else "", + ) + print(f" Service Base URL: {self.config.base_url}") + print(f" OperationsAPI URL: {self.config.ops_api}") + print(f" UtilMeta Node: {supervisor_str}") + + print(BANNER % "{:<60}".format("Service Instance Stats")) from .models import InstanceMonitor, Worker, DatabaseMonitor, CacheMonitor - from .query import InstanceMonitorSchema, DatabaseMonitorSchema, CacheMonitorSchema, WorkerSchema + from .query import ( + InstanceMonitorSchema, + DatabaseMonitorSchema, + CacheMonitorSchema, + WorkerSchema, + ) from utilmeta.core import orm, cache + latest_monitor = None workers = [] if _instance: try: latest_monitor = InstanceMonitorSchema.init( InstanceMonitor.objects.filter( - instance=_instance, - layer=0 - ).order_by('-time') + instance=_instance, layer=0 + ).order_by("-time") ) except orm.EmptyQueryset: pass workers = WorkerSchema.serialize( - Worker.objects.filter(instance=_instance, connected=True).order_by('-requests') + Worker.objects.filter(instance=_instance, connected=True).order_by( + "-requests" + ) ) if latest_monitor: record_ago = int(time.time() - latest_monitor.time) - print(f' Stats Cycle: {latest_monitor.interval} seconds (recorded {record_ago}s ago)') - print(f' Requests: {latest_monitor.requests} ({latest_monitor.rps} per second)') + print( + f" Stats Cycle: {latest_monitor.interval} seconds (recorded {record_ago}s ago)" + ) + print( + f" Requests: {latest_monitor.requests} ({latest_monitor.rps} per second)" + ) if latest_monitor.errors: - print(RED % f' Errors: {latest_monitor.errors}') - print(f' Avg Time: {round(latest_monitor.avg_time, 1)} ms') - print(f' Traffic: {readable_size(latest_monitor.in_traffic)} In / ' - f'{readable_size(latest_monitor.out_traffic)} Out') - print(f' Used Memory: {readable_size(latest_monitor.used_memory)} ({latest_monitor.memory_percent}%)') - print(f' CPU: {latest_monitor.cpu_percent}%') - print(f' Net conn: {latest_monitor.total_net_connections} ' - f'({latest_monitor.active_net_connections} active)') + print(RED % f" Errors: {latest_monitor.errors}") + print(f" Avg Time: {round(latest_monitor.avg_time, 1)} ms") + print( + f" Traffic: {readable_size(latest_monitor.in_traffic)} In / " + f"{readable_size(latest_monitor.out_traffic)} Out" + ) + print( + f" Used Memory: {readable_size(latest_monitor.used_memory)} ({latest_monitor.memory_percent}%)" + ) + print(f" CPU: {latest_monitor.cpu_percent}%") + print( + f" Net conn: {latest_monitor.total_net_connections} " + f"({latest_monitor.active_net_connections} active)" + ) if workers: - print(BANNER % '{:<60}'.format('Service Instance Workers')) - fields = ('PID', 'Status', 'Threads', 'Requests', 'Avg Time', 'Traffic', 'CPU', 'Memory') + print(BANNER % "{:<60}".format("Service Instance Workers")) + fields = ( + "PID", + "Status", + "Threads", + "Requests", + "Avg Time", + "Traffic", + "CPU", + "Memory", + ) form = "{:<10}{:<15}{:<10}{:<25}{:<15}{:<25}{:<8}{:<10}" print(form.format(*fields)) - print('-' * 60) + print("-" * 60) for worker in workers: - print(form.format(worker.pid, - f'{DOT} {worker.status}', - worker.threads, - f'{worker.requests} ({worker.rps} per second)', - f'{worker.avg_time} ms', - f'{readable_size(worker.in_traffic)} In / {readable_size(worker.out_traffic)} Out', - f'{worker.cpu_percent}%', - f'{readable_size(worker.used_memory)} ({worker.memory_percent}%)' - )) + print( + form.format( + worker.pid, + f"{DOT} {worker.status}", + worker.threads, + f"{worker.requests} ({worker.rps} per second)", + f"{worker.avg_time} ms", + f"{readable_size(worker.in_traffic)} In / {readable_size(worker.out_traffic)} Out", + f"{worker.cpu_percent}%", + f"{readable_size(worker.used_memory)} ({worker.memory_percent}%)", + ) + ) if _databases: from utilmeta.core.orm import DatabaseConnections + db_config = DatabaseConnections.config() if db_config: - print(BANNER % '{:<60}'.format('Service Instance Databases')) - fields = ('Alias', 'Engine', 'Name', 'Connections', 'Space', 'Location') + print(BANNER % "{:<60}".format("Service Instance Databases")) + fields = ("Alias", "Engine", "Name", "Connections", "Space", "Location") form = "{:<15}{:<15}{:<15}{:<25}{:<15}{:<50}" print(form.format(*fields)) - print('-' * 60) + print("-" * 60) for alias, database in _databases.items(): db = db_config.get(alias) if not db: continue - conn_str = '' - space_str = '' - max_connections = database.data.get('max_server_connections') + conn_str = "" + space_str = "" + max_connections = database.data.get("max_server_connections") try: latest_monitor = DatabaseMonitorSchema.init( DatabaseMonitor.objects.filter( - database=database, - layer=0 - ).order_by('-time') + database=database, layer=0 + ).order_by("-time") ) except orm.EmptyQueryset: pass else: - conn_str = (f'{latest_monitor.current_connections} ' - f'({latest_monitor.active_connections} active)') + conn_str = ( + f"{latest_monitor.current_connections} " + f"({latest_monitor.active_connections} active)" + ) if max_connections: - conn_str += f' / {max_connections}' - space_str = f'{readable_size(latest_monitor.used_space)}' - print(form.format(alias, db.type or '-', - db.database_name or '-', conn_str, space_str, - db.location or '-')) + conn_str += f" / {max_connections}" + space_str = f"{readable_size(latest_monitor.used_space)}" + print( + form.format( + alias, + db.type or "-", + db.database_name or "-", + conn_str, + space_str, + db.location or "-", + ) + ) if _caches: cache_config = cache.CacheConnections.config() if cache_config: - print(BANNER % '{:<60}'.format('Service Instance Caches')) - fields = ('Alias', 'Engine', 'PID', 'Connections', 'Memory', 'CPU', 'Location') + print(BANNER % "{:<60}".format("Service Instance Caches")) + fields = ( + "Alias", + "Engine", + "PID", + "Connections", + "Memory", + "CPU", + "Location", + ) form = "{:<15}{:<15}{:<15}{:<25}{:<15}{:<15}{:<30}" print(form.format(*fields)) - print('-' * 60) + print("-" * 60) for alias, cache_obj in _caches.items(): cache = cache_config.get(alias) if not cache: continue - pid = cache_obj.data.get('pid') or '--' - mem_str = '' - conn_str = '' - cpu_str = '' - loc_str = f'{cache.host}:{cache.port}' if (cache.host and cache.port) else '' + pid = cache_obj.data.get("pid") or "--" + mem_str = "" + conn_str = "" + cpu_str = "" + loc_str = ( + f"{cache.host}:{cache.port}" + if (cache.host and cache.port) + else "" + ) try: latest_monitor = CacheMonitorSchema.init( CacheMonitor.objects.filter( - cache=cache_obj, - layer=0 - ).order_by('-time') + cache=cache_obj, layer=0 + ).order_by("-time") ) except orm.EmptyQueryset: pass else: - conn_str = (f'{latest_monitor.current_connections} ' - f'({latest_monitor.total_connections} total)') - mem_str = f'{readable_size(latest_monitor.used_memory)}' + conn_str = ( + f"{latest_monitor.current_connections} " + f"({latest_monitor.total_connections} total)" + ) + mem_str = f"{readable_size(latest_monitor.used_memory)}" if latest_monitor.memory_percent: - mem_str += f' ({latest_monitor.memory_percent}%)' - cpu_str = f'{latest_monitor.cpu_percent}%' if latest_monitor.cpu_percent is not None else '-' + mem_str += f" ({latest_monitor.memory_percent}%)" + cpu_str = ( + f"{latest_monitor.cpu_percent}%" + if latest_monitor.cpu_percent is not None + else "-" + ) - print(form.format(alias, cache.type, pid, conn_str, mem_str, cpu_str, loc_str)) + print( + form.format( + alias, cache.type, pid, conn_str, mem_str, cpu_str, loc_str + ) + ) diff --git a/utilmeta/ops/config.py b/utilmeta/ops/config.py index 2b6876f..0260842 100644 --- a/utilmeta/ops/config.py +++ b/utilmeta/ops/config.py @@ -2,8 +2,17 @@ from utilmeta.conf import Config from utilmeta.core.orm.databases.config import Database, DatabaseConnections from utype.types import * -from utilmeta.utils import (DEFAULT_SECRET_NAMES, url_join, localhost, HTTPMethod, get_ip, - cached_property, import_obj, get_origin, get_server_ip) +from utilmeta.utils import ( + DEFAULT_SECRET_NAMES, + url_join, + localhost, + HTTPMethod, + get_ip, + cached_property, + import_obj, + get_origin, + get_server_ip, +) from typing import Union from urllib.parse import urlsplit from utilmeta import UtilMeta, __version__ @@ -17,10 +26,10 @@ class Operations(Config): __eager__: ClassVar = True # setup need to execute before django settings - NAME: ClassVar = 'ops' - REF: ClassVar = 'utilmeta.ops' - HOST: ClassVar = 'utilmeta.com' - ROUTER_NAME: ClassVar = '_OperationsDatabaseRouter' + NAME: ClassVar = "ops" + REF: ClassVar = "utilmeta.ops" + HOST: ClassVar = "utilmeta.com" + ROUTER_NAME: ClassVar = "_OperationsDatabaseRouter" DEFAULT_SECRET_NAMES: ClassVar = DEFAULT_SECRET_NAMES Database: ClassVar = Database @@ -46,18 +55,18 @@ class Monitor(Config): # INSTANCE_MONITOR_RETENTION = timedelta(days=7) def __init__( - self, - worker_disabled: bool = False, - server_disabled: bool = False, - instance_disabled: bool = False, - database_disabled: bool = False, - cache_disabled: bool = False, - # ---------------------------- - worker_retention: timedelta = timedelta(hours=24), - server_retention: timedelta = timedelta(days=7), - instance_retention: timedelta = timedelta(days=7), - database_retention: timedelta = timedelta(days=7), - cache_retention: timedelta = timedelta(days=7), + self, + worker_disabled: bool = False, + server_disabled: bool = False, + instance_disabled: bool = False, + database_disabled: bool = False, + cache_disabled: bool = False, + # ---------------------------- + worker_retention: timedelta = timedelta(hours=24), + server_retention: timedelta = timedelta(days=7), + instance_retention: timedelta = timedelta(days=7), + database_retention: timedelta = timedelta(days=7), + cache_retention: timedelta = timedelta(days=7), ): super().__init__(locals()) @@ -84,27 +93,34 @@ class Log(Config): hide_user_id: bool = False def __init__( - self, - store_data_level: Optional[int] = None, - store_result_level: Optional[int] = None, - store_headers_level: Optional[int] = None, - persist_level: int = WARN, - persist_duration_limit: Optional[int] = 5, - exclude_methods: list = (HTTPMethod.OPTIONS, HTTPMethod.CONNECT, HTTPMethod.TRACE, HTTPMethod.HEAD), - exclude_status: list = (), - exclude_request_headers: List[str] = (), - exclude_response_headers: List[str] = (), - # if these headers show up, exclude - default_volatile: bool = True, - volatile_maintain: timedelta = timedelta(days=7), - hide_ip_address: bool = False, - hide_user_id: bool = False, - # maintain: Optional[timedelta] = None, - # default - # - debug: info - # - production: WARN + self, + store_data_level: Optional[int] = None, + store_result_level: Optional[int] = None, + store_headers_level: Optional[int] = None, + persist_level: int = WARN, + persist_duration_limit: Optional[int] = 5, + exclude_methods: list = ( + HTTPMethod.OPTIONS, + HTTPMethod.CONNECT, + HTTPMethod.TRACE, + HTTPMethod.HEAD, + ), + exclude_status: list = (), + exclude_request_headers: List[str] = (), + exclude_response_headers: List[str] = (), + # if these headers show up, exclude + default_volatile: bool = True, + volatile_maintain: timedelta = timedelta(days=7), + hide_ip_address: bool = False, + hide_user_id: bool = False, + # maintain: Optional[timedelta] = None, + # default + # - debug: info + # - production: WARN ): - exclude_methods = [m.upper() for m in exclude_methods] if exclude_methods else [] + exclude_methods = ( + [m.upper() for m in exclude_methods] if exclude_methods else [] + ) super().__init__(locals()) class Proxy(Config): @@ -112,58 +128,59 @@ class Proxy(Config): forward: bool = False def __init__( - self, - base_url: str, - forward: bool = False, + self, + base_url: str, + forward: bool = False, ): super().__init__(locals()) @property def proxy_url(self): - return url_join(self.base_url, 'proxy') - - def __init__(self, - route: str, - database: Union[str, Database], - base_url: Optional[str] = None, - # replace service.base_url - disabled_scope: List[str] = (), - secret_names: List[str] = DEFAULT_SECRET_NAMES, - trusted_hosts: List[str] = (), - # trusted_packages: List[str] = (), - default_timeout: int = 30, - secure_only: bool = True, - # local_disabled: bool = False, - logger_cls=None, - max_backlog: int = 100, - # will trigger a log save if the log hits this limit - worker_cycle: Union[int, float, timedelta] = timedelta(seconds=30), - worker_task_cls=None, - resources_manager_cls=None, - # every worker cycle, a worker will do - # - save the logs - # - save the worker monitor - # - the main (with min pid) worker will do the monitor tasks - openapi=None, # openapi paths - monitor: Monitor = Monitor(), - log: Log = Log(), - report_disabled: bool = False, - task_error_log: str = None, - max_retention_time: Union[int, float, timedelta] = timedelta(days=90), - local_scope: List[str] = ('*',), - eager_migrate: bool = False, - eager_mount: bool = False, - # new in v2.6.5 +--------- - # token: str = None - # proxy_url: str = None, - # proxy_forward_requests: bool = None, - proxy: Proxy = None, - ): + return url_join(self.base_url, "proxy") + + def __init__( + self, + route: str, + database: Union[str, Database], + base_url: Optional[str] = None, + # replace service.base_url + disabled_scope: List[str] = (), + secret_names: List[str] = DEFAULT_SECRET_NAMES, + trusted_hosts: List[str] = (), + # trusted_packages: List[str] = (), + default_timeout: int = 30, + secure_only: bool = True, + # local_disabled: bool = False, + logger_cls=None, + max_backlog: int = 100, + # will trigger a log save if the log hits this limit + worker_cycle: Union[int, float, timedelta] = timedelta(seconds=30), + worker_task_cls=None, + resources_manager_cls=None, + # every worker cycle, a worker will do + # - save the logs + # - save the worker monitor + # - the main (with min pid) worker will do the monitor tasks + openapi=None, # openapi paths + monitor: Monitor = Monitor(), + log: Log = Log(), + report_disabled: bool = False, + task_error_log: str = None, + max_retention_time: Union[int, float, timedelta] = timedelta(days=90), + local_scope: List[str] = ("*",), + eager_migrate: bool = False, + eager_mount: bool = False, + # new in v2.6.5 +--------- + # token: str = None + # proxy_url: str = None, + # proxy_forward_requests: bool = None, + proxy: Proxy = None, + ): super().__init__(locals()) self.route = route self.database = database if isinstance(database, Database) else None - self.db_alias = database if isinstance(database, str) else '__ops' + self.db_alias = database if isinstance(database, str) else "__ops" self.disabled_scope = set(disabled_scope) self.secret_names = [k.lower() for k in secret_names] @@ -191,15 +208,19 @@ def __init__(self, if base_url: parsed = urlsplit(base_url) if not parsed.scheme: - raise ValueError(f'Operations base_url should be an absolute url, got {base_url}') + raise ValueError( + f"Operations base_url should be an absolute url, got {base_url}" + ) self._base_url = self.parse_base_url(base_url) if self.HOST not in self.trusted_hosts: self.trusted_hosts.append(self.HOST) if not isinstance(monitor, self.Monitor): - raise TypeError(f'Operations monitor config must be a Monitor instance, got {monitor}') + raise TypeError( + f"Operations monitor config must be a Monitor instance, got {monitor}" + ) if not isinstance(log, self.Log): - raise TypeError(f'Operations log config must be a Log instance, got {log}') + raise TypeError(f"Operations log config must be a Log instance, got {log}") self.monitor = monitor self.log = log self.logger_cls_string = logger_cls @@ -213,24 +234,25 @@ def __init__(self, self._mounted = False # ------------------ if proxy and not isinstance(proxy, self.Proxy): - raise TypeError(f'Operations proxy config must be a Proxy instance, got {proxy}') + raise TypeError( + f"Operations proxy config must be a Proxy instance, got {proxy}" + ) self.proxy = proxy @classmethod def parse_base_url(cls, url: str): if not url: return url - if '$IP' in url: - url = url.replace('$IP', get_server_ip()) + if "$IP" in url: + url = url.replace("$IP", get_server_ip()) return url def load_openapi(self, no_store: bool = False): from utilmeta import service from utilmeta.core.api.specs.openapi import OpenAPI + openapi = OpenAPI( - service, - external_docs=self.external_openapi, - base_url=self.base_url + service, external_docs=self.external_openapi, base_url=self.base_url )() if not no_store: self._openapi = openapi @@ -250,7 +272,7 @@ def is_local(self): @property def is_secure(self): - return urlsplit(self.ops_api).scheme == 'https' + return urlsplit(self.ops_api).scheme == "https" @property def proxy_required(self): @@ -260,6 +282,7 @@ def proxy_required(self): return False try: from ipaddress import ip_address + hostname = urlsplit(self.base_url).hostname ip = get_ip(hostname) return ip_address(ip or self.host).is_private @@ -279,42 +302,54 @@ def node_id(self): @cached_property def logger_cls(self): from utilmeta.ops.log import Logger + if not self.logger_cls_string: return Logger cls = import_obj(self.logger_cls_string) if not issubclass(cls, Logger): - raise TypeError(f'Operations.logger_cls must inherit utilmeta.ops.log.Logger, got {cls}') + raise TypeError( + f"Operations.logger_cls must inherit utilmeta.ops.log.Logger, got {cls}" + ) return cls @cached_property def resources_manager_cls(self): from utilmeta.ops.resources import ResourcesManager + if not self.resources_manager_cls_string: return ResourcesManager cls = import_obj(self.resources_manager_cls_string) if not issubclass(cls, ResourcesManager): - raise TypeError(f'Operations.logger_cls must inherit utilmeta.ops.log.Logger, got {cls}') + raise TypeError( + f"Operations.logger_cls must inherit utilmeta.ops.log.Logger, got {cls}" + ) return cls @cached_property def worker_task_cls(self): from utilmeta.ops.task import OperationWorkerTask + if not self.worker_task_cls_string: return OperationWorkerTask cls = import_obj(self.worker_task_cls_string) if not issubclass(cls, OperationWorkerTask): - raise TypeError(f'Operations.worker_task_cls must inherit ' - f'utilmeta.ops.task.OperationWorkerTask, got {cls}') + raise TypeError( + f"Operations.worker_task_cls must inherit " + f"utilmeta.ops.task.OperationWorkerTask, got {cls}" + ) return cls @classmethod def get_secret_key(cls, service: UtilMeta): - seed = f'{service.module_name}:{service.name}:' \ - f'{service.backend_name}:{service.backend_version}:{service.base_url}:{__version__}:{sys.version}' + seed = ( + f"{service.module_name}:{service.name}:" + f"{service.backend_name}:{service.backend_version}:{service.base_url}:{__version__}:{sys.version}" + ) return hashlib.md5(seed.encode()).hexdigest() def hook(self, service: UtilMeta): from .cmd import OperationsCommand + service.register_command(OperationsCommand) def setup(self, service: UtilMeta): @@ -323,11 +358,11 @@ def setup(self, service: UtilMeta): # --- add log middleware if service.adaptor: - service.adaptor.add_middleware( - self.logger_cls.middleware_cls(self) - ) + service.adaptor.add_middleware(self.logger_cls.middleware_cls(self)) else: - raise NotImplementedError('Operations setup error: service backend not specified') + raise NotImplementedError( + "Operations setup error: service backend not specified" + ) # from django.core.exceptions import ImproperlyConfigured # django_settings = None @@ -345,13 +380,14 @@ def setup(self, service: UtilMeta): # print('SETTINGS CONFIGURED') from utilmeta.core.server.backends.django.settings import DjangoSettings + django_config = service.get_config(DjangoSettings) db_routers = [] - if self.db_alias != 'default': + if self.db_alias != "default": db_router = self.get_database_router() setattr(service.module, self.ROUTER_NAME, db_router) - db_routers.append(f'{service.module_name}.{self.ROUTER_NAME}') + db_routers.append(f"{service.module_name}.{self.ROUTER_NAME}") if django_config: if self.REF not in django_config.apps: @@ -368,7 +404,7 @@ def setup(self, service: UtilMeta): apps=[self.REF], database_routers=tuple(db_routers), secret_key=self.get_secret_key(service), - append_slash=True + append_slash=True, ) service.use(django_config) @@ -377,20 +413,20 @@ def setup(self, service: UtilMeta): if dbs_config: if self.database: dbs_config.add_database( - service=service, - alias=self.db_alias, - database=self.database + service=service, alias=self.db_alias, database=self.database ) else: self.database = dbs_config.databases.get(self.db_alias) if not self.database: - raise ValueError(f'Operations config: database required, got invalid {repr(self.db_alias)}') + raise ValueError( + f"Operations config: database required, got invalid {repr(self.db_alias)}" + ) else: if not self.database: - raise ValueError(f'Operations config: database required, got invalid {repr(self.db_alias)}') - service.use(DatabaseConnections({ - self.db_alias: self.database - })) + raise ValueError( + f"Operations config: database required, got invalid {repr(self.db_alias)}" + ) + service.use(DatabaseConnections({self.db_alias: self.database})) # setup here, before importing APIs django_config.setup(service) @@ -404,8 +440,11 @@ def setup(self, service: UtilMeta): parsed = urlsplit(self.route) if not parsed.scheme: from utilmeta.ops.api import OperationsAPI + # route instead of URL - service.mount_to_api(OperationsAPI, route=self.route, eager=self.eager_mount) + service.mount_to_api( + OperationsAPI, route=self.route, eager=self.eager_mount + ) self._mounted = True # try: # root_api = service.resolve() @@ -424,7 +463,9 @@ def setup(self, service: UtilMeta): # pass if service.meta_config: - node_id = service.meta_config.get('node') or service.meta_config.get('node-id') + node_id = service.meta_config.get("node") or service.meta_config.get( + "node-id" + ) if node_id: self._node_id = node_id @@ -445,7 +486,7 @@ def on_startup(self, service: UtilMeta): self.load_openapi() if self._task: - print('Operations task already started, ignoring...') + print("Operations task already started, ignoring...") return if self.eager_migrate: @@ -459,8 +500,10 @@ def on_startup(self, service: UtilMeta): else: self.migrate() - print(f'UtilMeta OperationsAPI loaded at {ops_api}, ' - f'connect your APIs at {__website__}') + print( + f"UtilMeta OperationsAPI loaded at {ops_api}, " + f"connect your APIs at {__website__}" + ) # from .log import setup_locals # threading.Thread(target=setup_locals, args=(self,)).start() # task @@ -501,31 +544,38 @@ def allow_migrate(db, app_label, model_name=None, **hints): def migrate(self, with_default: bool = False): from utilmeta.core.orm.backends.django.database import DjangoDatabaseAdaptor + DjangoDatabaseAdaptor(self.database).check() import warnings from django.db.migrations.executor import MigrationExecutor from django.db import connections + ops_conn = connections[self.db_alias] executor = MigrationExecutor(ops_conn) - migrate_apps = ['ops', 'contenttypes'] + migrate_apps = ["ops", "contenttypes"] try: targets = [ - key for key in executor.loader.graph.leaf_nodes() if key[0] in migrate_apps + key + for key in executor.loader.graph.leaf_nodes() + if key[0] in migrate_apps ] plan = executor.migration_plan(targets) if not plan: return executor.migrate(targets, plan) except Exception as e: - warnings.warn(f'migrate operation models failed with error: {e}') + warnings.warn(f"migrate operation models failed with error: {e}") if with_default: from django.db import connection + # ---------- if connection != ops_conn: try: executor = MigrationExecutor(connection) targets = [ - key for key in executor.loader.graph.leaf_nodes() if key[0] in migrate_apps + key + for key in executor.loader.graph.leaf_nodes() + if key[0] in migrate_apps ] plan = executor.migration_plan(targets) if not plan: @@ -533,7 +583,9 @@ def migrate(self, with_default: bool = False): executor.migrate(targets, plan) except Exception as e: # ignore migration in default db - warnings.warn(f'migrate operation models to default database failed: {e}') + warnings.warn( + f"migrate operation models to default database failed: {e}" + ) @property def ops_api(self): @@ -554,7 +606,7 @@ def ops_api(self): @property def host(self): - ip = get_server_ip(private_only=bool(self.proxy)) or '127.0.0.1' + ip = get_server_ip(private_only=bool(self.proxy)) or "127.0.0.1" try: from utilmeta import service except ImportError: @@ -582,13 +634,14 @@ def port(self): @property def address(self): from ipaddress import ip_address + addr = ip_address(self.host) port = self.port host = self.host if port: if addr.version == 6: - host = f'[{host}]' - return f'{host}:{port}' + host = f"[{host}]" + return f"{host}:{port}" return host @property @@ -606,7 +659,7 @@ def base_url(self): @property def proxy_origin(self): - return 'http://' + self.address + return "http://" + self.address @property def proxy_ops_api(self): @@ -646,50 +699,56 @@ def proxy_base_url(self): def check_supervisor(self, base_url: str): parsed = urlsplit(base_url) if self.secure_only: - if parsed.scheme not in ['https', 'wss']: - raise ValueError(f'utilmeta.ops.Operations: Insecure supervisor: {base_url}, ' - f'HTTPS is required, or you need to turn secure_only=False') + if parsed.scheme not in ["https", "wss"]: + raise ValueError( + f"utilmeta.ops.Operations: Insecure supervisor: {base_url}, " + f"HTTPS is required, or you need to turn secure_only=False" + ) host = str(parsed.hostname) for trusted in self.trusted_hosts: - if host == trusted or host.endswith(f'.{trusted}'): + if host == trusted or host.endswith(f".{trusted}"): return True - raise ValueError(f'utilmeta.ops.Operations: Untrusted supervisor host: {parsed.hostname}, ' - f'if you trust this host, ' - f'you need to add it to the [trusted_hosts] param of Operations config') + raise ValueError( + f"utilmeta.ops.Operations: Untrusted supervisor host: {parsed.hostname}, " + f"if you trust this host, " + f"you need to add it to the [trusted_hosts] param of Operations config" + ) @classmethod def get_backend_name(cls, backend): - name = str(getattr(backend, 'name', '')) + name = str(getattr(backend, "name", "")) if name: return name - name = str(getattr(backend, '__name__', '')) + name = str(getattr(backend, "__name__", "")) if not name: - ref_name = str(backend).lstrip('<').rstrip('>').strip() - if ' ' in ref_name: - ref_name = ref_name.split(' ')[0] - if '.' in ref_name: - ref_name = ref_name.split('.')[0] + ref_name = str(backend).lstrip("<").rstrip(">").strip() + if " " in ref_name: + ref_name = ref_name.split(" ")[0] + if "." in ref_name: + ref_name = ref_name.split(".")[0] name = ref_name or str(backend) - return name + '_service' + return name + "_service" @classmethod def get_service_name(cls, backend): from utilmeta.utils import search_file, load_ini, read_from - meta_path = search_file('utilmeta.ini') or search_file('meta.ini') + + meta_path = search_file("utilmeta.ini") or search_file("meta.ini") name = None if meta_path: try: config = load_ini(read_from(meta_path), parse_key=True) except Exception as e: import warnings - warnings.warn(f'load ini file: {meta_path} failed with error: {e}') + + warnings.warn(f"load ini file: {meta_path} failed with error: {e}") else: - meta_config = config.get('utilmeta') or config.get('service') or {} + meta_config = config.get("utilmeta") or config.get("service") or {} if not isinstance(meta_config, dict): meta_config = {} - name = str(meta_config.get('name', '')).strip() + name = str(meta_config.get("name", "")).strip() if not name: - name = str(getattr(backend, 'name', '')) + name = str(getattr(backend, "name", "")) if name: return name if meta_path: @@ -705,11 +764,16 @@ def integrate(self, backend, module=None, name: str = None): origin = get_origin(self.route) elif not self._base_url: if self.proxy: - eg = ('eg: Operations(base_url="http://$IP:8080/api"), \n you are using a cluster proxy,' - ' $IP will be your current server ip address') + eg = ( + 'eg: Operations(base_url="http://$IP:8080/api"), \n you are using a cluster proxy,' + " $IP will be your current server ip address" + ) else: eg = 'eg: Operations(base_url="https://api.example.com/api")' - raise ValueError('Integrate utilmeta.ops.Operations requires to set a base_url of your API service, ' + eg) + raise ValueError( + "Integrate utilmeta.ops.Operations requires to set a base_url of your API service, " + + eg + ) else: url_parsed = urlsplit(self._base_url) # if url_parsed.path: @@ -718,6 +782,7 @@ def integrate(self, backend, module=None, name: str = None): root_url = url_parsed.path from utilmeta import UtilMeta + try: from utilmeta import service except ImportError: @@ -726,7 +791,7 @@ def integrate(self, backend, module=None, name: str = None): backend=backend, name=name or self.get_service_name(backend), origin=origin, - route=root_url + route=root_url, ) service._auto_created = True else: @@ -734,7 +799,9 @@ def integrate(self, backend, module=None, name: str = None): if module: service.module_name = module else: - raise ValueError(f'Operations.integrate second param should pass __name__, got {module}') + raise ValueError( + f"Operations.integrate second param should pass __name__, got {module}" + ) service.use(self) service.setup() @@ -742,18 +809,22 @@ def integrate(self, backend, module=None, name: str = None): if service.adaptor: if not self._mounted: from .api import OperationsAPI + service.mount_to_api(OperationsAPI, route=route, eager=self.eager_mount) self._mounted = True # service.adaptor.adapt(OperationsAPI, route=parsed.path) service.adaptor.setup() else: - raise NotImplementedError('Operations integrate error: service backend not specified') + raise NotImplementedError( + "Operations integrate error: service backend not specified" + ) if service.module: # ATTRIBUTE FINDER - setattr(service.module, 'utilmeta', service) + setattr(service.module, "utilmeta", service) import utilmeta + if not utilmeta._cmd_env: # trigger start self.on_startup(service) @@ -783,7 +854,7 @@ def get_django_ninja_openapi(cls, *ninja_apis, **path_ninja_apis): from ninja.openapi.schema import get_schema from ninja import NinjaAPI - def generator_func(service: 'UtilMeta'): + def generator_func(service: "UtilMeta"): config = service.get_config(cls) docs = [] for app in ninja_apis: @@ -792,12 +863,16 @@ def generator_func(service: 'UtilMeta'): elif isinstance(app, dict): path_ninja_apis.update(app) else: - raise TypeError(f'Invalid application: {app} for django ninja. NinjaAPI() instance expected') + raise TypeError( + f"Invalid application: {app} for django ninja. NinjaAPI() instance expected" + ) for path, ninja_api in path_ninja_apis.items(): if isinstance(ninja_api, NinjaAPI): doc = get_schema(ninja_api) - servers = doc.get('servers', []) - doc['servers'] = [{'url': url_join(config.base_url, path)}] + servers + servers = doc.get("servers", []) + doc["servers"] = [ + {"url": url_join(config.base_url, path)} + ] + servers docs.append(doc) return docs diff --git a/utilmeta/ops/connect.py b/utilmeta/ops/connect.py index c4507fe..7c49935 100644 --- a/utilmeta/ops/connect.py +++ b/utilmeta/ops/connect.py @@ -11,21 +11,21 @@ # 1. meta connect --token= # 2, -TRUSTED_HOST = 'utilmeta.com' -DEFAULT_SUPERVISOR = 'https://api.utilmeta.com/spv' -CLIENT_NAME = f'utilmeta-py-{__version__}' +TRUSTED_HOST = "utilmeta.com" +DEFAULT_SUPERVISOR = "https://api.utilmeta.com/spv" +CLIENT_NAME = f"utilmeta-py-{__version__}" default_supervisor = SupervisorClient( base_url=DEFAULT_SUPERVISOR, - base_headers={ - 'User-Agent': CLIENT_NAME - }, - fail_silently=True + base_headers={"User-Agent": CLIENT_NAME}, + fail_silently=True, ) # can only get the basic info -def auto_select_supervisor(*supervisors: SupervisorBasic, timeout: int = 5, times: int = 2) -> Optional[str]: +def auto_select_supervisor( + *supervisors: SupervisorBasic, timeout: int = 5, times: int = 2 +) -> Optional[str]: if not supervisors: return None if len(supervisors) == 1: @@ -49,7 +49,9 @@ def fetch_supervisor_info(base_url: str): if not url_map: return None - ts_pairs = [(url, sum(durations) / len(durations)) for url, durations in url_map.items()] + ts_pairs = [ + (url, sum(durations) / len(durations)) for url, durations in url_map.items() + ] ts_pairs.sort(key=lambda v: v[1]) return ts_pairs[0][0] @@ -62,17 +64,18 @@ def save_supervisor(data: SupervisorData) -> Supervisor: # ops_api=data.ops_api ).first() if not obj or obj.disabled: - raise exceptions.NotFound('Supervisor not found or disabled') + raise exceptions.NotFound("Supervisor not found or disabled") if obj.node_id: # already created return obj if obj.base_url != data.base_url: - raise exceptions.Conflict('Supervisor base_url conflicted') + raise exceptions.Conflict("Supervisor base_url conflicted") if Supervisor.objects.filter( - base_url=data.base_url, - node_id=data.node_id + base_url=data.base_url, node_id=data.node_id ).exists(): - raise exceptions.Conflict(f'Supervisor[{data.node_id}] at {data.base_url} already exists') + raise exceptions.Conflict( + f"Supervisor[{data.node_id}] at {data.base_url} already exists" + ) Supervisor.objects.filter(id=obj.pk).update( ident=data.ident, node_id=data.node_id, @@ -84,12 +87,12 @@ def save_supervisor(data: SupervisorData) -> Supervisor: connected=True, url=data.url, local=data.local, - init_key=None, # empty init_key, as it is no longer useful and maybe a potential leak source + init_key=None, # empty init_key, as it is no longer useful and maybe a potential leak source ) - return Supervisor.objects.filter(id=obj.pk).first() # refresh + return Supervisor.objects.filter(id=obj.pk).first() # refresh else: # from api calling - raise exceptions.BadRequest('Missing init_key for supervisor to be saved') + raise exceptions.BadRequest("Missing init_key for supervisor to be saved") def update_service_supervisor(service: str, node_id: str): @@ -97,38 +100,34 @@ def update_service_supervisor(service: str, node_id: str): return from utilmeta.ops import models from django.core.exceptions import EmptyResultSet + for model in models.supervisor_related_models: try: - model.objects.filter( - service=service, - node_id=None - ).update(node_id=node_id) + model.objects.filter(service=service, node_id=None).update(node_id=node_id) except EmptyResultSet: continue def connect_supervisor( - key: str, - base_url: str = None, - service_id: str = None, - cluster_id: str = None + key: str, base_url: str = None, service_id: str = None, cluster_id: str = None ): from utilmeta import service + ops_config = Operations.config() if not ops_config: - raise TypeError('Operations not configured') + raise TypeError("Operations not configured") if not key: - raise ValueError('Access key required to connect to supervisor') + raise ValueError("Access key required to connect to supervisor") - if not key.startswith('{') or not key.endswith('}'): + if not key.startswith("{") or not key.endswith("}"): # BASE64 key = base64.decodebytes(key.encode()).decode() if not base_url: # get action url based on the latency # fire 2 request for each supervisor at the same time, choose the more reliable one - print('connecting: auto-selecting supervisor...') + print("connecting: auto-selecting supervisor...") list_resp = default_supervisor.get_supervisors() if list_resp.success: base_url = auto_select_supervisor(*list_resp.result) @@ -136,15 +135,17 @@ def connect_supervisor( ops_config.check_supervisor(base_url) if not base_url: - raise ValueError('No supervisor selected, operation failed') + raise ValueError("No supervisor selected, operation failed") if service.production: if ops_config.is_local: - raise ValueError(f'Invalid production service operations location: {ops_config.ops_api}, ' - f'please use UtilMeta(origin="https://YOUR-PUBLIC-HOST") ' - f'to specify your public accessible service origin') + raise ValueError( + f"Invalid production service operations location: {ops_config.ops_api}, " + f'please use UtilMeta(origin="https://YOUR-PUBLIC-HOST") ' + f"to specify your public accessible service origin" + ) - print(f'connect supervisor at: {base_url}') + print(f"connect supervisor at: {base_url}") # with orm.Atomic(ops_config.db_alias): # --- PLACEHOLDER @@ -163,17 +164,23 @@ def connect_supervisor( if supervisor_obj: if supervisor_obj.local: - print(f'local supervisor already exists as [{supervisor_obj.node_id}], visit it in {supervisor_obj.url}') + print( + f"local supervisor already exists as [{supervisor_obj.node_id}], visit it in {supervisor_obj.url}" + ) return if supervisor_obj.public_key: - print(f'supervisor already exists as [{supervisor_obj.node_id}],' - f' visit it in {supervisor_obj.url}') + print( + f"supervisor already exists as [{supervisor_obj.node_id}]," + f" visit it in {supervisor_obj.url}" + ) if supervisor_obj.node_id and not ops_config.node_id: # lost sync, resync here from utilmeta.bin.utils import update_meta_ini_file + update_meta_ini_file(node=supervisor_obj.node_id) from .resources import ResourcesManager + ResourcesManager(service).sync_resources(supervisor_obj) # sync resources if supervisor already exists # maybe last connect failed to sync @@ -183,14 +190,14 @@ def connect_supervisor( if supervisor_obj.init_key != key: supervisor_obj.init_key = key - supervisor_obj.save(update_fields=['init_key']) + supervisor_obj.save(update_fields=["init_key"]) if not supervisor_obj: supervisor_obj = Supervisor.objects.create( service=service.name, base_url=base_url, - init_key=key, # for double-check - ops_api=ops_api + init_key=key, # for double-check + ops_api=ops_api, ) # without node_id @@ -204,42 +211,51 @@ def connect_supervisor( cluster_key=key if cluster_id else None, fail_silently=True, service_id=service_id, - cluster_id=cluster_id + cluster_id=cluster_id, ) as cli: - resp = cli.add_node( - data=resources.get_metadata() - ) + resp = cli.add_node(data=resources.get_metadata()) if not resp.success: - raise ValueError(f'connect to supervisor failed with error: {resp.message}') + raise ValueError( + f"connect to supervisor failed with error: {resp.message}" + ) if resp.result: # supervisor is returned (cannot access) supervisor_obj = save_supervisor(resp.result) - if not supervisor_obj.node_id or supervisor_obj.node_id != resp.result.node_id: - raise ValueError(f'supervisor failed to create: inconsistent node id: ' - f'{supervisor_obj.node_id}, {resp.result.node_id}') + if ( + not supervisor_obj.node_id + or supervisor_obj.node_id != resp.result.node_id + ): + raise ValueError( + f"supervisor failed to create: inconsistent node id: " + f"{supervisor_obj.node_id}, {resp.result.node_id}" + ) else: # supervisor already updated in POST OperationsAPI/ - supervisor_obj: Supervisor = Supervisor.objects.get(pk=supervisor_obj.pk) + supervisor_obj: Supervisor = Supervisor.objects.get( + pk=supervisor_obj.pk + ) # update after if not supervisor_obj.node_id: - raise ValueError('supervisor failed to create') + raise ValueError("supervisor failed to create") update_service_supervisor( - service=supervisor_obj.service, - node_id=supervisor_obj.node_id + service=supervisor_obj.service, node_id=supervisor_obj.node_id ) if not supervisor_obj.local: if not supervisor_obj.public_key: - raise ValueError('supervisor failed to create: no public key') + raise ValueError("supervisor failed to create: no public key") else: if not localhost(ops_api): - raise ValueError(f'supervisor failed to create: invalid local supervisor for {ops_api}') + raise ValueError( + f"supervisor failed to create: invalid local supervisor for {ops_api}" + ) url = supervisor_obj.url - print(f'supervisor[{supervisor_obj.node_id}] created') + print(f"supervisor[{supervisor_obj.node_id}] created") from utilmeta.bin.utils import update_meta_ini_file + update_meta_ini_file(node=supervisor_obj.node_id) # update meta.ini except Exception as e: @@ -249,9 +265,9 @@ def connect_supervisor( if not supervisor_obj.local: resources.sync_resources(supervisor_obj) - print('supervisor connected successfully!') + print("supervisor connected successfully!") if url: - print(f'please visit {url} to view and manage your APIs') + print(f"please visit {url} to view and manage your APIs") return url @@ -262,42 +278,47 @@ def delete_supervisor( ): ops_config = Operations.config() if not ops_config: - raise TypeError('Operations not configured') + raise TypeError("Operations not configured") if ops_config.node_id and node_id != node_id: - raise ValueError(f'You are trying to delete supervisor: {repr(node_id)} under a different service') + raise ValueError( + f"You are trying to delete supervisor: {repr(node_id)} under a different service" + ) supervisor: Supervisor = Supervisor.objects.filter(node_id=node_id).first() if not supervisor: - raise ValueError(f'Supervisor: {repr(node_id)} not exists') - print(f'deleting supervisor [{node_id}]...') + raise ValueError(f"Supervisor: {repr(node_id)} not exists") + print(f"deleting supervisor [{node_id}]...") init_key = supervisor.init_key try: if init_key: supervisor.init_key = None # delete as a marker for deletion - supervisor.save(update_fields=['init_key']) + supervisor.save(update_fields=["init_key"]) with SupervisorClient( base_url=supervisor.base_url, access_key=key, node_key=supervisor.public_key, node_id=node_id, - fail_silently=True + fail_silently=True, ) as cli: resp = cli.delete_node() if not resp.success: - if resp.state == 'node_not_exists': - print(f'supervisor not exists in remote, delete local supervisor') + if resp.state == "node_not_exists": + print(f"supervisor not exists in remote, delete local supervisor") else: - raise ValueError(f'connect to supervisor failed with error: {resp.message}') + raise ValueError( + f"connect to supervisor failed with error: {resp.message}" + ) else: if resp.result != supervisor.node_id: - raise ValueError(f'delete supervisor failed: node id mismatch') + raise ValueError(f"delete supervisor failed: node id mismatch") from utilmeta.bin.utils import update_meta_ini_file + update_meta_ini_file(node=None) - supervisor.delete() # this is mostly an empty delete, just incase + supervisor.delete() # this is mostly an empty delete, just incase except Exception as e: if init_key: supervisor.init_key = init_key - supervisor.save(update_fields=['init_key']) + supervisor.save(update_fields=["init_key"]) raise e - print(f'supervisor[{supervisor.node_id}] deleted successfully!') + print(f"supervisor[{supervisor.node_id}] deleted successfully!") diff --git a/utilmeta/ops/key.py b/utilmeta/ops/key.py index 74be03d..1754996 100644 --- a/utilmeta/ops/key.py +++ b/utilmeta/ops/key.py @@ -4,11 +4,11 @@ from typing import Optional import base64 -RSA_ALGO = 'RSA-OAEP-256' +RSA_ALGO = "RSA-OAEP-256" def generate_key_pair(identifier: str): - key = jwk.JWK.generate(kty='RSA', size=2048, alg=RSA_ALGO, kid=identifier) + key = jwk.JWK.generate(kty="RSA", size=2048, alg=RSA_ALGO, kid=identifier) public_key = key.export_public() private_key = key.export_private() return public_key, private_key @@ -17,7 +17,7 @@ def generate_key_pair(identifier: str): def encrypt_data(payload, public_key: Union[str, dict]) -> str: if not isinstance(public_key, dict): if isinstance(public_key, str): - if not public_key.startswith('{') or not public_key.endswith('}'): + if not public_key.startswith("{") or not public_key.endswith("}"): # BASE64 public_key = base64.decodebytes(public_key.encode()).decode() @@ -34,10 +34,8 @@ def encrypt_data(payload, public_key: Union[str, dict]) -> str: payload = json_encode(payload) else: payload = str(payload) - payload = payload.encode('utf-8') - jwe_token = jwe.JWE(payload, - recipient=pubkey_obj, - protected=protected_header) + payload = payload.encode("utf-8") + jwe_token = jwe.JWE(payload, recipient=pubkey_obj, protected=protected_header) return jwe_token.serialize() @@ -51,7 +49,7 @@ def decrypt_data(encrypted: Union[str, bytes], private_key: Union[str, dict]) -> jwetoken.deserialize(encrypted, key=privkey_obj) payload = jwetoken.payload if isinstance(payload, bytes): - payload = payload.decode('utf-8') + payload = payload.decode("utf-8") return payload diff --git a/utilmeta/ops/log.py b/utilmeta/ops/log.py index e7b43b1..cec6f6b 100644 --- a/utilmeta/ops/log.py +++ b/utilmeta/ops/log.py @@ -5,9 +5,19 @@ from utilmeta.utils.context import ContextProperty, Property from typing import List, Optional, Union from utilmeta.core.server import ServiceMiddleware -from utilmeta.utils import (file_like, SECRET, HAS_BODY_METHODS, - hide_secret_values, normalize, time_now, Error, ignore_errors, - replace_null, parse_user_agents, HTTP_METHODS_LOWER) +from utilmeta.utils import ( + file_like, + SECRET, + HAS_BODY_METHODS, + hide_secret_values, + normalize, + time_now, + Error, + ignore_errors, + replace_null, + parse_user_agents, + HTTP_METHODS_LOWER, +) from .config import Operations import threading import contextvars @@ -28,8 +38,8 @@ _databases: dict = {} _caches: dict = {} _openapi = None -_path_prefix = '' -_logger = contextvars.ContextVar('_logger') +_path_prefix = "" +_logger = contextvars.ContextVar("_logger") class WorkerMetricsLogger: @@ -48,25 +58,26 @@ def __init__(self): self._total_time = 0 @ignore_errors - def log(self, - duration: float, - in_traffic: int = 0, - out_traffic: int = 0, - outbound: bool = False, - error: bool = False, - timeout: bool = False - ): + def log( + self, + duration: float, + in_traffic: int = 0, + out_traffic: int = 0, + outbound: bool = False, + error: bool = False, + timeout: bool = False, + ): self._total_in += in_traffic self._total_out += out_traffic if outbound: self._total_outbound_requests += 1 - self._total_outbound_errors += (1 if error else 0) - self._total_outbound_timeouts += (1 if timeout else 0) + self._total_outbound_errors += 1 if error else 0 + self._total_outbound_timeouts += 1 if timeout else 0 self._total_outbound_request_time += duration else: self._total_requests += 1 - self._total_errors += (1 if error else 0) + self._total_errors += 1 if error else 0 self._total_time += duration def reset(self): @@ -91,14 +102,17 @@ def fetch(self, interval: int): rps=self._total_requests / interval, errors=self._total_errors, outbound_requests=self._total_outbound_requests, - outbound_avg_time=(self._total_outbound_request_time / self._total_outbound_requests) if - self._total_outbound_requests else 0, + outbound_avg_time=( + self._total_outbound_request_time / self._total_outbound_requests + ) + if self._total_outbound_requests + else 0, outbound_rps=self._total_outbound_requests / interval, outbound_errors=self._total_outbound_errors, outbound_timeouts=self._total_outbound_timeouts, ) - @ignore_errors(default=dict) # ignore cache errors + @ignore_errors(default=dict) # ignore cache errors def retrieve(self, inst) -> dict: if not inst: return {} @@ -119,25 +133,37 @@ def retrieve(self, inst) -> dict: ) if requests: values.update( - requests=models.F('requests') + requests, + requests=models.F("requests") + requests, rps=round(requests / (now - inst.time).total_seconds(), 4), - avg_time=((models.F('avg_time') * models.F('requests') + total_time) / - (models.F('requests') + requests)) if requests else models.F('avg_time'), - errors=models.F('errors') + errors, + avg_time=( + (models.F("avg_time") * models.F("requests") + total_time) + / (models.F("requests") + requests) + ) + if requests + else models.F("avg_time"), + errors=models.F("errors") + errors, ) if in_traffic: - values.update(in_traffic=models.F('in_traffic') + in_traffic) + values.update(in_traffic=models.F("in_traffic") + in_traffic) if out_traffic: - values.update(out_traffic=models.F('out_traffic') + out_traffic) + values.update(out_traffic=models.F("out_traffic") + out_traffic) if outbound_requests: values.update( - outbound_requests=models.F('outbound_requests') + outbound_requests, - outbound_errors=models.F('outbound_errors') + outbound_errors, - outbound_timeouts=models.F('outbound_timeouts') + outbound_timeouts, - outbound_rps=round(outbound_requests / (now - inst.time).total_seconds(), 4), - outbound_avg_time=((models.F('outbound_requests') * models.F('outbound_avg_time') + - total_outbound_request_time) / (models.F('outbound_requests') + outbound_requests)) - if outbound_requests else models.F('outbound_avg_time'), + outbound_requests=models.F("outbound_requests") + outbound_requests, + outbound_errors=models.F("outbound_errors") + outbound_errors, + outbound_timeouts=models.F("outbound_timeouts") + outbound_timeouts, + outbound_rps=round( + outbound_requests / (now - inst.time).total_seconds(), 4 + ), + outbound_avg_time=( + ( + models.F("outbound_requests") * models.F("outbound_avg_time") + + total_outbound_request_time + ) + / (models.F("outbound_requests") + outbound_requests) + ) + if outbound_requests + else models.F("outbound_avg_time"), ) return replace_null(values) @@ -153,18 +179,17 @@ def update_worker(self, record: bool = False, interval: int = None): if not _worker: return from .models import Worker - worker: Worker = _worker # noqa + + worker: Worker = _worker # noqa now = time_now() sys_metrics = worker.get_sys_metrics() - req_metrics = self.fetch(interval or max(1.0, (now - worker.time).total_seconds())) - self.save( - worker, - **sys_metrics, - connected=True, - time=now + req_metrics = self.fetch( + interval or max(1.0, (now - worker.time).total_seconds()) ) + self.save(worker, **sys_metrics, connected=True, time=now) if record: from .models import WorkerMonitor + WorkerMonitor.objects.create( worker=worker, interval=interval, @@ -175,7 +200,7 @@ def update_worker(self, record: bool = False, interval: int = None): worker_logger = WorkerMetricsLogger() -request_logger = var.RequestContextVar('_logger', cached=True, static=True) +request_logger = var.RequestContextVar("_logger", cached=True, static=True) class LogLevel: @@ -185,18 +210,19 @@ class LogLevel: ERROR = 3 -LOG_LEVELS = ['DEBUG', 'INFO', 'WARN', 'ERROR'] +LOG_LEVELS = ["DEBUG", "INFO", "WARN", "ERROR"] def level_log(f): lv = f.__name__.upper() if lv not in LOG_LEVELS: - raise ValueError(f'Invalid log level: {lv}') + raise ValueError(f"Invalid log level: {lv}") index = LOG_LEVELS.index(lv) @wraps(f) - def emit(self: 'Logger', brief: str, msg: str = None, **kwargs): + def emit(self: "Logger", brief: str, msg: str = None, **kwargs): return self.emit(brief, level=index, data=kwargs, msg=msg) + return emit @@ -218,8 +244,7 @@ def setup_locals(config: Operations, close_conn: bool = False): from .models import Resource, Worker, Supervisor from utilmeta import service - global _worker, _version, _supervisor, _instance, _server, \ - _endpoints_map, _openapi, _endpoints_patterns, _path_prefix, _databases, _caches + global _worker, _version, _supervisor, _instance, _server, _endpoints_map, _openapi, _endpoints_patterns, _path_prefix, _databases, _caches # node_id = config.node_id _supervisor = Supervisor.current().first() # reset supervisor @@ -231,43 +256,46 @@ def setup_locals(config: Operations, close_conn: bool = False): if not _server: _server = Resource.get_current_server() from .monitor import get_current_server + data = get_current_server() if not _server: from utilmeta.utils import get_mac_address + mac = get_mac_address() _server = Resource.objects.create( - type='server', + type="server", service=None, # server is a service-neutral resource node_id=node_id, ident=mac, data=data, - route=f'server/{mac}', + route=f"server/{mac}", ) else: if _server.data != data: _server.data = data - _server.save(update_fields=['data']) + _server.save(update_fields=["data"]) if not _instance: _instance = Resource.get_current_instance() from .schema import get_current_instance_data + data = get_current_instance_data() if not _instance: ident = config.address _instance = Resource.objects.create( - type='instance', + type="instance", service=service.name, node_id=node_id, ident=ident, - route=f'instance/{node_id}/{ident}' if node_id else f'instance/{ident}', + route=f"instance/{node_id}/{ident}" if node_id else f"instance/{ident}", server=_server, - data=data + data=data, ) else: if _instance.data != data: _instance.data = data - _instance.save(update_fields=['data']) + _instance.save(update_fields=["data"]) # if not _version: # if _instance: @@ -280,14 +308,13 @@ def setup_locals(config: Operations, close_conn: bool = False): if not _worker: import utilmeta + if not utilmeta._cmd_env: _worker = Worker.load() if not _endpoints_map: _endpoints = Resource.filter( - type='endpoint', - service=service.name, - deprecated=False + type="endpoint", service=service.name, deprecated=False ) if node_id: @@ -300,6 +327,7 @@ def setup_locals(config: Operations, close_conn: bool = False): _openapi = config.openapi from utilmeta.core.api.specs.openapi import get_operation_id from utilmeta.core.api.route import APIRoute + patterns = {} operation_ids = [] for path, path_item in _openapi.paths.items(): @@ -312,42 +340,47 @@ def setup_locals(config: Operations, close_conn: bool = False): operation = path_item.get(method) if not operation: continue - operation_id = operation.get('operationId') + operation_id = operation.get("operationId") if not operation_id: - operation_id = get_operation_id(method, path, excludes=operation_ids, attribute=True) + operation_id = get_operation_id( + method, path, excludes=operation_ids, attribute=True + ) operation_ids.append(operation_id) methods[method] = operation_id if methods: patterns[pattern] = methods except Exception as e: - warnings.warn(f'generate pattern operation Id at path {path} failed: {e}') + warnings.warn( + f"generate pattern operation Id at path {path} failed: {e}" + ) continue _endpoints_patterns = patterns if _openapi.servers: url = _openapi.servers[0].url from urllib.parse import urlparse - _path_prefix = urlparse(url).path.strip('/') + + _path_prefix = urlparse(url).path.strip("/") if not _databases: from utilmeta.core.orm import DatabaseConnections + db_config = DatabaseConnections.config() dbs = {} if db_config and db_config.databases: for alias, db in db_config.databases.items(): db_obj = Resource.filter( - type='database', - service=service.name, - ident=alias, - deprecated=False + type="database", service=service.name, ident=alias, deprecated=False ).first() if not db_obj: db_obj = Resource.objects.create( - type='database', + type="database", service=service.name, node_id=node_id, ident=alias, - route=f'database/{node_id}/{alias}' if node_id else f'database/{alias}', + route=f"database/{node_id}/{alias}" + if node_id + else f"database/{alias}", server=_server if db.local else None, ) dbs[alias] = db_obj @@ -355,6 +388,7 @@ def setup_locals(config: Operations, close_conn: bool = False): if not _caches: from utilmeta.core.cache import CacheConnections + cache_config = CacheConnections.config() caches = {} if cache_config and cache_config.caches: @@ -363,18 +397,17 @@ def setup_locals(config: Operations, close_conn: bool = False): # do not monitor memory cache for now continue cache_obj = Resource.filter( - type='cache', - service=service.name, - ident=alias, - deprecated=False + type="cache", service=service.name, ident=alias, deprecated=False ).first() if not cache_obj: cache_obj = Resource.objects.create( - type='cache', + type="cache", service=service.name, node_id=node_id, ident=alias, - route=f'cache/{node_id}/{alias}' if node_id else f'cache/{alias}', + route=f"cache/{node_id}/{alias}" + if node_id + else f"cache/{alias}", server=_server if cache.local else None, ) caches[alias] = cache_obj @@ -383,6 +416,7 @@ def setup_locals(config: Operations, close_conn: bool = False): if close_conn: # close connections from django.db import connections + # ops_conn = connections[config.db_alias] # if ops_conn: # ops_conn.close() @@ -406,16 +440,24 @@ def is_excluded(self, response: Response): request = response.request if request: if self.config.log.exclude_methods: - if request.adaptor.request_method.upper() in self.config.log.exclude_methods: + if ( + request.adaptor.request_method.upper() + in self.config.log.exclude_methods + ): return True if self.config.log.exclude_request_headers: - if any(h in self.config.log.exclude_request_headers for h in request.headers): + if any( + h in self.config.log.exclude_request_headers + for h in request.headers + ): return True if self.config.log.exclude_status: if response.status in self.config.log.exclude_status: return True if self.config.log.exclude_response_headers: - if any(h in self.config.log.exclude_response_headers for h in response.headers): + if any( + h in self.config.log.exclude_response_headers for h in response.headers + ): return True return False @@ -449,12 +491,7 @@ def process_response(self, response: Response): _responses_queue.append(response) if len(_responses_queue) >= self.config.max_backlog: - threading.Thread( - target=batch_save_logs, - kwargs=dict( - close=True - ) - ).start() + threading.Thread(target=batch_save_logs, kwargs=dict(close=True)).start() class Logger(Property): @@ -472,12 +509,10 @@ class Logger(Property): # STORE_RESULT_LEVEL = LogLevel.WARN # STORE_HEADERS_LEVEL = LogLevel.WARN - def __init__(self, - from_logger: 'Logger' = None, - span_data: dict = None - ): + def __init__(self, from_logger: "Logger" = None, span_data: dict = None): super().__init__() from utilmeta import service + self.service = service self.config = service.get_config(Operations) self.current_thread = threading.current_thread().ident @@ -509,11 +544,17 @@ def __init__(self, self._persist_level = self.config.log.persist_level self._persist_duration_limit = self.config.log.persist_duration_limit if self._store_data_level is None: - self._store_data_level = LogLevel.WARN if service.production else LogLevel.INFO + self._store_data_level = ( + LogLevel.WARN if service.production else LogLevel.INFO + ) if self._store_headers_level is None: - self._store_headers_level = LogLevel.WARN if service.production else LogLevel.INFO + self._store_headers_level = ( + LogLevel.WARN if service.production else LogLevel.INFO + ) if self._store_result_level is None: - self._store_result_level = LogLevel.WARN if service.production else LogLevel.INFO + self._store_result_level = ( + LogLevel.WARN if service.production else LogLevel.INFO + ) def relative_time(self, to=None): return max(int(((to or time.time()) - self.init_time) * 1000), 0) @@ -532,7 +573,12 @@ def events_only(self): @property def vacuum(self): - return not self._messages and not self._events and not self._exceptions and not self._span_logger + return ( + not self._messages + and not self._events + and not self._exceptions + and not self._span_logger + ) @property def level(self): @@ -564,14 +610,11 @@ def status_level(cls, status: int): def __call__(self, name: str, **kwargs): if self._span_logger: return self._span_logger(name, **kwargs) - assert name, f'Empty scope name' - self._span_data = dict( - name=name, - **kwargs - ) + assert name, f"Empty scope name" + self._span_data = dict(name=name, **kwargs) return self - def __enter__(self) -> 'Logger': + def __enter__(self) -> "Logger": if self._span_logger: return self._span_logger.__enter__() if not self._span_data: @@ -622,19 +665,24 @@ def span(self): def setup_request(self, request: Request): self._request = request if _supervisor: - supervisor_id = request.headers.get('x-utilmeta-node-id') or request.headers.get('x-node-id') - supervisor_hash = request.headers.get('X-utilmeta-supervisor-key-md5') - if supervisor_hash and supervisor_id == _supervisor.node_id: # noqa + supervisor_id = request.headers.get( + "x-utilmeta-node-id" + ) or request.headers.get("x-node-id") + supervisor_hash = request.headers.get("X-utilmeta-supervisor-key-md5") + if supervisor_hash and supervisor_id == _supervisor.node_id: # noqa import hashlib - if hashlib.md5(_supervisor.public_key) == supervisor_hash: # noqa + + if hashlib.md5(_supervisor.public_key) == supervisor_hash: # noqa self._supervised = True if self._supervised: - log_options = request.headers.get('x-utilmeta-log-options') + log_options = request.headers.get("x-utilmeta-log-options") if log_options: - options = [option.strip() for option in str(log_options).lower().split(',')] - if 'omit' in options: + options = [ + option.strip() for option in str(log_options).lower().split(",") + ] + if "omit" in options: self._omitted = True - if 'timing' in options or 'server-timing' in options: + if "timing" in options or "server-timing" in options: self._server_timing = True def omit(self, val: bool = True): @@ -647,11 +695,17 @@ def setup_response(self, response: Response): if self._supervised: if self._server_timing: duration = response.duration_ms or self.duration - ts = response.request.time.timestamp() if response.request else self.init_time + ts = ( + response.request.time.timestamp() + if response.request + else self.init_time + ) if duration: - response.set_header('Server-Timing', f'total;dur={duration};ts={ts}') + response.set_header( + "Server-Timing", f"total;dur={duration};ts={ts}" + ) - def generate_request_logs(self, context_type='service_log', context_id=None): + def generate_request_logs(self, context_type="service_log", context_id=None): if not self._client_responses: return [] @@ -659,38 +713,37 @@ def generate_request_logs(self, context_type='service_log', context_id=None): for resp in self._client_responses: log = self.generate_request_log( - resp, - context_type=context_type, - context_id=context_id + resp, context_type=context_type, context_id=context_id ) if log: objects.append(log) return objects - def generate_request_log(self, response: Response, - context_type='service_log', - context_id=None): + def generate_request_log( + self, response: Response, context_type="service_log", context_id=None + ): from .models import RequestLog - return RequestLog( - ) + return RequestLog() @classmethod def get_file_repr(cls, file): - return '' + return "" def parse_values(self, data): - return hide_secret_values(data, secret_names=self.config.secret_names, file_repr=self.get_file_repr) + return hide_secret_values( + data, secret_names=self.config.secret_names, file_repr=self.get_file_repr + ) @classmethod def get_endpoint_ident(cls, request: Request) -> Optional[str]: if not _endpoints_patterns: return None - path = str(request.path or '').strip('/') + path = str(request.path or "").strip("/") if _path_prefix: if not path.startswith(_path_prefix): return None - path = path[len(_path_prefix):].strip('/') + path = path[len(_path_prefix) :].strip("/") for pattern, methods in _endpoints_patterns.items(): if pattern.fullmatch(path): return methods.get(request.method) @@ -726,13 +779,13 @@ def generate_log(self, response: Response): try: data = self.parse_values(request.data) except Exception as e: # noqa: ignore - warnings.warn(f'load request data failed: {e}') + warnings.warn(f"load request data failed: {e}") if level >= self._store_result_level: try: result = self.parse_values(response.data) except Exception as e: # noqa: ignore - warnings.warn(f'load response data failed: {e}') + warnings.warn(f"load response data failed: {e}") try: public = request.ip_address.is_global @@ -750,11 +803,13 @@ def generate_log(self, response: Response): response_headers = {} if level >= self._store_headers_level: request_headers = self.parse_values(dict(request.headers)) - response_headers = self.parse_values(dict(response.prepare_headers(with_content_type=True))) + response_headers = self.parse_values( + dict(response.prepare_headers(with_content_type=True)) + ) operation_names = var.operation_names.getter(request) if operation_names: - endpoint_ident = '_'.join(operation_names) + endpoint_ident = "_".join(operation_names) else: # or find it by the generated openapi items (match method and path, find operationId) endpoint_ident = self.get_endpoint_ident(request) @@ -772,9 +827,9 @@ def generate_log(self, response: Response): service=self.service.name, instance=_instance, version=_version, - node_id=getattr(_supervisor, 'node_id', None), + node_id=getattr(_supervisor, "node_id", None), supervisor=_supervisor, - access_token_id=getattr(access_token, 'id', None), + access_token_id=getattr(access_token, "id", None), level=level_str, volatile=volatile, time=request.time, @@ -792,7 +847,7 @@ def generate_log(self, response: Response): result=result, user_id=user_id, ip=str(request.ip_address), - user_agent=parse_user_agents(request.headers.get('user-agent')), + user_agent=parse_user_agents(request.headers.get("user-agent")), status=status, request_type=request.content_type, response_type=response.content_type, @@ -808,7 +863,7 @@ def generate_log(self, response: Response): ) def get_trace(self): - self._events.sort(key=lambda v: v.get('init', 0)) + self._events.sort(key=lambda v: v.get("init", 0)) return normalize(self._events, _json=True) def exit(self): @@ -830,7 +885,9 @@ def exit(self): else: _logger.set(None) - def emit(self, brief: Union[str, Error], level: int, data: dict = None, msg: str = None): + def emit( + self, brief: Union[str, Error], level: int, data: dict = None, msg: str = None + ): if self._span_logger: return self._span_logger.emit(brief, level, data, msg=msg) @@ -860,13 +917,15 @@ def emit(self, brief: Union[str, Error], level: int, data: dict = None, msg: str self._exceptions.append(exception) name = LOG_LEVELS[level] - self._events.append(dict( - name=name, - init=self.relative_time(ts), - type=f'log.{name.lower()}', - msg=self._push_message(brief, msg=msg), - data=data, - )) + self._events.append( + dict( + name=name, + init=self.relative_time(ts), + type=f"log.{name.lower()}", + msg=self._push_message(brief, msg=msg), + data=data, + ) + ) def commit_error(self, e: Error): if e.exception in self._exceptions: @@ -906,11 +965,11 @@ def error(self, brief: Union[str, Exception], msg: str = None, **kwargs): @property def message(self) -> str: - return '\n'.join(self._messages) + return "\n".join(self._messages) @property def brief_message(self) -> str: - return '; '.join(self._briefs) + return "; ".join(self._briefs) def batch_save_logs(close: bool = False): @@ -929,12 +988,16 @@ def batch_save_logs(close: bool = False): if not _server: # not setup yet from .config import Operations + setup_locals(Operations.config()) if _supervisor: # update supervisor (connect / disconnect) from .models import Supervisor - supervisor = Supervisor.objects.filter(pk=getattr(_supervisor, 'pk', None)).first() + + supervisor = Supervisor.objects.filter( + pk=getattr(_supervisor, "pk", None) + ).first() if not supervisor: # check _supervisor before save logs _supervisor = None @@ -948,9 +1011,7 @@ def batch_save_logs(close: bool = False): if not logger: continue - service_log = logger.generate_log( - response - ) + service_log = logger.generate_log(response) if not service_log: continue @@ -980,6 +1041,7 @@ def batch_save_logs(close: bool = False): if close: from django.db import connections + connections.close_all() return diff --git a/utilmeta/ops/migrations/0001_initial.py b/utilmeta/ops/migrations/0001_initial.py index 81061be..43b1e73 100644 --- a/utilmeta/ops/migrations/0001_initial.py +++ b/utilmeta/ops/migrations/0001_initial.py @@ -26,8 +26,18 @@ class Migration(migrations.Migration): ), ), ("relieved_time", models.DateTimeField(default=None, null=True)), - ("trigger_times", models.JSONField(default=list, encoder=utype.utils.encode.JSONEncoder)), - ("trigger_values", models.JSONField(default=list, encoder=utype.utils.encode.JSONEncoder)), + ( + "trigger_times", + models.JSONField( + default=list, encoder=utype.utils.encode.JSONEncoder + ), + ), + ( + "trigger_values", + models.JSONField( + default=list, encoder=utype.utils.encode.JSONEncoder + ), + ), ( "time", models.DateTimeField( @@ -41,7 +51,12 @@ class Migration(migrations.Migration): ), ), ("count", models.PositiveBigIntegerField(default=1)), - ("data", models.JSONField(default=None, encoder=utype.utils.encode.JSONEncoder, null=True)), + ( + "data", + models.JSONField( + default=None, encoder=utype.utils.encode.JSONEncoder, null=True + ), + ), ], options={ "db_table": "utilmeta_alert_log", @@ -71,7 +86,12 @@ class Migration(migrations.Migration): ("route", models.CharField(max_length=300)), ("remote_id", models.CharField(default=None, max_length=40, null=True)), ("created_time", models.DateTimeField(auto_now_add=True)), - ("data", models.JSONField(default=dict, encoder=utype.utils.encode.JSONEncoder)), + ( + "data", + models.JSONField( + default=dict, encoder=utype.utils.encode.JSONEncoder + ), + ), ("deleted", models.BooleanField(default=False)), ("deprecated", models.BooleanField(default=False)), ( @@ -104,7 +124,12 @@ class Migration(migrations.Migration): ("service", models.CharField(max_length=100)), ("node_id", models.CharField(default=None, max_length=40, null=True)), ("ident", models.CharField(default=None, max_length=20, null=True)), - ("backup_urls", models.JSONField(default=list, encoder=utype.utils.encode.JSONEncoder)), + ( + "backup_urls", + models.JSONField( + default=list, encoder=utype.utils.encode.JSONEncoder + ), + ), ("base_url", models.URLField()), ("local", models.BooleanField(default=False)), ("public_key", models.TextField(default=None, null=True)), @@ -123,13 +148,38 @@ class Migration(migrations.Migration): models.PositiveIntegerField(default=None, null=True), ), ("latency", models.PositiveIntegerField(default=None, null=True)), - ("settings", models.JSONField(default=dict, encoder=utype.utils.encode.JSONEncoder)), + ( + "settings", + models.JSONField( + default=dict, encoder=utype.utils.encode.JSONEncoder + ), + ), ("connected", models.BooleanField(default=False)), ("disabled", models.BooleanField(default=False)), - ("alert_settings", models.JSONField(default=dict, encoder=utype.utils.encode.JSONEncoder)), - ("task_settings", models.JSONField(default=dict, encoder=utype.utils.encode.JSONEncoder)), - ("aggregate_settings", models.JSONField(default=dict, encoder=utype.utils.encode.JSONEncoder)), - ("data", models.JSONField(default=dict, encoder=utype.utils.encode.JSONEncoder)), + ( + "alert_settings", + models.JSONField( + default=dict, encoder=utype.utils.encode.JSONEncoder + ), + ), + ( + "task_settings", + models.JSONField( + default=dict, encoder=utype.utils.encode.JSONEncoder + ), + ), + ( + "aggregate_settings", + models.JSONField( + default=dict, encoder=utype.utils.encode.JSONEncoder + ), + ), + ( + "data", + models.JSONField( + default=dict, encoder=utype.utils.encode.JSONEncoder + ), + ), ( "resources_etag", models.CharField(default=None, max_length=200, null=True), @@ -171,7 +221,12 @@ class Migration(migrations.Migration): ("open_files", models.PositiveBigIntegerField(default=None, null=True)), ("active_net_connections", models.PositiveIntegerField(default=0)), ("total_net_connections", models.PositiveIntegerField(default=0)), - ("net_connections_info", models.JSONField(default=dict, encoder=utype.utils.encode.JSONEncoder)), + ( + "net_connections_info", + models.JSONField( + default=dict, encoder=utype.utils.encode.JSONEncoder + ), + ), ("in_traffic", models.PositiveBigIntegerField(default=0, null=True)), ("out_traffic", models.PositiveBigIntegerField(default=0, null=True)), ("outbound_requests", models.PositiveIntegerField(default=0)), @@ -205,7 +260,12 @@ class Migration(migrations.Migration): models.DecimalField(decimal_places=2, default=0.0, max_digits=10), ), ("pid", models.PositiveIntegerField()), - ("memory_info", models.JSONField(default=dict, encoder=utype.utils.encode.JSONEncoder)), + ( + "memory_info", + models.JSONField( + default=dict, encoder=utype.utils.encode.JSONEncoder + ), + ), ("threads", models.PositiveIntegerField(default=0)), ( "start_time", @@ -223,8 +283,18 @@ class Migration(migrations.Migration): ("status", models.CharField(default=None, max_length=100, null=True)), ("user", models.CharField(default=None, max_length=100, null=True)), ("retire_time", models.DateTimeField(default=None, null=True)), - ("reload_params", models.JSONField(default=dict, encoder=utype.utils.encode.JSONEncoder)), - ("data", models.JSONField(default=dict, encoder=utype.utils.encode.JSONEncoder)), + ( + "reload_params", + models.JSONField( + default=dict, encoder=utype.utils.encode.JSONEncoder + ), + ), + ( + "data", + models.JSONField( + default=dict, encoder=utype.utils.encode.JSONEncoder + ), + ), ( "instance", models.ForeignKey( @@ -289,7 +359,12 @@ class Migration(migrations.Migration): ("open_files", models.PositiveBigIntegerField(default=None, null=True)), ("active_net_connections", models.PositiveIntegerField(default=0)), ("total_net_connections", models.PositiveIntegerField(default=0)), - ("net_connections_info", models.JSONField(default=dict, encoder=utype.utils.encode.JSONEncoder)), + ( + "net_connections_info", + models.JSONField( + default=dict, encoder=utype.utils.encode.JSONEncoder + ), + ), ("in_traffic", models.PositiveBigIntegerField(default=0, null=True)), ("out_traffic", models.PositiveBigIntegerField(default=0, null=True)), ("outbound_requests", models.PositiveIntegerField(default=0)), @@ -329,9 +404,19 @@ class Migration(migrations.Migration): ), ), ("interval", models.PositiveIntegerField(default=None, null=True)), - ("memory_info", models.JSONField(default=dict, encoder=utype.utils.encode.JSONEncoder)), + ( + "memory_info", + models.JSONField( + default=dict, encoder=utype.utils.encode.JSONEncoder + ), + ), ("threads", models.PositiveIntegerField(default=0)), - ("metrics", models.JSONField(default=dict, encoder=utype.utils.encode.JSONEncoder)), + ( + "metrics", + models.JSONField( + default=dict, encoder=utype.utils.encode.JSONEncoder + ), + ), ( "worker", models.ForeignKey( @@ -368,7 +453,12 @@ class Migration(migrations.Migration): ("time", models.DateTimeField(auto_now_add=True)), ("down_time", models.PositiveBigIntegerField(default=None, null=True)), ("success", models.BooleanField(default=None, null=True)), - ("restart_data", models.JSONField(default=dict, encoder=utype.utils.encode.JSONEncoder)), + ( + "restart_data", + models.JSONField( + default=dict, encoder=utype.utils.encode.JSONEncoder + ), + ), ("version", models.CharField(max_length=100)), ( "remote_id", @@ -400,14 +490,44 @@ class Migration(migrations.Migration): "response_type", models.CharField(default=None, max_length=200, null=True), ), - ("request_headers", models.JSONField(default=dict, encoder=utype.utils.encode.JSONEncoder, null=True)), - ("response_headers", models.JSONField(default=dict, encoder=utype.utils.encode.JSONEncoder, null=True)), - ("user_agent", models.JSONField(default=None, encoder=utype.utils.encode.JSONEncoder, null=True)), + ( + "request_headers", + models.JSONField( + default=dict, encoder=utype.utils.encode.JSONEncoder, null=True + ), + ), + ( + "response_headers", + models.JSONField( + default=dict, encoder=utype.utils.encode.JSONEncoder, null=True + ), + ), + ( + "user_agent", + models.JSONField( + default=None, encoder=utype.utils.encode.JSONEncoder, null=True + ), + ), ("status", models.PositiveSmallIntegerField(default=200, null=True)), ("length", models.PositiveBigIntegerField(default=0, null=True)), - ("query", models.JSONField(default=dict, encoder=utype.utils.encode.JSONEncoder, null=True)), - ("data", models.JSONField(default=None, encoder=utype.utils.encode.JSONEncoder, null=True)), - ("result", models.JSONField(default=None, encoder=utype.utils.encode.JSONEncoder, null=True)), + ( + "query", + models.JSONField( + default=dict, encoder=utype.utils.encode.JSONEncoder, null=True + ), + ), + ( + "data", + models.JSONField( + default=None, encoder=utype.utils.encode.JSONEncoder, null=True + ), + ), + ( + "result", + models.JSONField( + default=None, encoder=utype.utils.encode.JSONEncoder, null=True + ), + ), ("full_url", models.URLField(default=None, null=True)), ("path", models.URLField(default=None, null=True)), ("in_traffic", models.PositiveBigIntegerField(default=0)), @@ -432,8 +552,18 @@ class Migration(migrations.Migration): ), ), ("ip", models.GenericIPAddressField()), - ("trace", models.JSONField(default=list, encoder=utype.utils.encode.JSONEncoder)), - ("messages", models.JSONField(default=list, encoder=utype.utils.encode.JSONEncoder)), + ( + "trace", + models.JSONField( + default=list, encoder=utype.utils.encode.JSONEncoder + ), + ), + ( + "messages", + models.JSONField( + default=list, encoder=utype.utils.encode.JSONEncoder + ), + ), ("volatile", models.BooleanField(default=True)), ( "alert", @@ -522,7 +652,12 @@ class Migration(migrations.Migration): ("open_files", models.PositiveBigIntegerField(default=None, null=True)), ("active_net_connections", models.PositiveIntegerField(default=0)), ("total_net_connections", models.PositiveIntegerField(default=0)), - ("net_connections_info", models.JSONField(default=dict, encoder=utype.utils.encode.JSONEncoder)), + ( + "net_connections_info", + models.JSONField( + default=dict, encoder=utype.utils.encode.JSONEncoder + ), + ), ( "time", models.DateTimeField( @@ -549,7 +684,12 @@ class Migration(migrations.Migration): decimal_places=2, default=None, max_digits=8, null=True ), ), - ("metrics", models.JSONField(default=dict, encoder=utype.utils.encode.JSONEncoder)), + ( + "metrics", + models.JSONField( + default=dict, encoder=utype.utils.encode.JSONEncoder + ), + ), ( "server", models.ForeignKey( @@ -577,14 +717,44 @@ class Migration(migrations.Migration): "response_type", models.CharField(default=None, max_length=200, null=True), ), - ("request_headers", models.JSONField(default=dict, encoder=utype.utils.encode.JSONEncoder, null=True)), - ("response_headers", models.JSONField(default=dict, encoder=utype.utils.encode.JSONEncoder, null=True)), - ("user_agent", models.JSONField(default=None, encoder=utype.utils.encode.JSONEncoder, null=True)), + ( + "request_headers", + models.JSONField( + default=dict, encoder=utype.utils.encode.JSONEncoder, null=True + ), + ), + ( + "response_headers", + models.JSONField( + default=dict, encoder=utype.utils.encode.JSONEncoder, null=True + ), + ), + ( + "user_agent", + models.JSONField( + default=None, encoder=utype.utils.encode.JSONEncoder, null=True + ), + ), ("status", models.PositiveSmallIntegerField(default=200, null=True)), ("length", models.PositiveBigIntegerField(default=0, null=True)), - ("query", models.JSONField(default=dict, encoder=utype.utils.encode.JSONEncoder, null=True)), - ("data", models.JSONField(default=None, encoder=utype.utils.encode.JSONEncoder, null=True)), - ("result", models.JSONField(default=None, encoder=utype.utils.encode.JSONEncoder, null=True)), + ( + "query", + models.JSONField( + default=dict, encoder=utype.utils.encode.JSONEncoder, null=True + ), + ), + ( + "data", + models.JSONField( + default=None, encoder=utype.utils.encode.JSONEncoder, null=True + ), + ), + ( + "result", + models.JSONField( + default=None, encoder=utype.utils.encode.JSONEncoder, null=True + ), + ), ("full_url", models.URLField(default=None, null=True)), ("path", models.URLField(default=None, null=True)), ("in_traffic", models.PositiveBigIntegerField(default=0)), @@ -656,7 +826,12 @@ class Migration(migrations.Migration): ("duration", models.PositiveBigIntegerField(default=None, null=True)), ("message", models.TextField(default="")), ("operation", models.CharField(default=None, max_length=32, null=True)), - ("tables", models.JSONField(default=list, encoder=utype.utils.encode.JSONEncoder)), + ( + "tables", + models.JSONField( + default=list, encoder=utype.utils.encode.JSONEncoder + ), + ), ( "context_type", models.CharField(default=None, max_length=40, null=True), @@ -730,7 +905,12 @@ class Migration(migrations.Migration): ("open_files", models.PositiveBigIntegerField(default=None, null=True)), ("active_net_connections", models.PositiveIntegerField(default=0)), ("total_net_connections", models.PositiveIntegerField(default=0)), - ("net_connections_info", models.JSONField(default=dict, encoder=utype.utils.encode.JSONEncoder)), + ( + "net_connections_info", + models.JSONField( + default=dict, encoder=utype.utils.encode.JSONEncoder + ), + ), ("in_traffic", models.PositiveBigIntegerField(default=0, null=True)), ("out_traffic", models.PositiveBigIntegerField(default=0, null=True)), ("outbound_requests", models.PositiveIntegerField(default=0)), @@ -784,7 +964,12 @@ class Migration(migrations.Migration): decimal_places=2, default=None, max_digits=10, null=True ), ), - ("metrics", models.JSONField(default=dict, encoder=utype.utils.encode.JSONEncoder)), + ( + "metrics", + models.JSONField( + default=dict, encoder=utype.utils.encode.JSONEncoder + ), + ), ( "instance", models.ForeignKey( @@ -825,7 +1010,12 @@ class Migration(migrations.Migration): ("current_connections", models.PositiveBigIntegerField(default=0)), ("server_connections", models.PositiveBigIntegerField(default=0)), ("new_transactions", models.PositiveBigIntegerField(default=0)), - ("metrics", models.JSONField(default=dict, encoder=utype.utils.encode.JSONEncoder)), + ( + "metrics", + models.JSONField( + default=dict, encoder=utype.utils.encode.JSONEncoder + ), + ), ("queries_num", models.PositiveBigIntegerField(default=0)), ( "qps", @@ -837,7 +1027,12 @@ class Migration(migrations.Migration): "query_avg_time", models.DecimalField(decimal_places=2, default=0, max_digits=10), ), - ("operations", models.JSONField(default=dict, encoder=utype.utils.encode.JSONEncoder)), + ( + "operations", + models.JSONField( + default=dict, encoder=utype.utils.encode.JSONEncoder + ), + ), ( "database", models.ForeignKey( @@ -871,13 +1066,23 @@ class Migration(migrations.Migration): ("pid", models.PositiveIntegerField(default=None, null=True)), ("query", models.TextField(default="")), ("operation", models.CharField(default=None, max_length=32, null=True)), - ("tables", models.JSONField(default=list, encoder=utype.utils.encode.JSONEncoder)), + ( + "tables", + models.JSONField( + default=list, encoder=utype.utils.encode.JSONEncoder + ), + ), ("backend_start", models.DateTimeField(default=None, null=True)), ("transaction_start", models.DateTimeField(default=None, null=True)), ("wait_event", models.TextField(default=None, null=True)), ("query_start", models.DateTimeField(default=None, null=True)), ("state_change", models.DateTimeField(default=None, null=True)), - ("data", models.JSONField(default=dict, encoder=utype.utils.encode.JSONEncoder)), + ( + "data", + models.JSONField( + default=dict, encoder=utype.utils.encode.JSONEncoder + ), + ), ( "database", models.ForeignKey( @@ -937,7 +1142,12 @@ class Migration(migrations.Migration): decimal_places=2, default=None, max_digits=10, null=True ), ), - ("metrics", models.JSONField(default=dict, encoder=utype.utils.encode.JSONEncoder)), + ( + "metrics", + models.JSONField( + default=dict, encoder=utype.utils.encode.JSONEncoder + ), + ), ( "cache", models.ForeignKey( @@ -1065,7 +1275,12 @@ class Migration(migrations.Migration): ("last_activity", models.DateTimeField(default=None, null=True)), ("used_times", models.PositiveIntegerField(default=0)), ("ip", models.GenericIPAddressField(default=None, null=True)), - ("scope", models.JSONField(default=list, encoder=utype.utils.encode.JSONEncoder)), + ( + "scope", + models.JSONField( + default=list, encoder=utype.utils.encode.JSONEncoder + ), + ), ("revoked", models.BooleanField(default=False)), ( "issuer", diff --git a/utilmeta/ops/migrations/0003_aggregationlog.py b/utilmeta/ops/migrations/0003_aggregationlog.py index 4b6f118..1d4fd36 100644 --- a/utilmeta/ops/migrations/0003_aggregationlog.py +++ b/utilmeta/ops/migrations/0003_aggregationlog.py @@ -31,7 +31,12 @@ class Migration(migrations.Migration): db_index=True, default=None, max_length=100, null=True ), ), - ("data", models.JSONField(default=dict, encoder=utype.utils.encode.JSONEncoder)), + ( + "data", + models.JSONField( + default=dict, encoder=utype.utils.encode.JSONEncoder + ), + ), ("layer", models.PositiveSmallIntegerField(default=0)), ("from_time", models.DateTimeField()), ("to_time", models.DateTimeField()), diff --git a/utilmeta/ops/models.py b/utilmeta/ops/models.py index 9a4168e..b1b725b 100644 --- a/utilmeta/ops/models.py +++ b/utilmeta/ops/models.py @@ -37,7 +37,9 @@ class Supervisor(models.Model): # will require an update url = models.URLField(default=None, null=True) - operation_timeout = models.DecimalField(max_digits=8, decimal_places=3, default=None, null=True) + operation_timeout = models.DecimalField( + max_digits=8, decimal_places=3, default=None, null=True + ) # open_scopes = models.JSONField(default=list) # disabled_scopes = models.JSONField(default=list) @@ -70,30 +72,27 @@ class Supervisor(models.Model): # etag is generated from supervisor class Meta: - db_table = 'utilmeta_supervisor' + db_table = "utilmeta_supervisor" @classmethod def filter(cls, *args, **kwargs) -> models.QuerySet: - q = models.Q( - local=True - ) | models.Q( + q = models.Q(local=True) | models.Q( public_key__isnull=False, ) - kwargs.update( - node_id__isnull=False, - disabled=False - ) + kwargs.update(node_id__isnull=False, disabled=False) return cls.objects.filter(*args, q, **kwargs) @classmethod def current(cls) -> models.QuerySet: from .config import Operations + config = Operations.config() kwargs = {} if config and config.node_id: kwargs.update(node_id=config.node_id) else: from utilmeta import service + kwargs.update(service=service.name) return cls.filter(**kwargs) @@ -101,7 +100,9 @@ def current(cls) -> models.QuerySet: class AccessToken(models.Model): objects = models.Manager() - issuer = models.ForeignKey(Supervisor, related_name='access_tokens', on_delete=models.CASCADE) + issuer = models.ForeignKey( + Supervisor, related_name="access_tokens", on_delete=models.CASCADE + ) token_id = models.CharField(max_length=200, unique=True) issued_at = models.DateTimeField(default=None, null=True) subject = models.CharField(max_length=500, default=None, null=True) @@ -123,7 +124,7 @@ class AccessToken(models.Model): # revoke tokens of a subject if it's permission is changed or revoked class Meta: - db_table = 'utilmeta_access_token' + db_table = "utilmeta_access_token" class Resource(models.Model): @@ -153,9 +154,11 @@ class Resource(models.Model): # supervisor = models.ForeignKey(Supervisor, related_name='resources', on_delete=models.CASCADE) server = models.ForeignKey( - 'self', related_name='resources', + "self", + related_name="resources", on_delete=models.SET_NULL, - default=None, null=True + default=None, + null=True, ) server_id: Optional[int] data: dict = models.JSONField(default=dict, encoder=JSONEncoder) @@ -169,7 +172,7 @@ class Resource(models.Model): # server / instance / database / cached -> disconnected class Meta: - db_table = 'utilmeta_resource' + db_table = "utilmeta_resource" @classmethod def filter(cls, *args, **kwargs): @@ -177,33 +180,35 @@ def filter(cls, *args, **kwargs): return cls.objects.filter(*args, **kwargs) @classmethod - def get_current_server(cls) -> Optional['Resource']: + def get_current_server(cls) -> Optional["Resource"]: # from utilmeta.utils import get_server_ip + from utilmeta.core.orm.backends.django.expressions import OrderBy from utilmeta.utils import get_mac_address from .config import Operations + config = Operations.config() qs = cls.filter( - type='server', + type="server", ident=get_mac_address(), ) if config.node_id: - qs = qs.filter( - models.Q(node_id=config.node_id) | models.Q(node_id=None) - ) + qs = qs.filter(models.Q(node_id=config.node_id) | models.Q(node_id=None)) return qs.order_by( - models.OrderBy(models.F('remote_id'), nulls_last=True), - models.OrderBy(models.F('node_id'), nulls_last=True), - 'deprecated', '-created_time' + OrderBy(models.F("remote_id"), nulls_last=True), + OrderBy(models.F("node_id"), nulls_last=True), + "deprecated", + "-created_time", ).first() # first go with the remote_id @classmethod - def get_current_instance(cls) -> Optional['Resource']: + def get_current_instance(cls) -> Optional["Resource"]: from utilmeta import service from .config import Operations + config = service.get_config(Operations) return cls.filter( - type='instance', + type="instance", service=service.name, ident=config.address # server=cls.get_current_server(), @@ -230,15 +235,17 @@ class Meta: class DatabaseConnection(models.Model): objects = models.Manager() - database: Resource = models.ForeignKey(Resource, related_name='database_connections', on_delete=models.CASCADE) + database: Resource = models.ForeignKey( + Resource, related_name="database_connections", on_delete=models.CASCADE + ) # remote_id = CharField(max_length=100) status = models.CharField(max_length=40) active = models.BooleanField(default=False) - client_addr = models.GenericIPAddressField() # mysql use ADDR:PORT as HOST + client_addr = models.GenericIPAddressField() # mysql use ADDR:PORT as HOST client_port = models.PositiveIntegerField() pid = models.PositiveIntegerField(default=None, null=True) - query = models.TextField(default='') + query = models.TextField(default="") operation = models.CharField(max_length=32, default=None, null=True) tables = models.JSONField(default=list, encoder=JSONEncoder) @@ -251,7 +258,7 @@ class DatabaseConnection(models.Model): data = models.JSONField(default=dict, encoder=JSONEncoder) class Meta: - db_table = 'utilmeta_database_connection' + db_table = "utilmeta_database_connection" # unique_together = ('database', 'remote_id') @@ -259,13 +266,20 @@ class ServiceMetrics(models.Model): """ request metrics that can simply be calculated in form of incr and divide """ - in_traffic = models.PositiveBigIntegerField(default=0, null=True) # in bytes - out_traffic = models.PositiveBigIntegerField(default=0, null=True) # in bytes + + in_traffic = models.PositiveBigIntegerField(default=0, null=True) # in bytes + out_traffic = models.PositiveBigIntegerField(default=0, null=True) # in bytes # avg process time of requests made from this service - outbound_requests = models.PositiveIntegerField(default=0) # total request log count + outbound_requests = models.PositiveIntegerField( + default=0 + ) # total request log count outbound_rps = models.DecimalField(max_digits=10, decimal_places=2, default=0.00) - outbound_timeouts = models.PositiveBigIntegerField(default=0) # total timeout outbound_requests - outbound_errors = models.PositiveBigIntegerField(default=0) # total error outbound_requests + outbound_timeouts = models.PositiveBigIntegerField( + default=0 + ) # total timeout outbound_requests + outbound_errors = models.PositiveBigIntegerField( + default=0 + ) # total error outbound_requests outbound_avg_time = models.DecimalField(max_digits=10, decimal_places=2, default=0) queries_num = models.PositiveBigIntegerField(default=0) @@ -286,23 +300,37 @@ class Meta: class Worker(SystemMetrics, ServiceMetrics): - server = models.ForeignKey(Resource, related_name='server_workers', on_delete=models.CASCADE) - instance = models.ForeignKey(Resource, related_name='instance_workers', on_delete=models.CASCADE) + server = models.ForeignKey( + Resource, related_name="server_workers", on_delete=models.CASCADE + ) + instance = models.ForeignKey( + Resource, related_name="instance_workers", on_delete=models.CASCADE + ) pid: int = models.PositiveIntegerField() memory_info = models.JSONField(default=dict, encoder=JSONEncoder) threads = models.PositiveIntegerField(default=0) start_time: datetime = models.DateTimeField(default=time_now) # utility = ForeignKey(ServiceUtility, related_name='workers', on_delete=SET_NULL, null=True, default=None) - master = models.ForeignKey('self', related_name='workers', on_delete=models.CASCADE, null=True, default=None) + master = models.ForeignKey( + "self", + related_name="workers", + on_delete=models.CASCADE, + null=True, + default=None, + ) connected = models.BooleanField(default=True) # type = ChoiceField(WorkerType.gen(), retrieve_key=False, store_key=False, default=WorkerType.common) - time: datetime = models.DateTimeField(default=time_now) # latest metrics update time + time: datetime = models.DateTimeField( + default=time_now + ) # latest metrics update time status = models.CharField(max_length=100, default=None, null=True) user = models.CharField(max_length=100, default=None, null=True) - retire_time = models.DateTimeField(default=None, null=True) # only work for task worker for now + retire_time = models.DateTimeField( + default=None, null=True + ) # only work for task worker for now reload_params = models.JSONField(default=dict, encoder=JSONEncoder) # worker reload_on_rss (for uwsgi) # task.max_worker_memory (for task) @@ -310,8 +338,8 @@ class Worker(SystemMetrics, ServiceMetrics): data = models.JSONField(default=dict, encoder=JSONEncoder) class Meta: - db_table = 'utilmeta_worker' - unique_together = ('server', 'pid') + db_table = "utilmeta_worker" + unique_together = ("server", "pid") @classmethod def get(cls, pid: int): @@ -321,6 +349,7 @@ def get(cls, pid: int): def get_sys_metrics(self): import psutil + try: process = psutil.Process(self.pid) except psutil.Error: @@ -332,25 +361,30 @@ def get_sys_metrics(self): except psutil.Error: open_files = 0 try: - net_connections = getattr(process, 'net_connections', getattr(process, 'connections'))() + net_connections = getattr( + process, "net_connections", getattr(process, "connections") + )() except (psutil.Error, AttributeError): net_connections = [] return dict( - used_memory=getattr(mem_info, 'uss', getattr(mem_info, 'rss', 0)) or 0, - memory_info={f: getattr(mem_info, f) for f in getattr(mem_info, '_fields')}, + used_memory=getattr(mem_info, "uss", getattr(mem_info, "rss", 0)) or 0, + memory_info={f: getattr(mem_info, f) for f in getattr(mem_info, "_fields")}, total_net_connections=len(net_connections), - active_net_connections=len([c for c in net_connections if c.status != 'CLOSE_WAIT']), + active_net_connections=len( + [c for c in net_connections if c.status != "CLOSE_WAIT"] + ), file_descriptors=process.num_fds() if psutil.POSIX else None, cpu_percent=process.cpu_percent(interval=1), - memory_percent=round(process.memory_percent('uss'), 2), + memory_percent=round(process.memory_percent("uss"), 2), open_files=open_files, threads=process.num_threads(), ) @classmethod - def load(cls, pid: int = None, **kwargs) -> Optional['Worker']: + def load(cls, pid: int = None, **kwargs) -> Optional["Worker"]: import psutil import os + pid = pid or os.getpid() try: proc = psutil.Process(pid) @@ -375,13 +409,11 @@ def load(cls, pid: int = None, **kwargs) -> Optional['Worker']: data.update(kwargs) master = cls.objects.filter( - pid=parent_pid, - server=server, - instance=instance + pid=parent_pid, server=server, instance=instance ).first() if master: - data.update(master_id=master.pk) # do not user master= + data.update(master_id=master.pk) # do not user master= # prevent transaction worker_qs = cls.objects.filter( @@ -391,21 +423,14 @@ def load(cls, pid: int = None, **kwargs) -> Optional['Worker']: if worker_qs.exists(): # IMPORTANT: update time as soon as Worker.load is triggered # to provide realtime metrics for representative - data.update( - time=time_now(), - connected=True - ) + data.update(time=time_now(), connected=True) worker_qs.update(**data) return worker_qs.first() from django.db.utils import IntegrityError + try: - return cls.objects.create( - server=server, - instance=instance, - pid=pid, - **data - ) + return cls.objects.create(server=server, instance=instance, pid=pid, **data) except IntegrityError: return worker_qs.first() @@ -415,38 +440,48 @@ class ServerMonitor(SystemMetrics): # task_settings = ForeignKey(TaskSettings, on_delete=SET_NULL, default=None, null=True) layer = models.PositiveSmallIntegerField(default=0) interval = models.PositiveIntegerField(default=None, null=True) # in seconds - server = models.ForeignKey(Resource, related_name='server_metrics', on_delete=models.CASCADE) + server = models.ForeignKey( + Resource, related_name="server_metrics", on_delete=models.CASCADE + ) # version = ForeignKey(VersionLog, on_delete=SET_NULL, null=True, default=None) - load_avg_1 = models.DecimalField(max_digits=8, decimal_places=2, default=None, null=True) - load_avg_5 = models.DecimalField(max_digits=8, decimal_places=2, default=None, null=True) - load_avg_15 = models.DecimalField(max_digits=8, decimal_places=2, default=None, null=True) + load_avg_1 = models.DecimalField( + max_digits=8, decimal_places=2, default=None, null=True + ) + load_avg_5 = models.DecimalField( + max_digits=8, decimal_places=2, default=None, null=True + ) + load_avg_15 = models.DecimalField( + max_digits=8, decimal_places=2, default=None, null=True + ) # alert = ForeignKey(AlertLog, related_name='source_metrics', on_delete=SET_NULL, null=True, default=None) metrics = models.JSONField(default=dict, encoder=JSONEncoder) class Meta: - db_table = 'utilmeta_server_monitor' - ordering = ('time',) + db_table = "utilmeta_server_monitor" + ordering = ("time",) @classmethod - def current(cls) -> Optional['ServerMonitor']: - return cls.objects.last() # already order by time + def current(cls) -> Optional["ServerMonitor"]: + return cls.objects.last() # already order by time class WorkerMonitor(SystemMetrics, ServiceMetrics): time = models.DateTimeField(default=time_now) interval = models.PositiveIntegerField(default=None, null=True) # in seconds - worker = models.ForeignKey(Worker, related_name='worker_metrics', on_delete=models.CASCADE) + worker = models.ForeignKey( + Worker, related_name="worker_metrics", on_delete=models.CASCADE + ) memory_info = models.JSONField(default=dict, encoder=JSONEncoder) threads = models.PositiveIntegerField(default=0) metrics = models.JSONField(default=dict, encoder=JSONEncoder) # extra metrics class Meta: - db_table = 'utilmeta_worker_monitor' - ordering = ('time',) + db_table = "utilmeta_worker_monitor" + ordering = ("time",) @classmethod - def current(cls) -> Optional['WorkerMonitor']: - return cls.objects.last() # already order by time + def current(cls) -> Optional["WorkerMonitor"]: + return cls.objects.last() # already order by time class InstanceMonitor(SystemMetrics, ServiceMetrics): @@ -454,20 +489,24 @@ class InstanceMonitor(SystemMetrics, ServiceMetrics): layer = models.PositiveSmallIntegerField(default=0) interval = models.PositiveIntegerField(default=None, null=True) # in seconds - instance = models.ForeignKey(Resource, related_name='instance_metrics', on_delete=models.CASCADE) + instance = models.ForeignKey( + Resource, related_name="instance_metrics", on_delete=models.CASCADE + ) threads = models.PositiveIntegerField(default=0) current_workers = models.PositiveIntegerField(default=0) avg_worker_lifetime = models.PositiveBigIntegerField(default=None, null=True) new_spawned_workers = models.PositiveBigIntegerField(default=0) - avg_workers = models.DecimalField(default=None, null=True, max_digits=10, decimal_places=2) + avg_workers = models.DecimalField( + default=None, null=True, max_digits=10, decimal_places=2 + ) - metrics = models.JSONField(default=dict, encoder=JSONEncoder) # extra metrics + metrics = models.JSONField(default=dict, encoder=JSONEncoder) # extra metrics class Meta: - db_table = 'utilmeta_instance_monitor' - ordering = ('time',) + db_table = "utilmeta_instance_monitor" + ordering = ("time",) class DatabaseMonitor(models.Model): @@ -477,7 +516,9 @@ class DatabaseMonitor(models.Model): layer = models.PositiveSmallIntegerField(default=0) interval = models.PositiveIntegerField(default=None, null=True) # in seconds - database = models.ForeignKey(Resource, on_delete=models.CASCADE, related_name='database_metrics') + database = models.ForeignKey( + Resource, on_delete=models.CASCADE, related_name="database_metrics" + ) used_space = models.PositiveBigIntegerField(default=0) # used disk space server_used_space = models.PositiveBigIntegerField(default=0) # used disk space @@ -492,11 +533,13 @@ class DatabaseMonitor(models.Model): queries_num = models.PositiveBigIntegerField(default=0) qps = models.DecimalField(max_digits=10, decimal_places=2, default=None, null=True) query_avg_time = models.DecimalField(max_digits=10, decimal_places=2, default=0) - operations = models.JSONField(default=dict, encoder=JSONEncoder) # {'SELECT': 100, 'UPDATE': 21, ...} + operations = models.JSONField( + default=dict, encoder=JSONEncoder + ) # {'SELECT': 100, 'UPDATE': 21, ...} class Meta: - db_table = 'utilmeta_database_monitor' - ordering = ('time',) + db_table = "utilmeta_database_monitor" + ordering = ("time",) class CacheMonitor(models.Model): @@ -506,10 +549,16 @@ class CacheMonitor(models.Model): layer = models.PositiveSmallIntegerField(default=0) interval = models.PositiveIntegerField(default=None, null=True) # in seconds - cache = models.ForeignKey(Resource, on_delete=models.CASCADE, related_name='cache_metrics') + cache = models.ForeignKey( + Resource, on_delete=models.CASCADE, related_name="cache_metrics" + ) - cpu_percent = models.DecimalField(max_digits=6, decimal_places=2, default=None, null=True) - memory_percent = models.DecimalField(max_digits=6, decimal_places=2, default=None, null=True) + cpu_percent = models.DecimalField( + max_digits=6, decimal_places=2, default=None, null=True + ) + memory_percent = models.DecimalField( + max_digits=6, decimal_places=2, default=None, null=True + ) used_memory = models.PositiveBigIntegerField(default=0) file_descriptors = models.PositiveIntegerField(default=None, null=True) open_files = models.PositiveBigIntegerField(default=None, null=True) @@ -521,14 +570,15 @@ class CacheMonitor(models.Model): metrics = models.JSONField(default=dict, encoder=JSONEncoder) class Meta: - db_table = 'utilmeta_cache_monitor' - ordering = ('time',) + db_table = "utilmeta_cache_monitor" + ordering = ("time",) class WebMixin(models.Model): """ - Log data using http/https schemes + Log data using http/https schemes """ + scheme = models.CharField(default=None, null=True, max_length=20) method = models.CharField(default=None, null=True, max_length=20) # replace the unit property @@ -561,10 +611,12 @@ class VersionLog(models.Model): service = models.CharField(max_length=100) node_id = models.CharField(max_length=100, default=None, null=True, db_index=True) - instance: Resource = models.ForeignKey(Resource, on_delete=models.CASCADE, related_name='restart_records') + instance: Resource = models.ForeignKey( + Resource, on_delete=models.CASCADE, related_name="restart_records" + ) time: datetime = models.DateTimeField(auto_now_add=True) # finish_time: datetime = models.DateTimeField(default=None, null=True) - down_time = models.PositiveBigIntegerField(default=None, null=True) # ms + down_time = models.PositiveBigIntegerField(default=None, null=True) # ms # down time, None means no down time success = models.BooleanField(default=None, null=True) @@ -586,7 +638,7 @@ class VersionLog(models.Model): remote_id = models.CharField(max_length=100, default=None, null=True) class Meta: - db_table = 'utilmeta_version_log' + db_table = "utilmeta_version_log" # @property # def message(self): @@ -613,7 +665,7 @@ class AlertType(models.Model): # for downgrade types, configurable subcategory = models.CharField(max_length=200) - name = models.CharField(max_length=100) # settings name or custom name + name = models.CharField(max_length=100) # settings name or custom name target = models.TextField() # eg # type.category: resource_saturated @@ -626,19 +678,24 @@ class AlertType(models.Model): ident = models.CharField(max_length=500) - compress_window: int = models.PositiveBigIntegerField(null=True, default=None) # seconds + compress_window: int = models.PositiveBigIntegerField( + null=True, default=None + ) # seconds min_times: int = models.PositiveIntegerField(default=1) resource = models.ForeignKey( - Resource, on_delete=models.SET_NULL, - null=True, default=None, related_name='alert_types', + Resource, + on_delete=models.SET_NULL, + null=True, + default=None, + related_name="alert_types", ) created_time = models.DateTimeField(auto_now_add=True) class Meta: - db_table = 'utilmeta_alert_type' - unique_together = ('service', 'ident') + db_table = "utilmeta_alert_type" + unique_together = ("service", "ident") # @classmethod # def get(cls, ident: str) -> Optional['AlertSettings']: @@ -649,18 +706,29 @@ class Meta: class AlertLog(models.Model): objects = models.Manager() - type: AlertType = models.ForeignKey(AlertType, on_delete=models.CASCADE, related_name='alert_logs') + type: AlertType = models.ForeignKey( + AlertType, on_delete=models.CASCADE, related_name="alert_logs" + ) server: Resource = models.ForeignKey( - Resource, on_delete=models.CASCADE, - related_name='server_alert_logs', default=None, null=True, + Resource, + on_delete=models.CASCADE, + related_name="server_alert_logs", + default=None, + null=True, ) instance: Resource = models.ForeignKey( - Resource, on_delete=models.CASCADE, - related_name='instance_alert_logs', default=None, null=True, + Resource, + on_delete=models.CASCADE, + related_name="instance_alert_logs", + default=None, + null=True, ) version = models.ForeignKey( - VersionLog, related_name='alert_logs', - on_delete=models.SET_NULL, null=True, default=None, + VersionLog, + related_name="alert_logs", + on_delete=models.SET_NULL, + null=True, + default=None, ) # impact_requests = models.PositiveBigIntegerField(default=None, null=True) @@ -689,10 +757,10 @@ class AlertLog(models.Model): data = models.JSONField(default=None, null=True, encoder=JSONEncoder) class Meta: - db_table = 'utilmeta_alert_log' + db_table = "utilmeta_alert_log" @classmethod - def get(cls, id) -> Optional['AlertLog']: + def get(cls, id) -> Optional["AlertLog"]: if not id: return None return cls.objects.filter(id=id).first() @@ -705,7 +773,9 @@ def uncertain(self): def compressible(self): if not self.type.compress_window: return False - return (time_now() - self.latest_time).total_seconds() < self.type.compress_window + return ( + time_now() - self.latest_time + ).total_seconds() < self.type.compress_window def relieve(self): if self.relieved_time: @@ -715,8 +785,8 @@ def relieve(self): self.delete() return True self.relieved_time = time_now() - self.save(update_fields=['relieved_time']) - return self.compressible # only report legit relieve + self.save(update_fields=["relieved_time"]) + return self.compressible # only report legit relieve class ServiceLog(WebMixin): @@ -731,16 +801,25 @@ class ServiceLog(WebMixin): # global identifier version = models.ForeignKey( - VersionLog, related_name='service_logs', - on_delete=models.SET_NULL, null=True, default=None, + VersionLog, + related_name="service_logs", + on_delete=models.SET_NULL, + null=True, + default=None, ) instance = models.ForeignKey( - Resource, related_name='instance_logs', - on_delete=models.SET_NULL, null=True, default=None, + Resource, + related_name="instance_logs", + on_delete=models.SET_NULL, + null=True, + default=None, ) endpoint = models.ForeignKey( - Resource, on_delete=models.SET_NULL, - null=True, default=None, related_name='endpoint_logs', + Resource, + on_delete=models.SET_NULL, + null=True, + default=None, + related_name="endpoint_logs", ) # incase the endpoint not loaded yet endpoint_ident = models.CharField(max_length=200, default=None, null=True) @@ -752,14 +831,19 @@ class ServiceLog(WebMixin): # units = ArrayField(CharField(max_length=10), default=list) # the orders that the process unit is called worker = models.ForeignKey( - Worker, related_name='logs', - on_delete=models.SET_NULL, null=True, default=None, + Worker, + related_name="logs", + on_delete=models.SET_NULL, + null=True, + default=None, ) # ----- level = models.CharField(max_length=30) # volatile log will be deleted when it is not count by any aggregates - time = models.DateTimeField() # not auto_now_add, cache stored log may add after request for some time + time = ( + models.DateTimeField() + ) # not auto_now_add, cache stored log may add after request for some time duration = models.PositiveBigIntegerField(default=None, null=True) # for http requests duration is the time between server receive request and send response # for ws requests duration is the time between client open a ws connection and close it @@ -774,24 +858,33 @@ class ServiceLog(WebMixin): messages = models.JSONField(default=list, encoder=JSONEncoder) alert = models.ForeignKey( - 'AlertLog', related_name='service_logs', - on_delete=models.SET_NULL, null=True, default=None, + "AlertLog", + related_name="service_logs", + on_delete=models.SET_NULL, + null=True, + default=None, ) volatile = models.BooleanField(default=True) # deleted_time = models.DateTimeField(default=None, null=True) # ---------- supervisor = models.ForeignKey( - Supervisor, related_name='logs', on_delete=models.SET_NULL, - default=None, null=True, + Supervisor, + related_name="logs", + on_delete=models.SET_NULL, + default=None, + null=True, ) access_token = models.ForeignKey( - AccessToken, related_name='logs', on_delete=models.SET_NULL, - default=None, null=True, + AccessToken, + related_name="logs", + on_delete=models.SET_NULL, + default=None, + null=True, ) class Meta: - db_table = 'utilmeta_service_log' + db_table = "utilmeta_service_log" class RequestLog(WebMixin): @@ -803,7 +896,9 @@ class RequestLog(WebMixin): # volatile = models.BooleanField(default=True) # requests made in other service request context - time = models.DateTimeField() # not auto_now_add, cache stored log may add after request for some time + time = ( + models.DateTimeField() + ) # not auto_now_add, cache stored log may add after request for some time # version = models.ForeignKey( # VersionLog, related_name='service_logs', # on_delete=models.SET_NULL, null=True, default=None @@ -811,8 +906,11 @@ class RequestLog(WebMixin): duration = models.PositiveBigIntegerField(default=None, null=True) worker = models.ForeignKey( - Worker, related_name='request_logs', - on_delete=models.SET_NULL, null=True, default=None, + Worker, + related_name="request_logs", + on_delete=models.SET_NULL, + null=True, + default=None, ) context_type = models.CharField(max_length=40, default=None, null=True) @@ -820,28 +918,43 @@ class RequestLog(WebMixin): context_id = models.CharField(max_length=200, default=None, null=True) # log id / execution id - host = models.URLField(default=None, null=True) # host of the requested host (ip or domain name) + host = models.URLField( + default=None, null=True + ) # host of the requested host (ip or domain name) - remote_log = models.TextField(default=None, null=True) # able to supply other type ident (eg. uuid) + remote_log = models.TextField( + default=None, null=True + ) # able to supply other type ident (eg. uuid) # remote utilmeta log id (in target service) to support recursive tracing asynchronous = models.BooleanField(default=None, null=True) - timeout = models.DecimalField(max_digits=10, decimal_places=2, default=None, null=True) + timeout = models.DecimalField( + max_digits=10, decimal_places=2, default=None, null=True + ) # SINGLE REQUEST, RETIRES NOT INCLUDED - timeout_error = models.BooleanField(default=False) # request is timeout - server_error = models.BooleanField(default=False) # ssl cert error when query the target host - client_error = models.BooleanField(default=False) # ssl cert error when query the target host - ssl_error = models.BooleanField(default=False) # ssl cert error when query the target host + timeout_error = models.BooleanField(default=False) # request is timeout + server_error = models.BooleanField( + default=False + ) # ssl cert error when query the target host + client_error = models.BooleanField( + default=False + ) # ssl cert error when query the target host + ssl_error = models.BooleanField( + default=False + ) # ssl cert error when query the target host dns_error = models.BooleanField(default=False) alert = models.ForeignKey( - AlertLog, related_name='request_logs', - on_delete=models.SET_NULL, null=True, default=None, + AlertLog, + related_name="request_logs", + on_delete=models.SET_NULL, + null=True, + default=None, ) class Meta: - db_table = 'utilmeta_request_log' + db_table = "utilmeta_request_log" class QueryLog(models.Model): @@ -855,19 +968,26 @@ class QueryLog(models.Model): # on_delete=models.SET_NULL, null=True, default=None # ) database = models.ForeignKey( - Resource, on_delete=models.CASCADE, - related_name='database_query_logs', + Resource, + on_delete=models.CASCADE, + related_name="database_query_logs", ) model = models.ForeignKey( - Resource, on_delete=models.SET_NULL, default=None, null=True, - related_name='model_query_logs' + Resource, + on_delete=models.SET_NULL, + default=None, + null=True, + related_name="model_query_logs", ) query = models.TextField() duration = models.PositiveBigIntegerField(default=None, null=True) # ms - message = models.TextField(default='') + message = models.TextField(default="") worker = models.ForeignKey( - Worker, related_name='query_logs', on_delete=models.SET_NULL, - null=True, default=None, + Worker, + related_name="query_logs", + on_delete=models.SET_NULL, + null=True, + default=None, ) operation = models.CharField(max_length=32, default=None, null=True) @@ -879,12 +999,15 @@ class QueryLog(models.Model): # log id / execution id alert = models.ForeignKey( - AlertLog, related_name='query_logs', - on_delete=models.SET_NULL, null=True, default=None, + AlertLog, + related_name="query_logs", + on_delete=models.SET_NULL, + null=True, + default=None, ) class Meta: - db_table = 'utilmeta_query_log' + db_table = "utilmeta_query_log" class AggregationLog(models.Model): @@ -893,8 +1016,11 @@ class AggregationLog(models.Model): service = models.CharField(max_length=100) node_id = models.CharField(max_length=100, default=None, null=True, db_index=True) supervisor = models.ForeignKey( - Supervisor, related_name='aggregation_logs', on_delete=models.SET_NULL, - default=None, null=True, + Supervisor, + related_name="aggregation_logs", + on_delete=models.SET_NULL, + default=None, + null=True, ) data = models.JSONField(default=dict, encoder=JSONEncoder) @@ -915,7 +1041,7 @@ class AggregationLog(models.Model): # remote id is not None means the report is successful class Meta: - db_table = 'utilmeta_aggregation_log' + db_table = "utilmeta_aggregation_log" supervisor_related_models = [ diff --git a/utilmeta/ops/monitor.py b/utilmeta/ops/monitor.py index a9b111e..60b5ea4 100644 --- a/utilmeta/ops/monitor.py +++ b/utilmeta/ops/monitor.py @@ -5,8 +5,18 @@ import utype -from utilmeta.utils import get_max_open_files, get_max_socket_conn, get_mac_address, get_sql_info, \ - get_sys_net_connections_info, get_system_fds, get_system_open_files, get_server_ip, DB, ignore_errors +from utilmeta.utils import ( + get_max_open_files, + get_max_socket_conn, + get_mac_address, + get_sql_info, + get_sys_net_connections_info, + get_system_fds, + get_system_open_files, + get_server_ip, + DB, + ignore_errors, +) from .schema import ServerSchema from .models import DatabaseConnection from utilmeta.core.orm import DatabaseConnections @@ -15,7 +25,7 @@ import sys -def get_current_server(unit: int = 1024 ** 2) -> ServerSchema: +def get_current_server(unit: int = 1024**2) -> ServerSchema: mem = psutil.virtual_memory() disk = psutil.disk_usage(os.getcwd()) devices = {} @@ -24,7 +34,7 @@ def get_num(n): return round(n / unit) * unit for device in psutil.disk_partitions(): - if 'loop' in device.device: + if "loop" in device.device: continue try: disk_usage = psutil.disk_usage(device.mountpoint) @@ -56,13 +66,14 @@ def get_num(n): release=platform.release(), machine=platform.machine(), processor=platform.processor(), - bits=platform.architecture()[0] - ) + bits=platform.architecture()[0], + ), ) def get_sys_metrics(cpu_interval: float = None, with_open_files: bool = True): from utilmeta.ops.query import SystemMetricsMixin + mem = psutil.virtual_memory() disk = psutil.disk_usage(os.getcwd()) total, active, info = get_sys_net_connections_info() @@ -78,7 +89,7 @@ def get_sys_metrics(cpu_interval: float = None, with_open_files: bool = True): open_files=open_files, active_net_connections=active, total_net_connections=total, - net_connections_info=info + net_connections_info=info, ) @@ -88,6 +99,7 @@ def get_redis_info(cache: Cache) -> dict: except ModuleNotFoundError: return {} from redis.exceptions import ConnectionError + try: con = Redis.from_url(cache.get_location()) return dict(con.info()) @@ -100,20 +112,23 @@ def get_cache_size(using: str) -> int: cache = CacheConnections.get(using) if not cache: return 0 - if cache.type == 'db': + if cache.type == "db": return get_db_size(using) - elif cache.type == 'file': + elif cache.type == "file": loc = cache.location return os.path.getsize(loc) - elif cache.type == 'redis': + elif cache.type == "redis": info = get_redis_info(cache) - return info.get('used_memory', 0) - elif cache.type == 'memcached': - if os.name == 'posix': + return info.get("used_memory", 0) + elif cache.type == "memcached": + if os.name == "posix": # echo only apply to unix systems try: - host, port = cache.location.split(':') - cmd = "echo \'stats\' | nc - w 1 %s %s | awk \'$2 == \"bytes\" { print $3 }\'" % (host, port) + host, port = cache.location.split(":") + cmd = ( + "echo 'stats' | nc - w 1 %s %s | awk '$2 == \"bytes\" { print $3 }'" + % (host, port) + ) res = os.popen(cmd).read() return int(res) except (OSError, TypeError): @@ -122,11 +137,15 @@ def get_cache_size(using: str) -> int: class CacheStatus(utype.Schema): - pid: int = utype.Field(alias_from=['process_id'], default=None, no_output=True) - used_memory: int = utype.Field(alias_from=['limit_maxbytes'], default=0) - current_connections: int = utype.Field(alias_from=['connected_clients', 'curr_connections'], default=0) - total_connections: int = utype.Field(alias_from=['total_connections_received'], default=0) - qps: float = utype.Field(alias_from=['instantaneous_ops_per_sec'], default=0) + pid: int = utype.Field(alias_from=["process_id"], default=None, no_output=True) + used_memory: int = utype.Field(alias_from=["limit_maxbytes"], default=0) + current_connections: int = utype.Field( + alias_from=["connected_clients", "curr_connections"], default=0 + ) + total_connections: int = utype.Field( + alias_from=["total_connections_received"], default=0 + ) + qps: float = utype.Field(alias_from=["instantaneous_ops_per_sec"], default=0) @ignore_errors(default=None) @@ -134,10 +153,11 @@ def get_cache_stats(using: str) -> Optional[CacheStatus]: cache = CacheConnections.get(using) if not cache: return None - if cache.type == 'redis': + if cache.type == "redis": return CacheStatus(get_redis_info(cache)) - elif cache.type == 'memcached': + elif cache.type == "memcached": from pymemcache.client.base import Client + mc = Client(cache.get_location()) return CacheStatus(mc.stats()) return None @@ -146,16 +166,17 @@ def get_cache_stats(using: str) -> Optional[CacheStatus]: @ignore_errors(default=list) def get_db_connections(using: str): db_sql = { - DB.PostgreSQL: "select pid, usename, client_addr, client_port, state," # noqa - " backend_start, query_start, state_change, xact_start, wait_event, query" # noqa - " from pg_stat_activity WHERE datname = '%s';", # noqa + DB.PostgreSQL: "select pid, usename, client_addr, client_port, state," # noqa + " backend_start, query_start, state_change, xact_start, wait_event, query" # noqa + " from pg_stat_activity WHERE datname = '%s';", # noqa DB.MySQL: "select * from information_schema.processlist where db = '%s';", # noqa - DB.Oracle: "select status from v$session where username='%s';" # noqa + DB.Oracle: "select status from v$session where username='%s';", # noqa } db = DatabaseConnections.get(using) if db.type not in db_sql: return [] from django.db import connections + with connections[db.alias].cursor() as cursor: db_type: str = str(cursor.db.display_name).lower() if db_type not in db_sql: @@ -167,11 +188,29 @@ def get_db_connections(using: str): result = cursor.fetchall() values = [] if db.type == DB.PostgreSQL: - for pid, usename, client_addr, client_port, state, \ - backend_start, query_start, state_change, xact_start, wait_event, query in result: + for ( + pid, + usename, + client_addr, + client_port, + state, + backend_start, + query_start, + state_change, + xact_start, + wait_event, + query, + ) in result: if usename != db.user: continue - if not pid or not usename or not client_addr or not client_port or not state or not query: + if ( + not pid + or not usename + or not client_addr + or not client_port + or not state + or not query + ): continue # find = False # for conn in current_connections: @@ -184,35 +223,38 @@ def get_db_connections(using: str): operation, tables = get_sql_info(query) if not operation: continue - values.append(DatabaseConnection( - status=state, - active=state == 'active', - client_addr=client_addr, - client_port=client_port, - pid=pid, - backend_start=backend_start, - query_start=query_start, - state_change=state_change, - wait_event=wait_event, - transaction_start=xact_start, - query=query, - operation=operation, - tables=tables - )) + values.append( + DatabaseConnection( + status=state, + active=state == "active", + client_addr=client_addr, + client_port=client_port, + pid=pid, + backend_start=backend_start, + query_start=query_start, + state_change=state_change, + wait_event=wait_event, + transaction_start=xact_start, + query=query, + operation=operation, + tables=tables, + ) + ) return values @ignore_errors(default=0) def get_db_server_connections(using: str): db_sql = { - DB.PostgreSQL: "select count(*) from pg_stat_activity", # noqa - DB.MySQL: "select count(*) from information_schema.processlist", # noqa - DB.Oracle: "select count(*) from v$session" # noqa + DB.PostgreSQL: "select count(*) from pg_stat_activity", # noqa + DB.MySQL: "select count(*) from information_schema.processlist", # noqa + DB.Oracle: "select count(*) from v$session", # noqa } db = DatabaseConnections.get(using) if db.type not in db_sql: return [] from django.db import connections + with connections[db.alias].cursor() as cursor: db_type: str = str(cursor.db.display_name).lower() cursor.execute(db_sql[db_type]) @@ -222,13 +264,14 @@ def get_db_server_connections(using: str): @ignore_errors(default=0) def get_db_connections_num(using: str) -> Tuple[Optional[int], Optional[int]]: from django.db import connections + db = DatabaseConnections.get(using) if not db: return None, None db_sql = { - DB.PostgreSQL: "select state from pg_stat_activity WHERE datname = '%s';", # noqa - DB.MySQL: "select command from information_schema.processlist where db = '%s';", # noqa - DB.Oracle: "select status from v$session where username='%s';" # noqa + DB.PostgreSQL: "select state from pg_stat_activity WHERE datname = '%s';", # noqa + DB.MySQL: "select command from information_schema.processlist where db = '%s';", # noqa + DB.Oracle: "select status from v$session where username='%s';", # noqa } if db.type not in db_sql: return None, None @@ -240,7 +283,7 @@ def get_db_connections_num(using: str) -> Tuple[Optional[int], Optional[int]]: cursor.execute(db_sql[db_type] % db_name) status = [str(result[0]).lower() for result in cursor.fetchall()] # for MySQL, command=Query means active, for others, state/status = active means active - active_count = len([s for s in status if s in ('active', 'query')]) + active_count = len([s for s in status if s in ("active", "query")]) conn_count = len(status) return conn_count, active_count @@ -248,11 +291,12 @@ def get_db_connections_num(using: str) -> Tuple[Optional[int], Optional[int]]: @ignore_errors(default=None) def get_db_size(using: str) -> int: from django.db import connections + db_sql = { DB.PostgreSQL: "select pg_database_size('%s');", - DB.MySQL: "select sum(DATA_LENGTH)+sum(INDEX_LENGTH) " # noqa - "from information_schema.tables where table_schema='%s';", # noqa - DB.Oracle: "select sum(bytes) from dba_segments where owner='%s'" # noqa + DB.MySQL: "select sum(DATA_LENGTH)+sum(INDEX_LENGTH) " # noqa + "from information_schema.tables where table_schema='%s';", # noqa + DB.Oracle: "select sum(bytes) from dba_segments where owner='%s'", # noqa } db = DatabaseConnections.get(using) if db.is_sqlite: @@ -263,7 +307,7 @@ def get_db_size(using: str) -> int: return 0 db_name: str = db.name if db_type == DB.Oracle: - db_name = f'{db.user}/{db.name}' + db_name = f"{db.user}/{db.name}" cursor.execute(db_sql[db_type] % db_name) return int(cursor.fetchone()[0]) @@ -271,10 +315,11 @@ def get_db_size(using: str) -> int: @ignore_errors(default=None) def get_db_server_size(using: str) -> int: from django.db import connections + db_sql = { - DB.PostgreSQL: "select sum(pg_database_size(pg_database.datname)) from pg_database;", # noqa - DB.MySQL: "select sum(DATA_LENGTH)+sum(INDEX_LENGTH) from information_schema.tables;", # noqa - DB.Oracle: "select sum(bytes) from dba_segments;" # noqa + DB.PostgreSQL: "select sum(pg_database_size(pg_database.datname)) from pg_database;", # noqa + DB.MySQL: "select sum(DATA_LENGTH)+sum(INDEX_LENGTH) from information_schema.tables;", # noqa + DB.Oracle: "select sum(bytes) from dba_segments;", # noqa } db = DatabaseConnections.get(using) if db.is_sqlite: @@ -290,9 +335,10 @@ def get_db_server_size(using: str) -> int: @ignore_errors(default=None) def get_db_max_connections(using: str) -> int: from django.db import connections + db_sql = { DB.PostgreSQL: "SHOW max_connections;", - DB.MySQL: 'SHOW VARIABLES LIKE "max_connections";' + DB.MySQL: 'SHOW VARIABLES LIKE "max_connections";', } with connections[using].cursor() as cursor: db_type: str = str(cursor.db.display_name).lower() @@ -312,8 +358,9 @@ def get_db_max_connections(using: str) -> int: @ignore_errors(default=None) def get_db_transactions(using: str) -> int: from django.db import connections + db_sql = { - DB.PostgreSQL: "select xact_commit from pg_stat_database where datname='%s';", # noqa + DB.PostgreSQL: "select xact_commit from pg_stat_database where datname='%s';", # noqa } db = DatabaseConnections.get(using) with connections[using].cursor() as cursor: diff --git a/utilmeta/ops/proxy.py b/utilmeta/ops/proxy.py index 5627e36..5e756e8 100644 --- a/utilmeta/ops/proxy.py +++ b/utilmeta/ops/proxy.py @@ -1,23 +1,24 @@ import utype from utilmeta.core import cli, request, response + # from utilmeta.utils import url_join from utype import Schema, Field from utype.types import * class RegistrySchema(Schema): - name: str = utype.Field(alias_from=['service_name', 'service_id']) + name: str = utype.Field(alias_from=["service_name", "service_id"]) title: Optional[str] = utype.Field(default=None, defer_default=True) - description: str = '' - address: str # host + port + description: str = "" + address: str # host + port # host: str = utype.Field(alias_from=['ip']) # port: Optional[int] = None # host:port base_url: str # address + base_route ops_api: str - instance_id: str = utype.Field(alias_from=['resource_id']) + instance_id: str = utype.Field(alias_from=["resource_id"]) # this field will be checked by the proxy # server_id: Optional[str] = Field(default=None, defer_default=True) # remote_id: Optional[str] = Field(default=None, defer_default=True) @@ -32,21 +33,22 @@ class RegistrySchema(Schema): # python / java / go / javascript / php utilmeta_version: str # python version - backend: str = 'utilmeta' + backend: str = "utilmeta" # runtime framework backend_version: str = Field(required=False) resources: Optional[dict] = Field(default=None, defer_default=True) def get_metadata(self): from .client import NodeMetadata - data = (self.resources or {}).get('metadata') or dict( + + data = (self.resources or {}).get("metadata") or dict( ops_api=self.ops_api, base_url=self.base_url, name=self.name, title=self.title, description=self.description, version=self.version, - production=self.production + production=self.production, ) return NodeMetadata(data) @@ -96,7 +98,7 @@ class RegistryResponse(response.Response): class ProxyClient(cli.Client): proxy: cli.Client - @cli.post('registry') + @cli.post("registry") def register_service(self, data: RegistrySchema = request.Body) -> RegistryResponse: pass diff --git a/utilmeta/ops/query.py b/utilmeta/ops/query.py index 3ddd4b5..9edfe08 100644 --- a/utilmeta/ops/query.py +++ b/utilmeta/ops/query.py @@ -1,6 +1,16 @@ from utype.types import * -from .models import ServiceLog, AccessToken, Worker, WorkerMonitor, \ - ServerMonitor, InstanceMonitor, DatabaseConnection, Supervisor, DatabaseMonitor, CacheMonitor +from .models import ( + ServiceLog, + AccessToken, + Worker, + WorkerMonitor, + ServerMonitor, + InstanceMonitor, + DatabaseConnection, + Supervisor, + DatabaseMonitor, + CacheMonitor, +) from utilmeta.core import orm @@ -44,10 +54,12 @@ class DatabaseConnectionSchema(orm.Schema[DatabaseConnection]): # --------------------------------------------------- -class WebMixinSchema(orm.Schema): # not be utype.Schema + +class WebMixinSchema(orm.Schema): # not be utype.Schema """ - Log data using http/https schemes + Log data using http/https schemes """ + scheme: Optional[str] method: Optional[str] # replace the unit property @@ -97,12 +109,12 @@ class ServiceLogSchema(WebMixinSchema, orm.Schema[ServiceLog]): instance_id: Optional[int] endpoint_id: Optional[int] - endpoint_remote_id: Optional[str] = orm.Field('endpoint.remote_id') + endpoint_remote_id: Optional[str] = orm.Field("endpoint.remote_id") endpoint_ident: Optional[str] endpoint_ref: Optional[str] worker_id: Optional[int] - worker_pid: Optional[int] = orm.Field('worker.pid') + worker_pid: Optional[int] = orm.Field("worker.pid") # ----- level: str @@ -125,10 +137,11 @@ class ServiceLogSchema(WebMixinSchema, orm.Schema[ServiceLog]): def __validate__(self): from .api.utils import config + if config.log.hide_ip_address: - self.ip = '*.*.*.*' if self.ip else '' + self.ip = "*.*.*.*" if self.ip else "" if config.log.hide_user_id: - self.user_id = '***' if self.user_id else None + self.user_id = "***" if self.user_id else None class AccessTokenSchema(orm.Schema[AccessToken]): @@ -173,6 +186,7 @@ class ServiceMetricsMixin(orm.Schema): """ request metrics that can simply be calculated in form of incr and divide """ + in_traffic: Optional[int] out_traffic: Optional[int] @@ -206,7 +220,7 @@ class WorkerSchema(SystemMetricsMixin, ServiceMetricsMixin, orm.Schema[Worker]): start_time: int master_id: Optional[int] - master_pid: Optional[int] = orm.Field('master__pid') + master_pid: Optional[int] = orm.Field("master__pid") connected: bool time: int @@ -214,7 +228,9 @@ class WorkerSchema(SystemMetricsMixin, ServiceMetricsMixin, orm.Schema[Worker]): status: Optional[str] -class WorkerMonitorSchema(SystemMetricsMixin, ServiceMetricsMixin, orm.Schema[WorkerMonitor]): +class WorkerMonitorSchema( + SystemMetricsMixin, ServiceMetricsMixin, orm.Schema[WorkerMonitor] +): id: int time: float @@ -238,7 +254,9 @@ class ServerMonitorSchema(SystemMetricsMixin, orm.Schema[ServerMonitor]): metrics: dict -class InstanceMonitorSchema(SystemMetricsMixin, ServiceMetricsMixin, orm.Schema[InstanceMonitor]): +class InstanceMonitorSchema( + SystemMetricsMixin, ServiceMetricsMixin, orm.Schema[InstanceMonitor] +): id: int time: float diff --git a/utilmeta/ops/resources.py b/utilmeta/ops/resources.py index 108f249..5cfb611 100644 --- a/utilmeta/ops/resources.py +++ b/utilmeta/ops/resources.py @@ -6,10 +6,18 @@ from typing import Optional, List, Type from .models import Supervisor, Resource from .config import Operations -from .schema import NodeMetadata, ResourcesSchema, \ - InstanceSchema, TableSchema, DatabaseSchema, CacheSchema, ResourceData, language_version +from .schema import ( + NodeMetadata, + ResourcesSchema, + InstanceSchema, + TableSchema, + DatabaseSchema, + CacheSchema, + ResourceData, + language_version, +) from utilmeta import UtilMeta -from utilmeta.utils import (fast_digest, json_dumps, get_ip, time_now, ignore_errors) +from utilmeta.utils import fast_digest, json_dumps, get_ip, time_now, ignore_errors from django.db import models import utilmeta @@ -18,6 +26,7 @@ class ModelGenerator: def __init__(self, model, config: Operations): self.model = model from utilmeta.core.orm.backends.base import ModelAdaptor + self.adaptor = ModelAdaptor.dispatch(model) self.config = config @@ -27,6 +36,7 @@ def __init__(self, model, config: Operations): def generate_fields(self): from utype.specs.json_schema import JsonSchemaGenerator + fields = {} for f in self.adaptor.get_fields(many=False, no_inherit=True): name = f.column_name @@ -40,32 +50,36 @@ def generate_fields(self): if f.is_pk: secret = False schema = JsonSchemaGenerator(f.rule)() - if schema.get('type') == 'boolean': + if schema.get("type") == "boolean": secret = False - data = {k: v for k, v in dict( - schema=schema, - title=f.title, - description=f.description, - primary_key=f.is_pk, - foreign_key=f.is_fk, - readonly=not f.is_writable, - unique=f.is_unique, - index=f.is_db_index, - null=f.is_nullable, - required=not f.is_optional, - category=f.field.__class__.__name__, - to_model=to_model, - to_field=to_field, - relate_name=relate_name, - secret=secret - ).items() if v} + data = { + k: v + for k, v in dict( + schema=schema, + title=f.title, + description=f.description, + primary_key=f.is_pk, + foreign_key=f.is_fk, + readonly=not f.is_writable, + unique=f.is_unique, + index=f.is_db_index, + null=f.is_nullable, + required=not f.is_optional, + category=f.field.__class__.__name__, + to_model=to_model, + to_field=to_field, + relate_name=relate_name, + secret=secret, + ).items() + if v + } fields[name] = data return fields class ResourcesManager: - EXCLUDED_APPS = ['utilmeta.ops', 'django.contrib.contenttypes'] + EXCLUDED_APPS = ["utilmeta.ops", "django.contrib.contenttypes"] # we reserve other models like django users sessions UPDATE_BATCH_SIZE = 50 @@ -82,9 +96,9 @@ def get_metadata(self): base_url=self.ops_config.base_url, name=self.service.name, title=self.service.title, - description=self.service.description or '', + description=self.service.description or "", version=self.service.version_str, - production=self.service.production + production=self.service.production, ) # @cached_property @@ -93,30 +107,28 @@ def get_metadata(self): def get_instances(self, node_id) -> List[InstanceSchema]: from .schema import InstanceSchema + instances = [] for val in Resource.filter( - type='instance', + type="instance", node_id=node_id, - ).order_by('created_time'): + ).order_by("created_time"): val: Resource server: Optional[Resource] = None if val.server_id: - server = Resource.objects.filter( - id=val.server_id - ).first() + server = Resource.objects.filter(id=val.server_id).first() if not server: continue inst_data = val.data server_data = dict(server.data) if server.ident == self.ops_config.address: from .monitor import get_current_server + inst_data.update(self.instance_data) server_data.update(get_current_server()) try: inst = InstanceSchema( - remote_id=val.remote_id, - server=server_data, - **inst_data + remote_id=val.remote_id, server=server_data, **inst_data ) except utype.exc.ParseError: # does not meet the latest spec @@ -142,14 +154,13 @@ def instance_data(self): def get_current_instance(self) -> InstanceSchema: from .monitor import get_current_server - return InstanceSchema( - server=get_current_server(), - **self.instance_data - ) + + return InstanceSchema(server=get_current_server(), **self.instance_data) def get_tables(self, with_model: bool = False) -> List[TableSchema]: # from utilmeta.core.orm.backends.base import ModelAdaptor from utilmeta.core.orm.backends.django import DjangoModelAdaptor + # todo: support other than django from django.apps import apps, AppConfig from django.db.models.options import Options @@ -160,12 +171,12 @@ def get_tables(self, with_model: bool = False) -> List[TableSchema]: model_id_map = {} def get_first_base(model) -> Type[models.Model]: - meta: Options = getattr(model, '_meta') + meta: Options = getattr(model, "_meta") parents = meta.get_parent_list() return parents[0] if parents else None def register_model(mod, label): - meta: Options = getattr(mod, '_meta') + meta: Options = getattr(mod, "_meta") if meta.auto_created or meta.abstract or meta.swappable: # swappable: like django.contrib.auth.models.User return @@ -177,7 +188,7 @@ def register_model(mod, label): if base: base_id = model_id_map.get(base) if not base_id: - lb = getattr(base, '_meta').app_label + lb = getattr(base, "_meta").app_label base_id = register_model(base, label=lb) adaptor = DjangoModelAdaptor(mod) @@ -185,9 +196,9 @@ def register_model(mod, label): ident = adaptor.ident generator = ModelGenerator(mod, config=self.ops_config) obj = TableSchema( - ref=f'{mod.__module__}.{mod.__name__}', + ref=f"{mod.__module__}.{mod.__name__}", ident=ident, - model_backend='django', + model_backend="django", model_name=model_name, metadata=dict( app_label=label, @@ -216,6 +227,7 @@ def register_model(mod, label): def get_databases(self): from .monitor import get_db_max_connections from utilmeta.core.orm.databases.config import DatabaseConnections + db_config = self.service.get_config(DatabaseConnections) if not db_config: return [] @@ -229,9 +241,9 @@ def get_databases(self): user=db.user, name=db.database_name, hostname=db.host, - server=get_ip(db.host, True), # incase it is intranet + server=get_ip(db.host, True), # incase it is intranet ops=alias == self.ops_config.db_alias, - max_server_connections=get_db_max_connections(alias) + max_server_connections=get_db_max_connections(alias), ) ) return databases @@ -239,6 +251,7 @@ def get_databases(self): def get_caches(self): # from utilmeta.utils import get_ip from utilmeta.core.cache.config import CacheConnections + cache_config = self.service.get_config(CacheConnections) if not cache_config: return [] @@ -266,26 +279,27 @@ def get_resources(self, node_id, etag: str = None) -> Optional[ResourcesSchema]: data = ResourcesSchema( metadata=self.get_metadata(), - openapi=self.ops_config.load_openapi(), # use new openapi + openapi=self.ops_config.load_openapi(), # use new openapi instances=instances, tables=self.get_tables(), databases=self.get_databases(), - caches=self.get_caches() + caches=self.get_caches(), ) if etag: resources_etag = fast_digest( - json_dumps(data), - compress=True, - case_insensitive=False + json_dumps(data), compress=True, case_insensitive=False ) if etag == resources_etag: return None return data def save_resources(self, resources: List[ResourceData], supervisor: Supervisor): - remote_pk_map = {val['remote_id']: val['pk'] for val in Resource.objects.filter( - node_id=supervisor.node_id, - ).values('pk', 'remote_id')} + remote_pk_map = { + val["remote_id"]: val["pk"] + for val in Resource.objects.filter( + node_id=supervisor.node_id, + ).values("pk", "remote_id") + } now = time_now() remote_pks = [] @@ -301,24 +315,23 @@ def save_resources(self, resources: List[ResourceData], supervisor: Supervisor): deleted_time=None, service=supervisor.service, node_id=supervisor.node_id, - **resource + **resource, ) remote_pks.append(resource.remote_id) if resource.remote_id in remote_pk_map: updates.append( Resource( - id=remote_pk_map[resource.remote_id], - updated_time=now, - **res + id=remote_pk_map[resource.remote_id], updated_time=now, **res ) ) else: service_q = models.Q(service=supervisor.service) - if resource.type == 'server': + if resource.type == "server": service_q |= models.Q(service=None) obj = Resource.objects.filter( - models.Q(node_id__isnull=True) | models.Q(node_id=supervisor.node_id), + models.Q(node_id__isnull=True) + | models.Q(node_id=supervisor.node_id), service_q, type=resource.type, remote_id=None, @@ -331,49 +344,46 @@ def save_resources(self, resources: List[ResourceData], supervisor: Supervisor): _data.update(resource.data) resource.data = _data - updates.append( - Resource( - id=obj.pk, - updated_time=now, - **res - ) - ) + updates.append(Resource(id=obj.pk, updated_time=now, **res)) continue creates.append(Resource(**res)) if updates: - fields = ['server_id', 'ident', 'route', - 'deleted_time', 'updated_time', - 'node_id', 'service', - 'remote_id', 'ref', 'data'] + fields = [ + "server_id", + "ident", + "route", + "deleted_time", + "updated_time", + "node_id", + "service", + "remote_id", + "ref", + "data", + ] batch_size = None if self.ops_config.database and self.ops_config.database.is_sqlite: batch_size = self.UPDATE_BATCH_SIZE Resource.objects.bulk_update(updates, fields=fields, batch_size=batch_size) if creates: - Resource.objects.bulk_create( - creates, - ignore_conflicts=True - ) + Resource.objects.bulk_create(creates, ignore_conflicts=True) Resource.objects.filter( # models.Q(remote_id=None) | (~models.Q(remote_id__in=remote_pks)), node_id=supervisor.node_id, # includes remote_id=None - ).exclude(remote_id__in=remote_pks).update( - deleted_time=time_now() - ) + ).exclude(remote_id__in=remote_pks).update(deleted_time=time_now()) - Resource.objects.exclude( - server__in=Resource.filter(type='server') - ).exclude(server=None).update(server_id=None) + Resource.objects.exclude(server__in=Resource.filter(type="server")).exclude( + server=None + ).update(server_id=None) for remote_id, server_id in remote_servers.items(): server = Resource.filter( models.Q(node_id__isnull=True) | models.Q(node_id=supervisor.node_id), - type='server', + type="server", remote_id=server_id, ).first() if server: @@ -388,11 +398,12 @@ def update_supervisor_service(cls, service: str, node_id: str): return from utilmeta.ops import models from django.core.exceptions import EmptyResultSet + for model in models.supervisor_related_models: try: - model.objects.filter( - node_id=node_id - ).exclude(service=service).update(service=service) + model.objects.filter(node_id=node_id).exclude(service=service).update( + service=service + ) except EmptyResultSet: pass try: @@ -407,6 +418,7 @@ def update_supervisor_service(cls, service: str, node_id: str): def set_local_node_id(cls, node_id: str): from utilmeta.bin.utils import update_meta_ini_file from utilmeta import service + update_meta_ini_file(node=node_id) service.load_meta() ops_config = Operations.config() @@ -415,50 +427,62 @@ def set_local_node_id(cls, node_id: str): def sync_resources(self, supervisor: Supervisor = None, force: bool = False): from utilmeta import service + ops_config = Operations.config() if not ops_config: - raise TypeError('Operations not configured') - supervisors = [supervisor] if supervisor and not supervisor.local else ( - Supervisor.filter(connected=True, local=False, service=service.name)) + raise TypeError("Operations not configured") + supervisors = ( + [supervisor] + if supervisor and not supervisor.local + else (Supervisor.filter(connected=True, local=False, service=service.name)) + ) for supervisor in supervisors: if supervisor.service != service.name: - force = True # name changed + force = True # name changed if not supervisor.node_id: continue - print(f'sync resources of [{service.name}] to supervisor[{supervisor.node_id}]...') + print( + f"sync resources of [{service.name}] to supervisor[{supervisor.node_id}]..." + ) with SupervisorClient( base_url=supervisor.base_url, node_key=supervisor.public_key, node_id=supervisor.node_id, - fail_silently=True + fail_silently=True, ) as client: try: resources = self.get_resources( supervisor.node_id, - etag=supervisor.resources_etag if not force else None + etag=supervisor.resources_etag if not force else None, ) except Exception as e: - print('meta: load resources failed with error: {}'.format(e)) + print("meta: load resources failed with error: {}".format(e)) continue if not resources: - print('[etag] resources is identical to the remote supervisor, done') + print( + "[etag] resources is identical to the remote supervisor, done" + ) continue - resp = client.upload_resources( - data=resources - ) + resp = client.upload_resources(data=resources) if not resp.success: - raise ValueError(f'sync to supervisor[{supervisor.node_id}]' - f' failed with error: {resp.message}') + raise ValueError( + f"sync to supervisor[{supervisor.node_id}]" + f" failed with error: {resp.message}" + ) if supervisor.service != service.name: - print(f'update supervisor and resources service name to [{service.name}]') + print( + f"update supervisor and resources service name to [{service.name}]" + ) supervisor.service = service.name - supervisor.save(update_fields=['service']) - self.update_supervisor_service(service.name, node_id=supervisor.node_id) + supervisor.save(update_fields=["service"]) + self.update_supervisor_service( + service.name, node_id=supervisor.node_id + ) if not ops_config.node_id: self.set_local_node_id(supervisor.node_id) @@ -466,73 +490,84 @@ def sync_resources(self, supervisor: Supervisor = None, force: bool = False): # we set the local node_id if resp.status == 304: - print('[304] resources is identical to the remote supervisor, done') + print("[304] resources is identical to the remote supervisor, done") continue if resp.result.resources_etag: supervisor.resources_etag = resp.result.resources_etag - supervisor.save(update_fields=['resources_etag']) + supervisor.save(update_fields=["resources_etag"]) - self.save_resources( - resp.result.resources, - supervisor=supervisor - ) + self.save_resources(resp.result.resources, supervisor=supervisor) - print(f'sync resources to supervisor[{supervisor.node_id}] successfully') + print( + f"sync resources to supervisor[{supervisor.node_id}] successfully" + ) if resp.result.url: if supervisor.url != resp.result.url: supervisor.url = resp.result.url - supervisor.save(update_fields=['url']) + supervisor.save(update_fields=["url"]) - print(f'you can visit {resp.result.url} to view the updated resources') + print( + f"you can visit {resp.result.url} to view the updated resources" + ) def init_service_resources( self, supervisor: Supervisor = None, instance: Resource = None, - force: bool = False + force: bool = False, ): if self.ops_config.proxy: self.register_service(supervisor=supervisor, instance=instance) else: self.sync_resources(supervisor=supervisor, force=force) - def register_service(self, supervisor: Supervisor = None, instance: Resource = None): + def register_service( + self, supervisor: Supervisor = None, instance: Resource = None + ): if not self.ops_config.proxy: return if not instance: return resources = self.get_resources( supervisor.node_id if supervisor else None, - etag=supervisor.resources_etag if supervisor else None + etag=supervisor.resources_etag if supervisor else None, ) from .proxy import ProxyClient, RegistrySchema, RegistryResponse - with ProxyClient(base_url=self.ops_config.proxy.base_url, fail_silently=True) as client: - resp = client.register_service(data=RegistrySchema( - name=self.service.name, - instance_id=instance.pk, - # remote_id=instance.remote_id, - address=self.ops_config.address, - ops_api=self.ops_config.proxy_ops_api, - base_url=self.ops_config.proxy_base_url, - cwd=str(self.service.project_dir), - version=self.service.version_str, - title=self.service.title, - description=self.service.description, - production=self.service.production, - asynchronous=self.service.asynchronous, - backend=self.service.backend_name, - backend_version=self.service.backend_version, - language='python', - language_version=language_version, - utilmeta_version=utilmeta.__version__, - resources=resources - )) + + with ProxyClient( + base_url=self.ops_config.proxy.base_url, fail_silently=True + ) as client: + resp = client.register_service( + data=RegistrySchema( + name=self.service.name, + instance_id=instance.pk, + # remote_id=instance.remote_id, + address=self.ops_config.address, + ops_api=self.ops_config.proxy_ops_api, + base_url=self.ops_config.proxy_base_url, + cwd=str(self.service.project_dir), + version=self.service.version_str, + title=self.service.title, + description=self.service.description, + production=self.service.production, + asynchronous=self.service.asynchronous, + backend=self.service.backend_name, + backend_version=self.service.backend_version, + language="python", + language_version=language_version, + utilmeta_version=utilmeta.__version__, + resources=resources, + ) + ) if isinstance(resp, RegistryResponse): if resp.result.node_id: from utilmeta.bin.utils import update_meta_ini_file + update_meta_ini_file(node=resp.result.node_id) else: - warnings.warn(f'register service: [{self.service.name}] to proxy: ' - f'{self.ops_config.proxy.base_url} failed: {resp.text}') - raise ValueError(f'service register failed: {resp.text}') + warnings.warn( + f"register service: [{self.service.name}] to proxy: " + f"{self.ops_config.proxy.base_url} failed: {resp.text}" + ) + raise ValueError(f"service register failed: {resp.text}") diff --git a/utilmeta/ops/schema.py b/utilmeta/ops/schema.py index cd53eb4..09880c5 100644 --- a/utilmeta/ops/schema.py +++ b/utilmeta/ops/schema.py @@ -6,7 +6,9 @@ from utilmeta.core.api.specs.openapi import OpenAPISchema import sys -language_version = f'{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}' +language_version = ( + f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" +) def get_current_instance_data() -> dict: @@ -15,12 +17,13 @@ def get_current_instance_data() -> dict: except ImportError: return {} from .config import Operations + config = service.get_config(Operations) return dict( version=service.version_str, asynchronous=service.asynchronous, production=service.production, - language='python', + language="python", language_version=language_version, utilmeta_version=utilmeta.__version__, spec_version=__spec_version__, @@ -29,7 +32,7 @@ def get_current_instance_data() -> dict: cwd=str(service.project_dir), # host=config.host if config.host else service.host, port=config.port if config.host else service.host, - address=config.address + address=config.address, ) @@ -39,14 +42,14 @@ class SupervisorBasic(Schema): class SupervisorInfoSchema(Schema): - utilmeta: str # spec version - supervisor: str # supervisor ident + utilmeta: str # spec version + supervisor: str # supervisor ident timestamp: int class ServiceInfoSchema(Schema): - utilmeta: str # spec version - service: str # supervisor ident + utilmeta: str # spec version + service: str # supervisor ident timestamp: int @@ -59,13 +62,13 @@ class NodeMetadata(Schema): name: str base_url: str title: Optional[str] = None - description: str = '' + description: str = "" version: Optional[str] = None spec_version: str = __spec_version__ production: bool = False - language: str = 'python' + language: str = "python" language_version: str = language_version utilmeta_version: str = utilmeta.__version__ @@ -85,7 +88,7 @@ class SupervisorData(Schema): class ResourceBase(Schema): # __options__ = utype.Options(addition=True) - description: str = '' + description: str = "" deprecated: bool = False tags: list = Field(default_factory=list) metadata: dict = Field(default_factory=dict) @@ -95,11 +98,11 @@ class ResourceBase(Schema): class TableSchema(ResourceBase): model_name: Optional[str] = None model_backend: Optional[str] = None - name: str # table name - ref: str # model ref - ident: str # ident (name or app_label.model_name) + name: str # table name + ref: str # model ref + ident: str # ident (name or app_label.model_name) - base: Optional[str] = None # base ident + base: Optional[str] = None # base ident database: Optional[str] = None # select database alias fields: dict @@ -133,7 +136,7 @@ class InstanceSchema(ResourceBase): version: str = Field(default=None, defer_default=True) asynchronous: bool = Field(default=None, defer_default=True) production: bool = Field(default=None, defer_default=True) - language: str = 'python' + language: str = "python" utilmeta_version: str = utilmeta.__version__ language_version: str = language_version backend: str = Field(default=None, defer_default=True) @@ -149,7 +152,9 @@ class DatabaseSchema(ResourceBase): port: int name: str user: str - server: Optional[str] = utype.Field(alias_from=['ip', 'server_ip'], default=None) # ip + server: Optional[str] = utype.Field( + alias_from=["ip", "server_ip"], default=None + ) # ip hostname: Optional[str] = None ops: bool = False test: bool = False @@ -161,7 +166,9 @@ class CacheSchema(ResourceBase): alias: str engine: str port: int - server: Optional[str] = utype.Field(alias_from=['ip', 'server_ip'], default=None) # ip + server: Optional[str] = utype.Field( + alias_from=["ip", "server_ip"], default=None + ) # ip hostname: Optional[str] = None max_memory: Optional[int] = None @@ -195,8 +202,8 @@ class ResourcesData(utype.Schema): url: Optional[str] = None resources: List[ResourceData] resources_etag: str - - + + class SupervisorPatchSchema(Schema): id: int = Field(no_input=True) node_id: str diff --git a/utilmeta/ops/task.py b/utilmeta/ops/task.py index 24f9e09..a4412e1 100644 --- a/utilmeta/ops/task.py +++ b/utilmeta/ops/task.py @@ -48,7 +48,9 @@ def start(self): if self._stopped: break - wait_for = max(0.0, self.interval - (time_now() - self._last_exec).total_seconds()) + wait_for = max( + 0.0, self.interval - (time_now() - self._last_exec).total_seconds() + ) if wait_for: time.sleep(wait_for) @@ -86,6 +88,7 @@ def __init__(self, config: Operations): self.supervisor = None from utilmeta import service + self.service = service self.hourly_aggregation = None @@ -106,6 +109,7 @@ def __call__(self, *args, **kwargs): def clear_connections(cls): # close all connections from django.db import connections + connections.close_all() @property @@ -119,9 +123,13 @@ def handle_error(self, e): if self.config.task_error_log: try: from utilmeta.utils import write_to - content = (f'[{os.getpid()}] {self._last_exec} Operations task worker execute cycle failed with error\n' - + err.full_info + '\n') - write_to(self.config.task_error_log, content, mode='a') + + content = ( + f"[{os.getpid()}] {self._last_exec} Operations task worker execute cycle failed with error\n" + + err.full_info + + "\n" + ) + write_to(self.config.task_error_log, content, mode="a") except Exception as e: pass @@ -139,6 +147,7 @@ def worker_cycle(self): # try to set up locals before from .log import _server, _worker, _instance, _supervisor + self.worker = _worker self.server = _server self.instance = _instance @@ -148,19 +157,19 @@ def worker_cycle(self): # 1. save logs batch_save_logs() except Exception as e: - warnings.warn(f'Save logs failed with error: {e}') + warnings.warn(f"Save logs failed with error: {e}") try: # 2. update worker worker_logger.update_worker( record=not self.config.monitor.worker_disabled, - interval=self.config.worker_cycle + interval=self.config.worker_cycle, ) # update worker from every worker # to make sure that the connected workers has the primary role to execute the following self.update_workers() except Exception as e: - warnings.warn(f'Update workers failed with error: {e}') + warnings.warn(f"Update workers failed with error: {e}") if self._stopped: # if this worker is stopped @@ -176,10 +185,12 @@ def worker_cycle(self): # 1st cycle manager = self.config.resources_manager_cls(self.service) try: - manager.init_service_resources(self.supervisor, instance=self.instance) + manager.init_service_resources( + self.supervisor, instance=self.instance + ) # ignore errors except Exception as e: - warnings.warn(f'sync resources failed with error: {e}') + warnings.warn(f"sync resources failed with error: {e}") self._sync_retries += 1 else: self._synced = True @@ -195,17 +206,19 @@ def worker_cycle(self): @property def connected_workers(self): from .models import Worker + if not self.instance or not self.server: return Worker.objects.none() - + return Worker.objects.filter( instance=self.instance, server=self.server, connected=True, ) - + def update_workers(self): from .models import Worker + if not self.instance or not self.server: return @@ -219,99 +232,115 @@ def update_workers(self): disconnected.append(worker.pk) continue - Worker.objects.filter( - instance=self.instance - ).filter( - models.Q( - time__lte=time_now() - timedelta( - seconds=self.interval * 2 - ) - ) | models.Q( - pk__in=disconnected - ) + Worker.objects.filter(instance=self.instance).filter( + models.Q(time__lte=time_now() - timedelta(seconds=self.interval * 2)) + | models.Q(pk__in=disconnected) ).update(connected=False) def get_total_memory(self): mem = 0 try: - for pss, uss in self.connected_workers.values_list('memory_info__pss', 'memory_info__uss'): + for pss, uss in self.connected_workers.values_list( + "memory_info__pss", "memory_info__uss" + ): mem += pss or uss or 0 - except Exception: # noqa + except Exception: # noqa # field error / Operational error # maybe sqlite, not support json lookup - for mem_info in self.connected_workers.values_list('memory_info', flat=True): - mem += mem_info.get('pss') or mem_info.get('uss') or 0 + for mem_info in self.connected_workers.values_list( + "memory_info", flat=True + ): + mem += mem_info.get("pss") or mem_info.get("uss") or 0 return mem def get_instance_metrics(self): total = self.connected_workers.aggregate( - outbound_requests=models.Sum('outbound_requests'), - queries_num=models.Sum('queries_num'), - requests=models.Sum('requests') + outbound_requests=models.Sum("outbound_requests"), + queries_num=models.Sum("queries_num"), + requests=models.Sum("requests"), ) avg_aggregates = {} - if total['outbound_requests']: + if total["outbound_requests"]: + avg_aggregates.update( + outbound_avg_time=models.Sum( + models.F("outbound_avg_time") * models.F("outbound_requests"), + output_field=models.DecimalField(), + ) + / total["outbound_requests"] + ) + if total["queries_num"]: avg_aggregates.update( - outbound_avg_time=models.Sum(models.F('outbound_avg_time') * models.F('outbound_requests'), - output_field=models.DecimalField()) / total['outbound_requests']) - if total['queries_num']: + query_avg_time=models.Sum( + models.F("query_avg_time") * models.F("queries_num"), + output_field=models.DecimalField(), + ) + / total["queries_num"] + ) + if total["requests"]: avg_aggregates.update( - query_avg_time=models.Sum(models.F('query_avg_time') * models.F('queries_num'), - output_field=models.DecimalField()) / total['queries_num']) - if total['requests']: - avg_aggregates.update(avg_time=models.Sum(models.F('avg_time') * models.F('requests'), - output_field=models.DecimalField()) / total['requests']) + avg_time=models.Sum( + models.F("avg_time") * models.F("requests"), + output_field=models.DecimalField(), + ) + / total["requests"] + ) used_memory = self.get_total_memory() import psutil + sys_total_memory = psutil.virtual_memory().total sys_cpu_num = os.cpu_count() total.update( used_memory=used_memory, - memory_percent=round(100 * used_memory / sys_total_memory, 3) + memory_percent=round(100 * used_memory / sys_total_memory, 3), ) - return replace_null(dict(**self.connected_workers.aggregate( - total_net_connections=models.Sum('total_net_connections'), - active_net_connections=models.Sum('active_net_connections'), - file_descriptors=models.Sum('file_descriptors'), - cpu_percent=models.Sum('cpu_percent') / sys_cpu_num, - threads=models.Sum('threads'), - open_files=models.Sum('open_files'), - in_traffic=models.Sum('in_traffic'), - out_traffic=models.Sum('out_traffic'), - outbound_rps=models.Sum('outbound_rps'), - outbound_timeouts=models.Sum('outbound_timeouts'), - outbound_errors=models.Sum('outbound_errors'), - qps=models.Sum('qps'), - errors=models.Sum('errors'), - rps=models.Sum('rps'), - **avg_aggregates, - ), **total)) + return replace_null( + dict( + **self.connected_workers.aggregate( + total_net_connections=models.Sum("total_net_connections"), + active_net_connections=models.Sum("active_net_connections"), + file_descriptors=models.Sum("file_descriptors"), + cpu_percent=models.Sum("cpu_percent") / sys_cpu_num, + threads=models.Sum("threads"), + open_files=models.Sum("open_files"), + in_traffic=models.Sum("in_traffic"), + out_traffic=models.Sum("out_traffic"), + outbound_rps=models.Sum("outbound_rps"), + outbound_timeouts=models.Sum("outbound_timeouts"), + outbound_errors=models.Sum("outbound_errors"), + qps=models.Sum("qps"), + errors=models.Sum("errors"), + rps=models.Sum("rps"), + **avg_aggregates, + ), + **total, + ) + ) def monitor(self): if not self.config.monitor.server_disabled: try: self.server_monitor() except Exception as e: - warnings.warn(f'utilmeta.ops.task: server monitor failed: {e}') + warnings.warn(f"utilmeta.ops.task: server monitor failed: {e}") if not self.config.monitor.instance_disabled: try: self.instance_monitor() except Exception as e: - warnings.warn(f'utilmeta.ops.task: instance monitor failed: {e}') + warnings.warn(f"utilmeta.ops.task: instance monitor failed: {e}") if not self.config.monitor.database_disabled: try: self.database_monitor() except Exception as e: - warnings.warn(f'utilmeta.ops.task: database monitor failed: {e}') + warnings.warn(f"utilmeta.ops.task: database monitor failed: {e}") if not self.config.monitor.cache_disabled: try: self.cache_monitor() except Exception as e: - warnings.warn(f'utilmeta.ops.task: cache monitor failed: {e}') + warnings.warn(f"utilmeta.ops.task: cache monitor failed: {e}") def instance_monitor(self): if not self.instance: @@ -321,6 +350,7 @@ def instance_monitor(self): # no workers return from .models import InstanceMonitor + metrics = self.get_instance_metrics() # now = time_now() # last: InstanceMonitor = InstanceMonitor.objects.filter( @@ -330,17 +360,18 @@ def instance_monitor(self): # data = dict(self.instance.data or {}) # data.update(time=self._last_exec.timestamp()) self.instance.updated_time = self._last_exec - self.instance.save(update_fields=['updated_time']) + self.instance.save(update_fields=["updated_time"]) InstanceMonitor.objects.create( time=self._last_exec, instance=self.instance, interval=self.interval, current_workers=workers_num, - **metrics + **metrics, ) def server_monitor(self): from .models import ServerMonitor + if not self.server: return metrics = get_sys_metrics(cpu_interval=self.DEFAULT_CPU_INTERVAL) @@ -348,24 +379,28 @@ def server_monitor(self): l1, l5, l15 = psutil.getloadavg() except (AttributeError, OSError): l1, l5, l15 = 0, 0, 0 - loads = dict( - load_avg_1=l1, - load_avg_5=l5, - load_avg_15=l15 - ) + loads = dict(load_avg_1=l1, load_avg_5=l5, load_avg_15=l15) ServerMonitor.objects.create( server=self.server, interval=self.interval, time=self._last_exec, **metrics, - **loads + **loads, ) def database_monitor(self): - from .monitor import (get_db_size, get_db_transactions, get_db_connections, get_db_server_size, - get_db_server_connections, get_db_connections_num, get_db_max_connections) + from .monitor import ( + get_db_size, + get_db_transactions, + get_db_connections, + get_db_server_size, + get_db_server_connections, + get_db_connections_num, + get_db_max_connections, + ) from utilmeta.core.orm import DatabaseConnections from .models import Resource, DatabaseMonitor, DatabaseConnection + db_config = DatabaseConnections.config() if not db_config: return @@ -375,9 +410,7 @@ def database_monitor(self): update_conn = [] create_conn = [] for database in Resource.filter( - type='database', - node_id=self.node_id, - ident__in=list(db_config.databases) + type="database", node_id=self.node_id, ident__in=list(db_config.databases) ): database: Resource db = DatabaseConnections.get(database.ident) @@ -388,7 +421,7 @@ def database_monitor(self): size = get_db_size(db.alias) connected = size is not None db_data = dict(database.data) - current_transactions = db_data.get('transactions') or 0 + current_transactions = db_data.get("transactions") or 0 new_transactions = max(0, transactions - current_transactions) db_metrics = dict( max_server_connections=max_conn, @@ -405,24 +438,26 @@ def database_monitor(self): # database.save(update_fields=update_fields) update_databases.append(database) current, active = get_db_connections_num(db.alias) - db_monitors.append(DatabaseMonitor( - database=database, - interval=self.interval, - time=self._last_exec, - used_space=size or 0, - server_used_space=get_db_server_size(db.alias) or 0, - server_connections=get_db_server_connections(db.alias) or 0, - current_connections=current or 0, - active_connections=active or 0, - new_transactions=new_transactions, - metrics=db_metrics, - )) + db_monitors.append( + DatabaseMonitor( + database=database, + interval=self.interval, + time=self._last_exec, + used_space=size or 0, + server_used_space=get_db_server_size(db.alias) or 0, + server_connections=get_db_server_connections(db.alias) or 0, + current_connections=current or 0, + active_connections=active or 0, + new_transactions=new_transactions, + metrics=db_metrics, + ) + ) connections = get_db_connections(db.alias) if connections: - current_connections = list(DatabaseConnection.objects.filter( - database=database - )) + current_connections = list( + DatabaseConnection.objects.filter(database=database) + ) for conn in connections: for c in current_connections: c: DatabaseConnection @@ -436,32 +471,47 @@ def database_monitor(self): if db_monitors: DatabaseMonitor.objects.bulk_create(db_monitors) if update_databases: - Resource.objects.bulk_update(update_databases, fields=['updated_time', 'data']) + Resource.objects.bulk_update( + update_databases, fields=["updated_time", "data"] + ) if create_conn: DatabaseConnection.objects.bulk_create(create_conn) if update_conn: DatabaseConnection.objects.bulk_update( - update_conn, fields=['status', 'active', 'client_addr', 'client_port', - 'pid', 'backend_start', 'query_start', 'state_change', - 'wait_event', 'transaction_start', 'query', 'operation', 'tables']) + update_conn, + fields=[ + "status", + "active", + "client_addr", + "client_port", + "pid", + "backend_start", + "query_start", + "state_change", + "wait_event", + "transaction_start", + "query", + "operation", + "tables", + ], + ) if update_databases: - DatabaseConnection.objects.filter( - database__in=update_databases - ).exclude(pk__in=[conn.pk for conn in create_conn + update_conn]).delete() + DatabaseConnection.objects.filter(database__in=update_databases).exclude( + pk__in=[conn.pk for conn in create_conn + update_conn] + ).delete() def cache_monitor(self): from .monitor import get_cache_stats from .models import CacheMonitor, Resource from utilmeta.core.cache import CacheConnections + cache_config = CacheConnections.config() if not cache_config: return updated_caches = [] cache_monitors = [] for cache_obj in Resource.filter( - type='cache', - node_id=self.node_id, - ident__in=list(cache_config.caches) + type="cache", node_id=self.node_id, ident__in=list(cache_config.caches) ): cache_obj: Resource cache = CacheConnections.get(cache_obj.ident) @@ -472,7 +522,7 @@ def cache_monitor(self): connected = stats is not None cache_data = dict(connected=connected) data = dict(stats or {}) - pid = data.get('pid') + pid = data.get("pid") # cpu_percent = memory_percent = fds = open_files = None if pid and cache.local: try: @@ -485,7 +535,7 @@ def cache_monitor(self): cpu_percent=cpu_percent, memory_percent=memory_percent, file_descriptors=fds, - open_files=open_files + open_files=open_files, ) except psutil.Error: pass @@ -497,25 +547,42 @@ def cache_monitor(self): cache_obj.data = cache_data # update_fields.append('data') updated_caches.append(cache_obj) - cache_monitors.append(CacheMonitor( - time=self._last_exec, - interval=self.interval, - cache=cache_obj, - **data - )) + cache_monitors.append( + CacheMonitor( + time=self._last_exec, + interval=self.interval, + cache=cache_obj, + **data, + ) + ) if updated_caches: - Resource.objects.bulk_update(updated_caches, fields=['updated_time', 'data']) + Resource.objects.bulk_update( + updated_caches, fields=["updated_time", "data"] + ) if cache_monitors: CacheMonitor.objects.bulk_create(cache_monitors) @ignore_errors def clear(self): - from .models import ServiceLog, RequestLog, QueryLog, VersionLog, WorkerMonitor, CacheMonitor, \ - InstanceMonitor, ServerMonitor, DatabaseMonitor, AggregationLog, AlertLog, Worker, Resource + from .models import ( + ServiceLog, + RequestLog, + QueryLog, + VersionLog, + WorkerMonitor, + CacheMonitor, + InstanceMonitor, + ServerMonitor, + DatabaseMonitor, + AggregationLog, + AlertLog, + Worker, + Resource, + ) + now = self._last_exec or time_now() ServiceLog.objects.filter( - time__lt=now - self.config.log.volatile_maintain, - volatile=True + time__lt=now - self.config.log.volatile_maintain, volatile=True ).delete() # MAX RETENTION ------------------ @@ -544,44 +611,33 @@ def clear(self): time__lt=now - self.config.monitor.worker_retention ).delete() Worker.objects.filter( - time__lt=now - self.DISCONNECTED_WORKER_RETENTION, - connected=False + time__lt=now - self.DISCONNECTED_WORKER_RETENTION, connected=False ).delete() # MONITOR RETENTION ---------------- - Resource.objects.filter( - type='instance', - node_id=self.node_id, - ).annotate( - latest_time=models.Max('instance_metrics__time') - ).filter( - latest_time__lt=now - self.DISCONNECTED_INSTANCE_RETENTION - ).update(deleted_time=now, deprecated=True) + Resource.objects.filter(type="instance", node_id=self.node_id,).annotate( + latest_time=models.Max("instance_metrics__time") + ).filter(latest_time__lt=now - self.DISCONNECTED_INSTANCE_RETENTION).update( + deleted_time=now, deprecated=True + ) InstanceMonitor.objects.filter( - layer=0, - time__lt=now - self.config.monitor.instance_retention + layer=0, time__lt=now - self.config.monitor.instance_retention ).delete() - Resource.objects.filter( - type='server', - node_id=self.node_id, - ).annotate( - latest_time=models.Max('server_metrics__time') - ).filter( - latest_time__lt=now - self.DISCONNECTED_SERVER_RETENTION - ).update(deleted_time=now, deprecated=True) + Resource.objects.filter(type="server", node_id=self.node_id,).annotate( + latest_time=models.Max("server_metrics__time") + ).filter(latest_time__lt=now - self.DISCONNECTED_SERVER_RETENTION).update( + deleted_time=now, deprecated=True + ) ServerMonitor.objects.filter( - layer=0, - time__lt=now - self.config.monitor.server_retention + layer=0, time__lt=now - self.config.monitor.server_retention ).delete() DatabaseMonitor.objects.filter( - layer=0, - time__lt=now - self.config.monitor.database_retention + layer=0, time__lt=now - self.config.monitor.database_retention ).delete() CacheMonitor.objects.filter( - layer=0, - time__lt=now - self.config.monitor.cache_retention + layer=0, time__lt=now - self.config.monitor.cache_retention ).delete() def alert(self): @@ -601,22 +657,13 @@ def is_worker_primary(self): @property def current_day(self) -> datetime: t = self._last_exec or time_now() - return datetime( - year=t.year, - month=t.month, - day=t.day, - tzinfo=t.tzinfo - ) + return datetime(year=t.year, month=t.month, day=t.day, tzinfo=t.tzinfo) @property def current_hour(self) -> datetime: t = self._last_exec or time_now() return datetime( - year=t.year, - month=t.month, - day=t.day, - hour=t.hour, - tzinfo=t.tzinfo + year=t.year, month=t.month, day=t.day, hour=t.hour, tzinfo=t.tzinfo ) @property @@ -669,37 +716,40 @@ def logs_aggregation(self, layer: int = 0): supervisor=self.supervisor, from_time=last_time, to_time=current_time, - layer=layer + layer=layer, ).first() if not aggregation: service_data = aggregate_logs( - service=self.service.name, - to_time=current_time, - layer=layer + service=self.service.name, to_time=current_time, layer=layer + ) + endpoints = ( + aggregate_endpoint_logs( + service=self.service.name, to_time=current_time, layer=layer + ) + if service_data + else None ) - endpoints = aggregate_endpoint_logs( - service=self.service.name, - to_time=current_time, - layer=layer - ) if service_data else None aggregation = AggregationLog.objects.create( service=self.service.name, node_id=self.node_id, supervisor=self.supervisor, - data=normalize(dict( - service=service_data, - endpoints=endpoints, - ), _json=True), + data=normalize( + dict( + service=service_data, + endpoints=endpoints, + ), + _json=True, + ), layer=layer, from_time=last_time, to_time=current_time, - reported_time=self._last_exec if not service_data else None + reported_time=self._last_exec if not service_data else None, ) # check daily --------------------------------------------- else: - service_data = (aggregation.data or {}).get('service') + service_data = (aggregation.data or {}).get("service") if layer == 0: self.hourly_aggregation = aggregation @@ -730,7 +780,9 @@ def logs_aggregation(self, layer: int = 0): # already ahead of this layer 1 aggregation report = True else: - prob_1 = ((self._last_exec - current_time).total_seconds() + self.interval * 2) / layer_seconds + prob_1 = ( + (self._last_exec - current_time).total_seconds() + self.interval * 2 + ) / layer_seconds prob = (2 * self.interval / layer_seconds) + prob_1 if prob_1 >= 1: report = True @@ -749,14 +801,14 @@ def logs_aggregation(self, layer: int = 0): with SupervisorClient( node_id=self.node_id, default_timeout=self.config.default_timeout, - fail_silently=True + fail_silently=True, ) as client: resp = client.report_analytics( data=dict( time=current_time.astimezone(timezone.utc), layer=layer, interval=layer_seconds, - **aggregation.data + **aggregation.data, ) ) updates = {} @@ -765,26 +817,28 @@ def logs_aggregation(self, layer: int = 0): else: aggregation.reported_time = resp.time or self._last_exec updates.update( - remote_id=resp.result.id, - reported_time=aggregation.reported_time + remote_id=resp.result.id, reported_time=aggregation.reported_time ) - AggregationLog.objects.filter( - pk=aggregation.pk - ).update(**updates) + AggregationLog.objects.filter(pk=aggregation.pk).update(**updates) success = isinstance(resp, SupervisorReportResponse) and resp.success if not success: return # if this report is successful, we can check if there are missing reports - missing_reports = AggregationLog.objects.filter( - supervisor=self.supervisor, - # layer=layer, - # no restrict on the layer - reported_time=None, - created_time__gte=self._last_exec - self.AGGREGATION_EXPIRE_TIME[layer] - ).order_by('to_time').exclude(pk=aggregation.pk) + missing_reports = ( + AggregationLog.objects.filter( + supervisor=self.supervisor, + # layer=layer, + # no restrict on the layer + reported_time=None, + created_time__gte=self._last_exec + - self.AGGREGATION_EXPIRE_TIME[layer], + ) + .order_by("to_time") + .exclude(pk=aggregation.pk) + ) missing_count = missing_reports.count() if missing_count: @@ -798,55 +852,61 @@ def logs_aggregation(self, layer: int = 0): # using batch to handle history missing report # avoid sending massive reports values = [] - for obj in list(missing_reports[offset: offset + batch_size]): + for obj in list(missing_reports[offset : offset + batch_size]): obj: AggregationLog - service = obj.data.get('service') + service = obj.data.get("service") if not service: empty_missing_reports.append(obj) continue batch_missing_reports.append(obj) - values.append(dict( - time=obj.to_time.astimezone(timezone.utc), - layer=obj.layer, - interval=layer_seconds, - **obj.data - )) - if values: - resp = client.batch_report_analytics( - data=values + values.append( + dict( + time=obj.to_time.astimezone(timezone.utc), + layer=obj.layer, + interval=layer_seconds, + **obj.data, + ) ) + if values: + resp = client.batch_report_analytics(data=values) if isinstance(resp.result, list): for res, report in zip(resp.result, batch_missing_reports): - remote_id = res.get('id') if isinstance(res, dict) else None + remote_id = ( + res.get("id") if isinstance(res, dict) else None + ) if remote_id: - updates.append(AggregationLog( - id=report.pk, - remote_id=remote_id, - reported_time=resp.time or self._last_exec - )) + updates.append( + AggregationLog( + id=report.pk, + remote_id=remote_id, + reported_time=resp.time or self._last_exec, + ) + ) else: - errors.append(AggregationLog( - id=report.pk, - error=res.get('error', str(res)) if isinstance(res, dict) else str(res) - )) + errors.append( + AggregationLog( + id=report.pk, + error=res.get("error", str(res)) + if isinstance(res, dict) + else str(res), + ) + ) if updates: AggregationLog.objects.bulk_update( - updates, fields=['remote_id', 'reported_time'], - batch_size=self.UPDATE_BATCH_MAX_SIZE + updates, + fields=["remote_id", "reported_time"], + batch_size=self.UPDATE_BATCH_MAX_SIZE, ) if errors: AggregationLog.objects.bulk_update( - errors, fields=['error'], - batch_size=self.UPDATE_BATCH_MAX_SIZE + errors, fields=["error"], batch_size=self.UPDATE_BATCH_MAX_SIZE ) if empty_missing_reports: AggregationLog.objects.filter( pk__in=[obj.pk for obj in empty_missing_reports] - ).update( - reported_time=self._last_exec or time_now() - ) + ).update(reported_time=self._last_exec or time_now()) def heartbeat(self): pass diff --git a/utilmeta/utils/__init__.py b/utilmeta/utils/__init__.py index 131c072..dceb60d 100644 --- a/utilmeta/utils/__init__.py +++ b/utilmeta/utils/__init__.py @@ -8,4 +8,5 @@ from .error import Error from .logical import LogicUtil from .plugin import PluginEvent, PluginBase, PluginTarget + Plugin = PluginBase diff --git a/utilmeta/utils/adaptor.py b/utilmeta/utils/adaptor.py index 6885004..67b8ffa 100644 --- a/utilmeta/utils/adaptor.py +++ b/utilmeta/utils/adaptor.py @@ -28,22 +28,28 @@ def dispatch(cls, obj, *args, **kwargs): name = cls.get_module_name(obj) if name: - ref = f'{cls.__backends_package__}.{name}' if cls.__backends_package__ else name + ref = ( + f"{cls.__backends_package__}.{name}" + if cls.__backends_package__ + else name + ) import_obj(ref) else: cls.load_from_base() if cls.qualify(obj): - return cls(obj, *args, **kwargs) # noqa + return cls(obj, *args, **kwargs) # noqa to = cls.recursively_dispatch(cls, obj, *args, **kwargs) if to: return to - raise NotImplementedError(f'{cls}: adaptor for {obj}: {repr(name)} is not implemented') + raise NotImplementedError( + f"{cls}: adaptor for {obj}: {repr(name)} is not implemented" + ) @classmethod def recursively_dispatch(cls, base, obj, *args, **kwargs): for impl in base.__subclasses__(): - impl: Type['BaseAdaptor'] + impl: Type["BaseAdaptor"] try: if impl.qualify(obj): return impl(obj, *args, **kwargs) # noqa @@ -56,7 +62,7 @@ def recursively_dispatch(cls, base, obj, *args, **kwargs): return None @classmethod - def reconstruct(cls, adaptor: 'BaseAdaptor'): + def reconstruct(cls, adaptor: "BaseAdaptor"): raise NotImplementedError @classmethod @@ -69,7 +75,11 @@ def load_from_base(cls): return if cls.__backends_names__: for name in cls.__backends_names__: - ref = f'{cls.__backends_package__}.{name}' if cls.__backends_package__ else name + ref = ( + f"{cls.__backends_package__}.{name}" + if cls.__backends_package__ + else name + ) import_obj(ref) cls.__path_loaded__ = True @@ -77,13 +87,13 @@ def load_from_base(cls): def set_backends_pkg(cls): if cls.__backends_package__: return - module_parts = cls.__module__.split('.') + module_parts = cls.__module__.split(".") module = sys.modules[cls.__module__] - if not hasattr(module, '__path__'): + if not hasattr(module, "__path__"): module_parts = module_parts[:-1] if cls.__backends_route__: module_parts.append(cls.__backends_route__) - cls.__backends_package__ = '.'.join(module_parts) + cls.__backends_package__ = ".".join(module_parts) def __init_subclass__(cls, **kwargs): cls.set_backends_pkg() diff --git a/utilmeta/utils/base.py b/utilmeta/utils/base.py index 0c8a0c0..57e2f20 100644 --- a/utilmeta/utils/base.py +++ b/utilmeta/utils/base.py @@ -2,9 +2,9 @@ from typing import Dict, Any, TypeVar, List import inspect -T = TypeVar('T') +T = TypeVar("T") -__all__ = ['Util', 'Meta'] +__all__ = ["Util", "Meta"] # class UtilKey: @@ -19,7 +19,7 @@ class Meta(type): def __init__(cls, name, bases: tuple, attrs: dict, **kwargs): super().__init__(name, bases, attrs) - __init = attrs.get(Attr.INIT) # only track current init + __init = attrs.get(Attr.INIT) # only track current init cls._kwargs = kwargs cls._pos_var = None @@ -72,11 +72,11 @@ def __init__(cls, name, bases: tuple, attrs: dict, **kwargs): cls._defaults = ImmutableDict(defaults) cls._requires = requires - cls._attr_names = [a for a in attrs if not a.startswith('_')] + cls._attr_names = [a for a in attrs if not a.startswith("_")] @property def cls_path(cls): - return f'{cls.__module__}.{cls.__name__}' + return f"{cls.__module__}.{cls.__name__}" @property def kw_keys(cls): @@ -107,7 +107,7 @@ def __init__(self, __params__: Dict[str, Any]): if isinstance(val, dict): _kwargs = {k: v for k, v in val.items() if not k.startswith(SEG)} kwargs.update(_kwargs) - spec.update(_kwargs) # also update spec + spec.update(_kwargs) # also update spec continue elif key in self._pos_keys: args.append(key) @@ -115,7 +115,9 @@ def __init__(self, __params__: Dict[str, Any]): kwargs[key] = val else: continue - if val != self._defaults.get(key): # for key_var or pos_var the default is None + if val != self._defaults.get( + key + ): # for key_var or pos_var the default is None spec[key] = val self.__args__ = tuple(args) @@ -125,12 +127,15 @@ def __init__(self, __params__: Dict[str, Any]): def __hash__(self): return hash(repr(self)) - def __eq__(self, other: 'Util'): + def __eq__(self, other: "Util"): if inspect.isclass(self): return super().__eq__(other) if not isinstance(other, self.__class__): return False - return self.__spec_kwargs__ == other.__spec_kwargs__ and self.__args__ == other.__args__ + return ( + self.__spec_kwargs__ == other.__spec_kwargs__ + and self.__args__ == other.__args__ + ) def __bool__(self): # !! return not self.vacuum @@ -169,7 +174,9 @@ def __copy__(self): # pop(attrs, Attr.LOCK) # pop __lock__ cls: type = self.__class__ return cls(self.__name__, bases, self._copy(attrs)) - return self.__class__(*self._copy(self.__args__), **self._copy(self.__spec_kwargs__)) + return self.__class__( + *self._copy(self.__args__), **self._copy(self.__spec_kwargs__) + ) @property def _cls_name(self): @@ -189,12 +196,12 @@ def _repr(self, params: List[str] = None, excludes: List[str] = None): for k, v in self.__spec_kwargs__.items(): # if not isinstance(v, bool) and any([s in str(k).lower() for s in self._secret_names]) and v: # v = SECRET - if k.startswith('_'): + if k.startswith("_"): continue if params is not None and k not in params: continue if excludes is not None and k in excludes: continue - attrs.append(k + '=' + represent(v)) # str(self.display(v))) - s = ', '.join([represent(v) for v in self.__args__] + attrs) - return f'{self._cls_name}({s})' + attrs.append(k + "=" + represent(v)) # str(self.display(v))) + s = ", ".join([represent(v) for v in self.__args__] + attrs) + return f"{self._cls_name}({s})" diff --git a/utilmeta/utils/constant/data.py b/utilmeta/utils/constant/data.py index 1f133cb..887210c 100644 --- a/utilmeta/utils/constant/data.py +++ b/utilmeta/utils/constant/data.py @@ -1,35 +1,35 @@ from ..datastructure import Static -SECRET = '*' * 8 -PY = '.py' -ID = 'id' -PK = 'pk' -SEG = '__' -UTF_8 = 'utf-8' -SHA_256 = 'sha256' -ELEMENTS = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' +SECRET = "*" * 8 +PY = ".py" +ID = "id" +PK = "pk" +SEG = "__" +UTF_8 = "utf-8" +SHA_256 = "sha256" +ELEMENTS = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" MAX_BASE = len(ELEMENTS) class Logic(Static): - ALL = '*' - AND = '&' - OR = '|' - XOR = '^' - NOT = '~' + ALL = "*" + AND = "&" + OR = "|" + XOR = "^" + NOT = "~" class Reg(Static): - META = '.^$#~&*+?{}[]\\|()' - IP = '((2(5[0-5]|[0-4]\\d))|[0-1]?\\d{1,2})(\\.((2(5[0-5]|[0-4]\\d))|[0-1]?\\d{1,2})){3}' - ALNUM = '[0-9a-zA-Z]+' - ALNUM_SCORE = '[0-9a-zA-Z_]+' - ALNUM_SEP = '[0-9a-zA-Z-]+' - ALNUM_SCORE_SEP = '[0-9a-zA-Z_-]+' - - ALL = '.+' - URL_ROUTE = '[^/]+' # match all string except / - PATH_REGEX = '{(%s)}' % ALNUM_SCORE + META = ".^$#~&*+?{}[]\\|()" + IP = "((2(5[0-5]|[0-4]\\d))|[0-1]?\\d{1,2})(\\.((2(5[0-5]|[0-4]\\d))|[0-1]?\\d{1,2})){3}" + ALNUM = "[0-9a-zA-Z]+" + ALNUM_SCORE = "[0-9a-zA-Z_]+" + ALNUM_SEP = "[0-9a-zA-Z-]+" + ALNUM_SCORE_SEP = "[0-9a-zA-Z_-]+" + + ALL = ".+" + URL_ROUTE = "[^/]+" # match all string except / + PATH_REGEX = "{(%s)}" % ALNUM_SCORE # EMAIL_FULL = '^(?:[a-zA-Z0-9!#$%&\'*+/=?^_`{|}~-]+(?:\\.[a-zA-Z0-9!#$%&\'' \ # '*+/=?^_`{|}~-]+)*|(?:[\x01-\x08\x0b\x0c\x0e-\x1f!#-[]-' \ # '\x7f]|\\[\x01-\t\x0b\x0c\x0e-\x7f])*)@(?:(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]' \ @@ -37,141 +37,158 @@ class Reg(Static): # '[0-4][0-9])|1[0-9][0-9]|[1-9]?[0-9]))\\.){3}(?:(2(5[0-5]|[0-4][0-9])' \ # '|1[0-9][0-9]|[1-9]?[0-9])|[a-zA-Z0-9-]*[a-zA-Z0-9]:(?:[\x01-\x08\x0b\x0c\x0e-\x1f' \ # '!-ZS-\x7f]|\\[\x01-\t\x0b\x0c\x0e-\x7f])+)\\])&' - EMAIL_ALNUM = '^[a-zA-Z0-9]+@[a-zA-Z0-9]+(\\.[a-zA-Z0-9]+)+[a-z0-9A-Z]+$' - EMAIL = '^[a-z0-9A-Z]+[a-z0-9A-Z._-]+@[a-zA-Z0-9_-]+(\\.[a-zA-Z0-9_-]+)+[a-z0-9A-Z]+$' + EMAIL_ALNUM = "^[a-zA-Z0-9]+@[a-zA-Z0-9]+(\\.[a-zA-Z0-9]+)+[a-z0-9A-Z]+$" + EMAIL = ( + "^[a-z0-9A-Z]+[a-z0-9A-Z._-]+@[a-zA-Z0-9_-]+(\\.[a-zA-Z0-9_-]+)+[a-z0-9A-Z]+$" + ) EMAIL_SIMPLE = r"^\S+@\S+\.\S+$" - must_contains_letter_number = '(?=.*[0-9])(?=.*[a-zA-Z]).+' - must_contains_letter_number_special = r'(?=.*[0-9])(?=.*[a-zA-Z])(?=.*[^a-zA-Z0-9]).+' - must_contains_hybrid_letter_number_special = '(?=.*[0-9])(?=.*[A-Z])(?=.*[a-z])(?=.*[^a-zA-Z0-9]).+' + must_contains_letter_number = "(?=.*[0-9])(?=.*[a-zA-Z]).+" + must_contains_letter_number_special = ( + r"(?=.*[0-9])(?=.*[a-zA-Z])(?=.*[^a-zA-Z0-9]).+" + ) + must_contains_hybrid_letter_number_special = ( + "(?=.*[0-9])(?=.*[A-Z])(?=.*[a-z])(?=.*[^a-zA-Z0-9]).+" + ) class DateFormat(Static): DATETIME = "%Y-%m-%d %H:%M:%S" - DATETIME_DF = '%Y-%m-%d %H:%M:%S.%f' - DATETIME_F = '%Y-%m-%d %H:%M:%S %f' - DATETIME_P = '%Y-%m-%d %I:%M:%S %p' + DATETIME_DF = "%Y-%m-%d %H:%M:%S.%f" + DATETIME_F = "%Y-%m-%d %H:%M:%S %f" + DATETIME_P = "%Y-%m-%d %I:%M:%S %p" DATETIME_T = "%Y-%m-%dT%H:%M:%S" DATETIME_TZ = "%Y-%m-%dT%H:%M:%SZ" DATETIME_TFZ = "%Y-%m-%dT%H:%M:%S.%fZ" DATETIME_TF = "%Y-%m-%dT%H:%M:%S.%f" DATETIME_ISO = "%Y-%m-%dT%H:%M:%S.%fTZD" - DATETIME_GMT = '%a, %d %b %Y %H:%M:%S GMT' - DATETIME_PS = '%a %b %d %H:%M:%S %Y' - DATETIME_GMT2 = '%b %d %H:%M:%S %Y GMT' - DATE = '%Y-%m-%d' + DATETIME_GMT = "%a, %d %b %Y %H:%M:%S GMT" + DATETIME_PS = "%a %b %d %H:%M:%S %Y" + DATETIME_GMT2 = "%b %d %H:%M:%S %Y GMT" + DATE = "%Y-%m-%d" # TIME = '%H:%M:%S' class Key(Static): - USER_ID = '_user_id' - IP_KEY = '_ip' - UA_KEY = '_ua' + USER_ID = "_user_id" + IP_KEY = "_ip" + UA_KEY = "_ua" - USER_HASH = '_user_hash' - USER_CACHE = '_cached_user' - PK_LIST = '_pk_list' + USER_HASH = "_user_hash" + USER_CACHE = "_cached_user" + PK_LIST = "_pk_list" - DATA = '_data' - HINTS = '_hints' - Q = '_q' - ID = '_id' - MERGE = '_merge' - META = '_meta' - ROUTER = '_Router' - DISABLE_CACHE = '_disable_cache' - INSTANCE = '_instance' + DATA = "_data" + HINTS = "_hints" + Q = "_q" + ID = "_id" + MERGE = "_merge" + META = "_meta" + ROUTER = "_Router" + DISABLE_CACHE = "_disable_cache" + INSTANCE = "_instance" class Attr(Static): - GT = '__gt__' - GE = '__ge__' - LT = '__lt__' - LE = '__le__' - - NEXT = '__next__' - COMMAND = '__command__' - SPEC = '__spec__' - LEN = '__len__' - DOC = '__doc__' - HASH = '__hash__' - CODE = '__code__' - MODULE = '__module__' - BASES = '__bases__' - NAME = '__name__' - FUNC = '__func__' - CALL = '__call__' - ARGS = '__args__' - INIT = '__init__' - DICT = '__dict__' - MAIN = '__main__' - ITER = '__iter__' - LOCK = '__locked__' - INNER = '__inner__' - PROXY = '__proxy__' - RELATED = '__related__' - STATUS = '__status__' - ORIGIN = '__origin__' - TARGET = '__target__' - CATEGORY = '__category__' - CAUSES = '__causes__' - CAUSE = '__cause__' - CLS = '__class__' - PARSER = '__parser__' - BUILTINS = '__builtins__' - ANNOTATES = '__annotations__' - ISOLATE = '__isolate__' - CONFIG = '__config__' - OPTIONS = '__options__' - CACHE = '__cache__' - VACUUM = '__vacuum__' - VALIDATE = '__validate__' - TEMPLATE = '__template__' - DATA = '__data__' - EXTRA = '__extra__' - ADD = '__ADD__' - MOD = '__MOD__' - REM = '__REM__' - - GETATTR = '__getattr__' - GETATTRIBUTE = '__getattribute__' - - GET = '__get__' - SET = '__set__' - DELETE = '__delete__' + GT = "__gt__" + GE = "__ge__" + LT = "__lt__" + LE = "__le__" + + NEXT = "__next__" + COMMAND = "__command__" + SPEC = "__spec__" + LEN = "__len__" + DOC = "__doc__" + HASH = "__hash__" + CODE = "__code__" + MODULE = "__module__" + BASES = "__bases__" + NAME = "__name__" + FUNC = "__func__" + CALL = "__call__" + ARGS = "__args__" + INIT = "__init__" + DICT = "__dict__" + MAIN = "__main__" + ITER = "__iter__" + LOCK = "__locked__" + INNER = "__inner__" + PROXY = "__proxy__" + RELATED = "__related__" + STATUS = "__status__" + ORIGIN = "__origin__" + TARGET = "__target__" + CATEGORY = "__category__" + CAUSES = "__causes__" + CAUSE = "__cause__" + CLS = "__class__" + PARSER = "__parser__" + BUILTINS = "__builtins__" + ANNOTATES = "__annotations__" + ISOLATE = "__isolate__" + CONFIG = "__config__" + OPTIONS = "__options__" + CACHE = "__cache__" + VACUUM = "__vacuum__" + VALIDATE = "__validate__" + TEMPLATE = "__template__" + DATA = "__data__" + EXTRA = "__extra__" + ADD = "__ADD__" + MOD = "__MOD__" + REM = "__REM__" + + GETATTR = "__getattr__" + GETATTRIBUTE = "__getattribute__" + + GET = "__get__" + SET = "__set__" + DELETE = "__delete__" class EndpointAttr(Static): - method = 'method' - alias = 'alias' - hook = 'hook' + method = "method" + alias = "alias" + hook = "hook" - main = 'main' - unit = 'unit' + main = "main" + unit = "unit" - before_hook = 'before_hook' - after_hook = 'after_hook' - error_hook = 'error_hook' - errors = 'errors' - excludes = 'excludes' + before_hook = "before_hook" + after_hook = "after_hook" + error_hook = "error_hook" + errors = "errors" + excludes = "excludes" ATOM_TYPES = (str, int, bool, float, type(None)) JSON_TYPES = (*ATOM_TYPES, list, dict) # types thar can directly dump to json COMMON_TYPES = (*JSON_TYPES, set, tuple, bytes) -COMMON_ERRORS = (AttributeError, TypeError, ValueError, IndexError, KeyError, UnicodeDecodeError) -HOOK_TYPES = (EndpointAttr.before_hook, EndpointAttr.after_hook, EndpointAttr.error_hook) +COMMON_ERRORS = ( + AttributeError, + TypeError, + ValueError, + IndexError, + KeyError, + UnicodeDecodeError, +) +HOOK_TYPES = ( + EndpointAttr.before_hook, + EndpointAttr.after_hook, + EndpointAttr.error_hook, +) UNIT_TYPES = (EndpointAttr.main, *HOOK_TYPES) DEFAULT_SECRET_NAMES = ( - 'password', - 'secret', - 'dsn', - 'sessionid', - 'pwd', - 'passphrase', - 'cookie', - 'authorization', - '_token', - '_key', + "password", + "secret", + "dsn", + "sessionid", + "pwd", + "passphrase", + "cookie", + "authorization", + "_token", + "_key", ) diff --git a/utilmeta/utils/constant/i18n.py b/utilmeta/utils/constant/i18n.py index b79b97b..2903e7a 100644 --- a/utilmeta/utils/constant/i18n.py +++ b/utilmeta/utils/constant/i18n.py @@ -192,10 +192,10 @@ class Locale(Static): UZ = "UZ" VN = "VN" CN = "CN" - HK = "HK" # HongKong (China) + HK = "HK" # HongKong (China) MO = "MO" SG = "SG" - TW = "TW" # Taiwan (China) + TW = "TW" # Taiwan (China) @classmethod def language(cls, locale: str, single: bool = True) -> Union[List[str], str]: @@ -209,130 +209,180 @@ def language(cls, locale: str, single: bool = True) -> Union[List[str], str]: class TimeZone(Static, ignore_duplicate=True): - UTC = 'UTC' - GMT = 'UTC' - WET = 'WET' - CET = 'CET' - MET = 'CET' - ECT = 'CET' - EET = 'EET' - MIT = 'Pacific/Apia' - HST = 'Pacific/Honolulu' - AST = 'America/Anchorage' - PST = 'America/Los_Angeles' - LOS_ANGELES = 'America/Los_Angeles' - MST = 'America/Denver' - PNT = 'America/Phoenix' - CST = 'America/Chicago' - CHICAGO = 'America/Chicago' - EST = 'America/New_York' - NEW_YORK = 'America/New_York' - IET = 'America/Indiana/Indianapolis' - PRT = 'America/Puerto_Rico' - CNT = 'America/St_Johns' - AGT = 'America/Argentina/Buenos_Aires' - BET = 'America/Sao_Paulo' - ART = 'Africa/Cairo' - CAT = 'Africa/Harare' - EAT = 'Africa/Addis_Ababa' - NET = 'Asia/Yerevan' - PLT = 'Asia/Karachi' - IST = 'Asia/Kolkata' - BST = 'Asia/Dhaka' - VST = 'Asia/Ho_Chi_Minh' - CTT = 'Asia/Shanghai' - SHANGHAI = 'Asia/Shanghai' - JST = 'Asia/Tokyo' - TOKYO = 'Asia/Tokyo' - ACT = 'Australia/Darwin' - DARWIN = 'Australia/Darwin' - AET = 'Australia/Sydney' - SYDNEY = 'Australia/Sydney' - SST = 'Pacific/Guadalcanal' - NST = 'Pacific/Auckland' + UTC = "UTC" + GMT = "UTC" + WET = "WET" + CET = "CET" + MET = "CET" + ECT = "CET" + EET = "EET" + MIT = "Pacific/Apia" + HST = "Pacific/Honolulu" + AST = "America/Anchorage" + PST = "America/Los_Angeles" + LOS_ANGELES = "America/Los_Angeles" + MST = "America/Denver" + PNT = "America/Phoenix" + CST = "America/Chicago" + CHICAGO = "America/Chicago" + EST = "America/New_York" + NEW_YORK = "America/New_York" + IET = "America/Indiana/Indianapolis" + PRT = "America/Puerto_Rico" + CNT = "America/St_Johns" + AGT = "America/Argentina/Buenos_Aires" + BET = "America/Sao_Paulo" + ART = "Africa/Cairo" + CAT = "Africa/Harare" + EAT = "Africa/Addis_Ababa" + NET = "Asia/Yerevan" + PLT = "Asia/Karachi" + IST = "Asia/Kolkata" + BST = "Asia/Dhaka" + VST = "Asia/Ho_Chi_Minh" + CTT = "Asia/Shanghai" + SHANGHAI = "Asia/Shanghai" + JST = "Asia/Tokyo" + TOKYO = "Asia/Tokyo" + ACT = "Australia/Darwin" + DARWIN = "Australia/Darwin" + AET = "Australia/Sydney" + SYDNEY = "Australia/Sydney" + SST = "Pacific/Guadalcanal" + NST = "Pacific/Auckland" LANGUAGE_LOCALE_MAP = { - 'af': ['ZA'], - 'ar': ['AE', 'BH', 'DZ', 'EG', 'IQ', 'JO', 'KW', 'LB', 'LY', 'MA', 'OM', - 'QA', 'SA', 'SY', 'TN', 'YE'], - 'az': ['AZ', 'AZ'], - 'be': ['BY'], - 'bg': ['BG'], - 'bs': ['BA'], - 'ca': ['ES'], - 'cs': ['CZ'], - 'cy': ['GB'], - 'da': ['DK'], - 'de': ['AT', 'CH', 'DE', 'LI', 'LU'], - 'dv': ['MV'], - 'el': ['GR'], - 'en': ['AU', 'BZ', 'CA', 'CB', 'GB', 'IE', 'JM', 'NZ', 'PH', 'TT', 'US', 'ZA', 'ZW'], - 'es': ['AR', 'BO', 'CL', 'CO', 'CR', 'DO', 'EC', 'ES', 'ES', 'GT', 'HN', 'MX', 'NI', - 'PA', 'PE', 'PR', 'PY', 'SV', 'UY', 'VE'], - 'et': ['EE'], - 'eu': ['ES'], - 'fa': ['IR'], - 'fi': ['FI'], - 'fo': ['FO'], - 'fr': ['BE', 'CA', 'CH', 'FR', 'LU', 'MC'], - 'gl': ['ES'], - 'gu': ['IN'], - 'he': ['IL'], - 'hi': ['IN'], - 'hr': ['BA', 'HR'], - 'hu': ['HU'], - 'hy': ['AM'], - 'id': ['ID'], - 'is': ['IS'], - 'it': ['CH', 'IT'], - 'ja': ['JP'], - 'ka': ['GE'], - 'kk': ['KZ'], - 'kn': ['IN'], - 'ko': ['KR'], - 'ky': ['KG'], - 'lt': ['LT'], - 'lv': ['LV'], - 'mi': ['NZ'], - 'mk': ['MK'], - 'mn': ['MN'], - 'mr': ['IN'], - 'ms': ['BN', 'MY'], - 'mt': ['MT'], - 'nb': ['NO'], - 'nl': ['BE', 'NL'], - 'nn': ['NO'], - 'ns': ['ZA'], - 'pa': ['IN'], - 'pl': ['PL'], - 'pt': ['BR', 'PT'], - 'qu': ['BO', 'EC', 'PE'], - 'ro': ['RO'], - 'ru': ['RU'], - 'sa': ['IN'], - 'se': ['FI', 'NO', 'SE'], - 'sk': ['SK'], - 'sl': ['SI'], - 'sq': ['AL'], - 'sr': ['BA', 'SP'], - 'sv': ['FI', 'SE'], - 'sw': ['KE'], - 'ta': ['IN'], - 'te': ['IN'], - 'th': ['TH'], - 'tl': ['PH'], - 'tn': ['ZA'], - 'tr': ['TR'], - 'tt': ['RU'], - 'uk': ['UA'], - 'ur': ['PK'], - 'uz': ['UZ', 'UZ'], - 'vi': ['VN'], - 'xh': ['ZA'], - 'zh': ['CN', 'HK', 'MO', 'SG', 'TW'], - 'zu': ['ZA'], + "af": ["ZA"], + "ar": [ + "AE", + "BH", + "DZ", + "EG", + "IQ", + "JO", + "KW", + "LB", + "LY", + "MA", + "OM", + "QA", + "SA", + "SY", + "TN", + "YE", + ], + "az": ["AZ", "AZ"], + "be": ["BY"], + "bg": ["BG"], + "bs": ["BA"], + "ca": ["ES"], + "cs": ["CZ"], + "cy": ["GB"], + "da": ["DK"], + "de": ["AT", "CH", "DE", "LI", "LU"], + "dv": ["MV"], + "el": ["GR"], + "en": [ + "AU", + "BZ", + "CA", + "CB", + "GB", + "IE", + "JM", + "NZ", + "PH", + "TT", + "US", + "ZA", + "ZW", + ], + "es": [ + "AR", + "BO", + "CL", + "CO", + "CR", + "DO", + "EC", + "ES", + "ES", + "GT", + "HN", + "MX", + "NI", + "PA", + "PE", + "PR", + "PY", + "SV", + "UY", + "VE", + ], + "et": ["EE"], + "eu": ["ES"], + "fa": ["IR"], + "fi": ["FI"], + "fo": ["FO"], + "fr": ["BE", "CA", "CH", "FR", "LU", "MC"], + "gl": ["ES"], + "gu": ["IN"], + "he": ["IL"], + "hi": ["IN"], + "hr": ["BA", "HR"], + "hu": ["HU"], + "hy": ["AM"], + "id": ["ID"], + "is": ["IS"], + "it": ["CH", "IT"], + "ja": ["JP"], + "ka": ["GE"], + "kk": ["KZ"], + "kn": ["IN"], + "ko": ["KR"], + "ky": ["KG"], + "lt": ["LT"], + "lv": ["LV"], + "mi": ["NZ"], + "mk": ["MK"], + "mn": ["MN"], + "mr": ["IN"], + "ms": ["BN", "MY"], + "mt": ["MT"], + "nb": ["NO"], + "nl": ["BE", "NL"], + "nn": ["NO"], + "ns": ["ZA"], + "pa": ["IN"], + "pl": ["PL"], + "pt": ["BR", "PT"], + "qu": ["BO", "EC", "PE"], + "ro": ["RO"], + "ru": ["RU"], + "sa": ["IN"], + "se": ["FI", "NO", "SE"], + "sk": ["SK"], + "sl": ["SI"], + "sq": ["AL"], + "sr": ["BA", "SP"], + "sv": ["FI", "SE"], + "sw": ["KE"], + "ta": ["IN"], + "te": ["IN"], + "th": ["TH"], + "tl": ["PH"], + "tn": ["ZA"], + "tr": ["TR"], + "tt": ["RU"], + "uk": ["UA"], + "ur": ["PK"], + "uz": ["UZ", "UZ"], + "vi": ["VN"], + "xh": ["ZA"], + "zh": ["CN", "HK", "MO", "SG", "TW"], + "zu": ["ZA"], } LANGUAGES = Language.gen() -LOCALES = Locale.gen() \ No newline at end of file +LOCALES = Locale.gen() diff --git a/utilmeta/utils/constant/vendor.py b/utilmeta/utils/constant/vendor.py index 04dc625..918e082 100644 --- a/utilmeta/utils/constant/vendor.py +++ b/utilmeta/utils/constant/vendor.py @@ -2,30 +2,30 @@ class DB(Static): - PostgreSQL = 'postgresql' - MySQL = 'mysql' - Oracle = 'oracle' - SQLite = 'sqlite' + PostgreSQL = "postgresql" + MySQL = "mysql" + Oracle = "oracle" + SQLite = "sqlite" class AgentDevice(Static): - pc = 'pc' - mobile = 'mobile' - bot = 'bot' - tablet = 'tablet' - email = 'email' + pc = "pc" + mobile = "mobile" + bot = "bot" + tablet = "tablet" + email = "email" class AgentBrowser(Static): - chrome = 'chrome' - firefox = 'firefox' - safari = 'safari' - edge = 'edge' - opera = 'opera' - ie = 'ie' # Internet Explorer + chrome = "chrome" + firefox = "firefox" + safari = "safari" + edge = "edge" + opera = "opera" + ie = "ie" # Internet Explorer class AgentOS(Static): - mac = 'mac' - windows = 'windows' - linux = 'linux' + mac = "mac" + windows = "windows" + linux = "linux" diff --git a/utilmeta/utils/constant/web.py b/utilmeta/utils/constant/web.py index 57be0ad..8746d1e 100644 --- a/utilmeta/utils/constant/web.py +++ b/utilmeta/utils/constant/web.py @@ -1,69 +1,69 @@ from ..datastructure import Static import re -SCHEME = '://' -LOCAL = 'localhost' -LOCAL_IP = '127.0.0.1' -ALL_IP = '0.0.0.0' +SCHEME = "://" +LOCAL = "localhost" +LOCAL_IP = "127.0.0.1" +ALL_IP = "0.0.0.0" PATH_REGEX = re.compile("{([a-zA-Z][a-zA-Z0-9_]*)}") class HTTPMethod(Static): - GET = 'GET' - PUT = 'PUT' - POST = 'POST' - PATCH = 'PATCH' - DELETE = 'DELETE' - HEAD = 'HEAD' - TRACE = 'TRACE' - OPTIONS = 'OPTIONS' - CONNECT = 'CONNECT' + GET = "GET" + PUT = "PUT" + POST = "POST" + PATCH = "PATCH" + DELETE = "DELETE" + HEAD = "HEAD" + TRACE = "TRACE" + OPTIONS = "OPTIONS" + CONNECT = "CONNECT" class CommonMethod(Static): - GET = 'get' - PUT = 'put' - POST = 'post' - PATCH = 'patch' - DELETE = 'delete' + GET = "get" + PUT = "put" + POST = "post" + PATCH = "patch" + DELETE = "delete" class MetaMethod(Static): - HEAD = 'head' - TRACE = 'trace' - OPTIONS = 'options' - CONNECT = 'connect' + HEAD = "head" + TRACE = "trace" + OPTIONS = "options" + CONNECT = "connect" class RequestType(Static): - PLAIN = 'text/plain' - JSON = 'application/json' - FORM_URLENCODED = 'application/x-www-form-urlencoded' - FORM_DATA = 'multipart/form-data' - XML = 'text/xml' - HTML = 'text/html' - APP_XML = 'application/xml' - OCTET_STREAM = 'application/octet-stream' + PLAIN = "text/plain" + JSON = "application/json" + FORM_URLENCODED = "application/x-www-form-urlencoded" + FORM_DATA = "multipart/form-data" + XML = "text/xml" + HTML = "text/html" + APP_XML = "application/xml" + OCTET_STREAM = "application/octet-stream" class GeneralType(Static): - JSON = 'json' - TEXT = 'text' - IMAGE = 'image' - AUDIO = 'audio' - VIDEO = 'video' - HTML = 'html' - OCTET_STREAM = 'octet-stream' + JSON = "json" + TEXT = "text" + IMAGE = "image" + AUDIO = "audio" + VIDEO = "video" + HTML = "html" + OCTET_STREAM = "octet-stream" @classmethod def get(cls, content_type: str): if not content_type: return None - if '/' not in content_type: + if "/" not in content_type: return content_type - if ';' in content_type: - content_type = content_type.split(';')[0] - per, suf = content_type.split('/') + if ";" in content_type: + content_type = content_type.split(";")[0] + per, suf = content_type.split("/") if suf in GENERAL_TYPES: return suf elif per in GENERAL_TYPES: @@ -72,102 +72,102 @@ def get(cls, content_type: str): @classmethod def content_type(cls, type: str): - if ';' in type: - type = type.split(';')[0] - if '/' in type: + if ";" in type: + type = type.split(";")[0] + if "/" in type: return type content_map = { - cls.HTML: 'text/html', - cls.JSON: 'application/json', + cls.HTML: "text/html", + cls.JSON: "application/json", # cls.AUDIO: 'audio/*', # cls.VIDEO: 'video/*', # cls.IMAGE: 'image/*', # cls.TEXT: 'text/*', - cls.OCTET_STREAM: 'application/octet-stream', + cls.OCTET_STREAM: "application/octet-stream", } return content_map.get(type) class Scheme(Static): - HTTP = 'http' - HTTPS = 'https' - FTP = 'ftp' - FTPS = 'ftps' - SFTP = 'sftp' - SSH = 'ssh' - WS = 'ws' - WSS = 'wss' - MQTT = 'mqtt' - SMTP = 'smtp' - POP = 'pop' - UDP = 'udp' + HTTP = "http" + HTTPS = "https" + FTP = "ftp" + FTPS = "ftps" + SFTP = "sftp" + SSH = "ssh" + WS = "ws" + WSS = "wss" + MQTT = "mqtt" + SMTP = "smtp" + POP = "pop" + UDP = "udp" class WebSocketEventType(Static): - open = 'open' - close = 'close' - error = 'error' - message = 'message' + open = "open" + close = "close" + error = "error" + message = "message" class AuthScheme(Static): - BASIC = 'basic' - DIGEST = 'digest' - BEARER = 'bearer' - TOKEN = 'token' - HOBA = 'hoba' - MUTUAL = 'mutual' + BASIC = "basic" + DIGEST = "digest" + BEARER = "bearer" + TOKEN = "token" + HOBA = "hoba" + MUTUAL = "mutual" class Header(Static): - CONNECTION = 'Connection' - COOKIE = 'Cookie' + CONNECTION = "Connection" + COOKIE = "Cookie" - REMOTE_ADDR = 'REMOTE_ADDR' + REMOTE_ADDR = "REMOTE_ADDR" FORWARDED_FOR = "HTTP_X_FORWARDED_FOR" - AUTHORIZATION = 'Authorization' - - WWW_AUTH = 'WWW-Authenticate' - - ACCEPT = 'Accept' - ACCEPT_LANGUAGE = 'Accept-Language' - ACCEPT_ENCODING = 'Accept-Encoding' - CONTENT_LANGUAGE = 'Content-Language' - - REFERER = 'Referer' - UPGRADE = 'Upgrade' - - SET_COOKIE = 'Set-Cookie' - USER_AGENT = 'User-Agent' - - VARY = 'Vary' - EXPIRES = 'Expires' - PRAGMA = 'Pragma' - CACHE_CONTROL = 'Cache-Control' - ETAG = 'Etag' - LAST_MODIFIED = 'Last-Modified' - - IF_UNMODIFIED_SINCE = 'If-Unmodified-Since' - IF_MODIFIED_SINCE = 'If-Modified-Since' - IF_NONE_MATCH = 'If-None-Match' - IF_MATCH = 'If-Match' - - LENGTH = 'Content-Length' - TYPE = 'Content-Type' - ALLOW = 'Allow' - ORIGIN = 'Origin' - ALLOW_ORIGIN = 'Access-Control-Allow-Origin' - ACCESS_MAX_AGE = 'Access-Control-Max-Age' - ALLOW_CREDENTIALS = 'Access-Control-Allow-Credentials' - ALLOW_METHODS = 'Access-Control-Allow-Methods' - EXPOSE_HEADERS = 'Access-Control-Expose-Headers' - ALLOW_HEADERS = 'Access-Control-Allow-Headers' - OPTIONS_METHOD = 'Access-Control-Request-Method' - OPTIONS_HEADERS = 'Access-Control-Request-Headers' + AUTHORIZATION = "Authorization" + + WWW_AUTH = "WWW-Authenticate" + + ACCEPT = "Accept" + ACCEPT_LANGUAGE = "Accept-Language" + ACCEPT_ENCODING = "Accept-Encoding" + CONTENT_LANGUAGE = "Content-Language" + + REFERER = "Referer" + UPGRADE = "Upgrade" + + SET_COOKIE = "Set-Cookie" + USER_AGENT = "User-Agent" + + VARY = "Vary" + EXPIRES = "Expires" + PRAGMA = "Pragma" + CACHE_CONTROL = "Cache-Control" + ETAG = "Etag" + LAST_MODIFIED = "Last-Modified" + + IF_UNMODIFIED_SINCE = "If-Unmodified-Since" + IF_MODIFIED_SINCE = "If-Modified-Since" + IF_NONE_MATCH = "If-None-Match" + IF_MATCH = "If-Match" + + LENGTH = "Content-Length" + TYPE = "Content-Type" + ALLOW = "Allow" + ORIGIN = "Origin" + ALLOW_ORIGIN = "Access-Control-Allow-Origin" + ACCESS_MAX_AGE = "Access-Control-Max-Age" + ALLOW_CREDENTIALS = "Access-Control-Allow-Credentials" + ALLOW_METHODS = "Access-Control-Allow-Methods" + EXPOSE_HEADERS = "Access-Control-Expose-Headers" + ALLOW_HEADERS = "Access-Control-Allow-Headers" + OPTIONS_METHOD = "Access-Control-Request-Method" + OPTIONS_HEADERS = "Access-Control-Request-Headers" @classmethod def attr_name(cls, key: str): - return key.replace('-', '_').lower() + return key.replace("-", "_").lower() class TCPStatus(Static): @@ -192,36 +192,49 @@ class TCPStatus(Static): TCPStatus.TIME_WAIT, TCPStatus.FIN_WAIT1, TCPStatus.FIN_WAIT2, - TCPStatus.LAST_ACK + TCPStatus.LAST_ACK, ] HTTP = Scheme.HTTP + SCHEME HTTPS = Scheme.HTTPS + SCHEME REQUEST_TYPES = RequestType.gen() CONTENT_TYPE = Header.attr_name(Header.TYPE) -DICT_TYPES = (RequestType.JSON, RequestType.XML, RequestType.FORM_URLENCODED, RequestType.FORM_DATA) +DICT_TYPES = ( + RequestType.JSON, + RequestType.XML, + RequestType.FORM_URLENCODED, + RequestType.FORM_DATA, +) GENERAL_TYPES = GeneralType.gen() -STREAM_TYPES = (GeneralType.IMAGE, GeneralType.AUDIO, GeneralType.VIDEO, GeneralType.OCTET_STREAM) +STREAM_TYPES = ( + GeneralType.IMAGE, + GeneralType.AUDIO, + GeneralType.VIDEO, + GeneralType.OCTET_STREAM, +) SCHEMES = Scheme.gen() ALLOW_HEADERS = (Header.TYPE, Header.LENGTH, Header.ORIGIN) HTTP_METHODS = HTTPMethod.gen() ISOLATED_HEADERS = {Header.LENGTH, Header.TYPE, Header.CONNECTION} SAFE_METHODS = (HTTPMethod.GET, HTTPMethod.OPTIONS, HTTPMethod.HEAD, HTTPMethod.TRACE) DEFAULT_IDEMPOTENT_METHODS = (*SAFE_METHODS, HTTPMethod.PUT, HTTPMethod.DELETE) -UNSAFE_METHODS = HAS_BODY_METHODS = (CommonMethod.POST, CommonMethod.PUT, CommonMethod.PATCH, CommonMethod.DELETE) +UNSAFE_METHODS = HAS_BODY_METHODS = ( + CommonMethod.POST, + CommonMethod.PUT, + CommonMethod.PATCH, + CommonMethod.DELETE, +) HTTP_METHODS_LOWER = [m.lower() for m in HTTP_METHODS] COMMON_METHODS = CommonMethod.gen() META_METHODS = MetaMethod.gen() METHODS = COMMON_METHODS + META_METHODS -SECURE_SCHEMES = { - 'http': 'https', - 'ws': 'wss', - 'ftp': 'ftps' -} +SECURE_SCHEMES = {"http": "https", "ws": "wss", "ftp": "ftps"} STATUS_WITHOUT_BODY = (204, 205, 304) MESSAGE_STATUSES = list(range(100, 103)) SUCCESS_STATUSES = list(range(200, 208)) REDIRECT_STATUSES = list(range(300, 308)) -REQUEST_ERROR_STATUSES = list(range(400, 419)) + list(range(421, 427)) + [428, 429, 431, 449, 451] +REQUEST_ERROR_STATUSES = ( + list(range(400, 419)) + list(range(421, 427)) + [428, 429, 431, 449, 451] +) SERVER_ERROR_STATUSES = list(range(500, 511)) + [600] DEFAULT_RETRY_ON_STATUSES = (408, 429, 500, 502, 503, 504) ERROR_STATUS = { diff --git a/utilmeta/utils/context.py b/utilmeta/utils/context.py index f0ff566..3982dd7 100644 --- a/utilmeta/utils/context.py +++ b/utilmeta/utils/context.py @@ -9,7 +9,7 @@ class ParserProperty: - def __init__(self, prop: Union[Type['Property'], 'Property'], field: ParserField): + def __init__(self, prop: Union[Type["Property"], "Property"], field: ParserField): self.prop = prop self.field = field self.get = partial(self.prop.getter, field=self.field) @@ -75,7 +75,7 @@ def setter(self, obj, value, field: ParserField = None): class DuplicateContextProperty(ValueError): - def __init__(self, msg='', ident: str = None): + def __init__(self, msg="", ident: str = None): super().__init__(msg) self.ident = ident @@ -84,12 +84,16 @@ class ContextWrapper: """ A universal context parser, often used to process Request context """ + context_cls = object default_property = None - def __init__(self, parser: BaseParser, - default_properties: dict = None, - excluded_names: List[str] = None): + def __init__( + self, + parser: BaseParser, + default_properties: dict = None, + excluded_names: List[str] = None, + ): properties = {} attrs = {} ident_props = {} @@ -104,12 +108,14 @@ def __init__(self, parser: BaseParser, prop = None # detect property from type input including Union and Optional for origin in val.input_origins: - prop_field = getattr(origin, '__field__', None) + prop_field = getattr(origin, "__field__", None) if prop_field: if isinstance(prop_field, Property): prop = prop_field break - elif inspect.isclass(prop_field) and issubclass(prop_field, Property): + elif inspect.isclass(prop_field) and issubclass( + prop_field, Property + ): prop = prop_field() break if not prop: @@ -129,7 +135,7 @@ def __init__(self, parser: BaseParser, self.attrs: Dict[str, ParserProperty] = attrs self.parser = parser - def init_prop(self, prop, val) -> ParserProperty: # noqa, to be inherit + def init_prop(self, prop, val) -> ParserProperty: # noqa, to be inherit return prop.init(val) def parse_context(self, context: object) -> dict: diff --git a/utilmeta/utils/datastructure.py b/utilmeta/utils/datastructure.py index b218e48..30f42aa 100644 --- a/utilmeta/utils/datastructure.py +++ b/utilmeta/utils/datastructure.py @@ -12,10 +12,10 @@ def __error__(self, *args, **kwargs): __setitem__ = __error__ def __str__(self): - return f'{self.__class__.__name__}({super().__repr__()})' + return f"{self.__class__.__name__}({super().__repr__()})" def __repr__(self): - return f'{self.__class__.__name__}({super().__repr__()})' + return f"{self.__class__.__name__}({super().__repr__()})" setdefault = __error__ pop = __error__ @@ -29,10 +29,10 @@ def error(self, *args, **kwargs): raise AttributeError("ImmutableList can not modify value") def __str__(self): - return f'{self.__class__.__name__}({super().__repr__()})' + return f"{self.__class__.__name__}({super().__repr__()})" def __repr__(self): - return f'{self.__class__.__name__}({super().__repr__()})' + return f"{self.__class__.__name__}({super().__repr__()})" append = error clear = error @@ -108,20 +108,25 @@ class Static: def __init_subclass__(cls, ignore_duplicate: bool = False, **kwargs): attrs = [] for name, attr in cls.__dict__.items(): - if name.startswith('_'): + if name.startswith("_"): continue if attr in attrs and not ignore_duplicate: - raise ValueError(f'Static value cannot be duplicated, got {attr}') + raise ValueError(f"Static value cannot be duplicated, got {attr}") attrs.append(attr) @classmethod def gen(cls) -> tuple: attrs = [] for name, attr in cls.__dict__.items(): - if '__' in name: + if "__" in name: continue - if callable(attr) or inspect.isfunction(attr) or inspect.ismethod(attr) or\ - isinstance(attr, classmethod) or isinstance(attr, staticmethod): + if ( + callable(attr) + or inspect.isfunction(attr) + or inspect.ismethod(attr) + or isinstance(attr, classmethod) + or isinstance(attr, staticmethod) + ): continue attrs.append(attr) return tuple(attrs) @@ -130,10 +135,15 @@ def gen(cls) -> tuple: def dict(cls, reverse: bool = False, lower: bool = False): attrs = {} for name, attr in cls.__dict__.items(): - if '__' in name: + if "__" in name: continue - if callable(attr) or inspect.isfunction(attr) or inspect.ismethod(attr) or \ - isinstance(attr, classmethod) or isinstance(attr, staticmethod): + if ( + callable(attr) + or inspect.isfunction(attr) + or inspect.ismethod(attr) + or isinstance(attr, classmethod) + or isinstance(attr, staticmethod) + ): continue name = name.lower() if lower else name if reverse: @@ -268,13 +278,18 @@ class LazyLoadObject(SimpleLazyObject): def __init__(self, ref: str): def _load_func(): from .functional import import_obj + return import_obj(ref) - self.__dict__['_ref'] = ref + + self.__dict__["_ref"] = ref super().__init__(_load_func) except (ModuleNotFoundError, ImportError): + class classonlymethod(classmethod): def __get__(self, instance, cls=None): if instance is not None: - raise AttributeError("This method is available only on the class, not on instances.") + raise AttributeError( + "This method is available only on the class, not on instances." + ) return super().__get__(instance, cls) diff --git a/utilmeta/utils/decorator.py b/utilmeta/utils/decorator.py index d1d3e0f..eda0915 100644 --- a/utilmeta/utils/decorator.py +++ b/utilmeta/utils/decorator.py @@ -9,15 +9,34 @@ from utilmeta.utils.error import Error import warnings -__all__ = ['omit', 'error_convert', 'handle_retries', 'cached_property', - 'awaitable', 'async_to_sync', 'adapt_async', - 'handle_parse', 'handle_timeout', 'ignore_errors', 'static_require'] - - -def ignore_errors(_f=None, *, default=None, log: bool = True, - log_detail: bool = True, errors=(Exception,), on_finally=None): +__all__ = [ + "omit", + "error_convert", + "handle_retries", + "cached_property", + "awaitable", + "async_to_sync", + "adapt_async", + "handle_parse", + "handle_timeout", + "ignore_errors", + "static_require", +] + + +def ignore_errors( + _f=None, + *, + default=None, + log: bool = True, + log_detail: bool = True, + errors=(Exception,), + on_finally=None, +): if on_finally: - assert callable(on_finally), f'@ignore_errors on_finally must be a callable, got {on_finally}' + assert callable( + on_finally + ), f"@ignore_errors on_finally must be a callable, got {on_finally}" def decorator(f): @wraps(f) @@ -26,7 +45,7 @@ def wrapper(*args, **kwargs): return f(*args, **kwargs) except errors as e: if log: - warnings.warn(f'IGNORED ERROR for {f.__name__}: {e}') + warnings.warn(f"IGNORED ERROR for {f.__name__}: {e}") if log_detail: Error(e).log(console=True) # to avoid a public mutable value (like dict) cause unpredictable result @@ -35,7 +54,9 @@ def wrapper(*args, **kwargs): finally: if on_finally: on_finally() + return wrapper + if _f: return decorator(_f) return decorator @@ -43,7 +64,7 @@ def wrapper(*args, **kwargs): def static_require(*args_func: Callable, runtime: bool = True): def not_implement(*_, **__): - raise NotImplementedError('you current settings does not support this method') + raise NotImplementedError("you current settings does not support this method") @ignore_errors(default=False, log=False) def satisfied(): @@ -59,12 +80,15 @@ def decorator(f): else: return not_implement else: + @wraps(f) def wrapper(*args, **kwargs): if satisfied(): return f(*args, **kwargs) return not_implement() + return wrapper + return decorator @@ -81,7 +105,9 @@ def wrapper(*args, **kwargs): if not isinstance(err, tuple(errors)): raise Error().throw() raise Error().throw(target) + return wrapper + return deco @@ -98,16 +124,22 @@ def wrapper(*args, **kwargs): r = async_result.get(timeout.total_seconds()) except multiprocessing.context.TimeoutError: # pool.terminate() - raise TimeoutError(f"function <{f.__name__}> execute beyond expect" - f" time limit {timeout.total_seconds()} seconds") + raise TimeoutError( + f"function <{f.__name__}> execute beyond expect" + f" time limit {timeout.total_seconds()} seconds" + ) finally: pool.close() return r + return wrapper + return decorator -def handle_retries(retries: int = 2, on_errors=None, retry_interval: Union[float, Callable] = None): +def handle_retries( + retries: int = 2, on_errors=None, retry_interval: Union[float, Callable] = None +): assert retries > 1 on_errors = on_errors or Exception @@ -126,11 +158,13 @@ def wrapper(*args, **kwargs): except on_errors as e: errors.append(e) if not errors: - raise RuntimeError('Invalid retry status') + raise RuntimeError("Invalid retry status") if len(errors) == 1: raise Error(errors[0]).throw() raise CombinedError(*errors) + return wrapper + return decorator @@ -154,7 +188,10 @@ def handler(func, *args, **kwargs): @wraps(f) def wrapper(*args, **kwargs): args = [f] + [a for a in args] - threading.Thread(target=handler, args=args, kwargs=kwargs, daemon=daemon).start() + threading.Thread( + target=handler, args=args, kwargs=kwargs, daemon=daemon + ).start() + return wrapper @@ -168,18 +205,19 @@ class cached_property: The optional ``name`` argument is obsolete as of Python 3.6 and will be deprecated in Django 4.0 (#30127). """ + name = None @staticmethod def func(instance): raise TypeError( - 'Cannot use cached_property instance without calling ' - '__set_name__() on it.' + "Cannot use cached_property instance without calling " + "__set_name__() on it." ) def __init__(self, func, name=None): self.real_func = func - self.__doc__ = getattr(func, '__doc__') + self.__doc__ = getattr(func, "__doc__") def __set_name__(self, owner, name): if self.name is None: @@ -204,12 +242,16 @@ def __get__(self, instance, cls=None): import inspect + _CO_NESTED = inspect.CO_NESTED -_CO_FROM_COROUTINE = inspect.CO_COROUTINE | inspect.CO_ITERABLE_COROUTINE | inspect.CO_ASYNC_GENERATOR +_CO_FROM_COROUTINE = ( + inspect.CO_COROUTINE | inspect.CO_ITERABLE_COROUTINE | inspect.CO_ASYNC_GENERATOR +) def from_coroutine(level=2, _cache={}): from sys import _getframe + f_code = _getframe(level).f_code if f_code in _cache: return _cache[f_code] @@ -228,7 +270,7 @@ def from_coroutine(level=2, _cache={}): # # Where func() is some function that we've wrapped with one of the decorators # below. If so, the code object is nested and has a name such as or - if f_code.co_flags & _CO_NESTED and f_code.co_name[0] == '<': + if f_code.co_flags & _CO_NESTED and f_code.co_name[0] == "<": return from_coroutine(level + 2) else: _cache[f_code] = False @@ -236,7 +278,6 @@ def from_coroutine(level=2, _cache={}): def adapt_async(f=None, close_conn=True): - def decorator(func): if inspect.iscoroutinefunction(func) or inspect.isasyncgenfunction(func): return func @@ -247,6 +288,7 @@ def close_wrapper(*args, **kwargs): finally: if close_conn: from django.db import connections + if isinstance(close_conn, str): conn = connections[close_conn] if conn: @@ -264,18 +306,21 @@ def wrapper(*args, **kwargs): if service.asynchronous: return service.pool.get_result(close_wrapper, *args, **kwargs) return func(*args, **kwargs) + return wrapper + if f: return decorator(f) return decorator from contextvars import ContextVar -from_thread = ContextVar('from_thread') + +from_thread = ContextVar("from_thread") def awaitable(syncfunc, bind_service: bool = False, close_conn: bool = False): - ''' + """ Decorator that allows an asynchronous function to be paired with a synchronous function in a single function call. The selection of which function executes depends on the calling context. For example: @@ -294,10 +339,13 @@ async def bar(): ... r = await spam(s, 1024) # Calls async function (B) above ... - ''' + """ + def decorate(asyncfunc): - if not inspect.iscoroutinefunction(asyncfunc) and not inspect.isasyncgenfunction(asyncfunc): - raise TypeError(f'{asyncfunc} must be async def function') + if not inspect.iscoroutinefunction( + asyncfunc + ) and not inspect.isasyncgenfunction(asyncfunc): + raise TypeError(f"{asyncfunc} must be async def function") # origin = None if isinstance(syncfunc, (classmethod, staticmethod)): @@ -307,7 +355,9 @@ def decorate(asyncfunc): sync_func = syncfunc if inspect.signature(sync_func) != inspect.signature(asyncfunc): - raise TypeError(f'{sync_func.__name__} and async {asyncfunc.__name__} have different signatures') + raise TypeError( + f"{sync_func.__name__} and async {asyncfunc.__name__} have different signatures" + ) @wraps(asyncfunc) def wrapper(*args, **kwargs): @@ -322,7 +372,9 @@ def wrapper(*args, **kwargs): else: if service.asynchronous: import utilmeta - if not getattr(utilmeta, '_cmd_env', False): + + if not getattr(utilmeta, "_cmd_env", False): + def sync_func_wrapper(*_, **__): from_thread.set(True) try: @@ -331,22 +383,30 @@ def sync_func_wrapper(*_, **__): from_thread.set(False) if close_conn: from django.db import connections + connections.close_all() - return service.pool.get_result(sync_func_wrapper, *args, **kwargs) + + return service.pool.get_result( + sync_func_wrapper, *args, **kwargs + ) return sync_func(*args, **kwargs) + wrapper._syncfunc = sync_func wrapper._asyncfunc = asyncfunc wrapper._awaitable = True wrapper.__doc__ = sync_func.__doc__ or asyncfunc.__doc__ return wrapper + return decorate try: from asgiref.sync import async_to_sync except ImportError: + def async_to_sync(to_await): import asyncio + async_response = [] def wrapper(*args, **kwargs): @@ -369,4 +429,5 @@ async def run_and_capture_result(): coroutine = run_and_capture_result() loop.run_until_complete(coroutine) return async_response[0] + return wrapper diff --git a/utilmeta/utils/error.py b/utilmeta/utils/error.py index ac47105..535471f 100644 --- a/utilmeta/utils/error.py +++ b/utilmeta/utils/error.py @@ -8,11 +8,13 @@ if TYPE_CHECKING: from utilmeta.core.request.base import Request -CAUSE_DIVIDER = '\n# The above exception was the direct cause of the following exception:\n' +CAUSE_DIVIDER = ( + "\n# The above exception was the direct cause of the following exception:\n" +) class Error: - def __init__(self, e: Exception = None, request: 'Request' = None): + def __init__(self, e: Exception = None, request: "Request" = None): if isinstance(e, Exception): self.exc = e self.type = e.__class__ @@ -28,10 +30,10 @@ def __init__(self, e: Exception = None, request: 'Request' = None): self.exc_traceback = exc_traceback self.locals = {} - self.current_traceback = '' - self.traceback = '' - self.variable_info = '' - self.full_info = '' + self.current_traceback = "" + self.traceback = "" + self.variable_info = "" + self.full_info = "" self.ts = time.time() # request context @@ -41,7 +43,7 @@ def setup(self, from_errors: list = ()): if self.current_traceback: return # FIXME: lots of performance cost in this function - self.current_traceback = ''.join(traceback.format_tb(self.exc_traceback)) + self.current_traceback = "".join(traceback.format_tb(self.exc_traceback)) self.traceback = self.current_traceback if not from_errors: try: @@ -63,20 +65,20 @@ def setup(self, from_errors: list = ()): variables = [] if self.locals: - variables.append('Exception Local Variables:') + variables.append("Exception Local Variables:") for key, val in self.locals.items(): if key.startswith(SEG) and key.endswith(SEG): continue try: - variables.append(f'{key} = {readable(val, max_length=100)}') + variables.append(f"{key} = {readable(val, max_length=100)}") except Exception as e: - print(f'Variable <{key}> serialize error: {e}') - self.variable_info = '\n'.join(variables) - self.full_info = '\n'.join([self.message, *variables]) + print(f"Variable <{key}> serialize error: {e}") + self.variable_info = "\n".join(variables) + self.full_info = "\n".join([self.message, *variables]) # self.record_disabled = getattr(self.exc, 'record_disabled', False) def __str__(self): - return f'<{self.type.__name__}: {str(self.exc)}>' + return f"<{self.type.__name__}: {str(self.exc)}>" @property def exception(self): @@ -84,11 +86,7 @@ def exception(self): @property def message(self) -> str: - return '{0}{1}: {2}'.format( - self.traceback, - self.type.__name__, - self.exc - ) + return "{0}{1}: {2}".format(self.traceback, self.type.__name__, self.exc) # @property # def root_cause(self) -> Exception: @@ -128,22 +126,22 @@ def status(self) -> int: return self.get_status(default=500) def get_status(self, default=None): - status = getattr(self.exc, 'status', None) + status = getattr(self.exc, "status", None) if isinstance(status, int) and 100 <= status <= 600: return status return default @property def result(self): - return getattr(self.exc, 'result', None) + return getattr(self.exc, "result", None) @property def state(self): - return getattr(self.exc, 'state', None) + return getattr(self.exc, "state", None) @property def headers(self): - return getattr(self.exc, 'headers', None) + return getattr(self.exc, "headers", None) def log(self, console: bool = False) -> int: if not self.full_info: @@ -160,12 +158,12 @@ def cause_func(self): stk = traceback.extract_tb(self.exc_traceback, 1) return stk[0][2] - def throw(self, type=None, prepend: str = '', **kwargs): + def throw(self, type=None, prepend: str = "", **kwargs): if not (inspect.isclass(type) and issubclass(type, Exception)): type = None type = type or self.type if prepend or not isinstance(self.exc, type): - e = type(f'{prepend}{self.exc}', **kwargs) # noqa + e = type(f"{prepend}{self.exc}", **kwargs) # noqa e.__cause__ = self.exc # setattr(e, Attr.CAUSES, self.get_causes()) else: @@ -174,7 +172,9 @@ def throw(self, type=None, prepend: str = '', **kwargs): # cause in that way can track the original variables return e - def get_hook(self, hooks: Dict[Type[Exception], Callable], exact: bool = False) -> Optional[Callable]: + def get_hook( + self, hooks: Dict[Type[Exception], Callable], exact: bool = False + ) -> Optional[Callable]: if not hooks: return None diff --git a/utilmeta/utils/exceptions/config.py b/utilmeta/utils/exceptions/config.py index ddeb94c..fa28d4a 100644 --- a/utilmeta/utils/exceptions/config.py +++ b/utilmeta/utils/exceptions/config.py @@ -1,5 +1,3 @@ - - class NotConfigured(NotImplementedError): def __init__(self, config_cls): self.config_cls = config_cls @@ -7,7 +5,7 @@ def __init__(self, config_cls): @property def msg(self): - return f'Config: {self.config_cls} not configured' + return f"Config: {self.config_cls} not configured" class SettingNotConfigured(NotConfigured): @@ -17,7 +15,7 @@ def __init__(self, config_cls, item: str): @property def msg(self): - return f'Config: {self.config_cls}.{self.item} not configured' + return f"Config: {self.config_cls}.{self.item} not configured" class ConfigError(Exception): diff --git a/utilmeta/utils/exceptions/http.py b/utilmeta/utils/exceptions/http.py index 7c7dbe4..15de549 100644 --- a/utilmeta/utils/exceptions/http.py +++ b/utilmeta/utils/exceptions/http.py @@ -9,19 +9,22 @@ class HttpError(Exception): record_disabled = False def __str__(self): - msg = str(self.message or '') + msg = str(self.message or "") head = self.__class__.__name__ if msg.startswith(head): return msg - return f'{head}: {self.message}' - - def __init__(self, message: str = None, *, - state: Union[str, int] = None, - status: int = None, - result=None, - extra: dict = None, - detail: Union[dict, list] = (), - ): + return f"{head}: {self.message}" + + def __init__( + self, + message: str = None, + *, + state: Union[str, int] = None, + status: int = None, + result=None, + extra: dict = None, + detail: Union[dict, list] = (), + ): self.message = str(message) self.state = state self.result = result @@ -82,8 +85,14 @@ class BadRequest(RequestError): class Unauthorized(RequestError): status = 401 - def __init__(self, message: str = None, *, state: Union[str, int] = None, - auth_scheme: str = None, auth_params: dict = None): + def __init__( + self, + message: str = None, + *, + state: Union[str, int] = None, + auth_scheme: str = None, + auth_params: dict = None, + ): self.message = str(message) self.state = state super().__init__(message=message, state=state) @@ -91,7 +100,7 @@ def __init__(self, message: str = None, *, state: Union[str, int] = None, return value = auth_scheme.capitalize() if isinstance(auth_params, dict): - value += ' ' + ','.join([f'{k}={v}' for k, v in auth_params.items()]) + value += " " + ",".join([f"{k}={v}" for k, v in auth_params.items()]) self.append_headers = {Header.WWW_AUTH: value} @@ -102,7 +111,15 @@ class PaymentRequired(RequestError): class PermissionDenied(RequestError): status = 403 - def __init__(self, msg: str = None, *, scope=None, required_scope=None, name: str = None, **kwargs): + def __init__( + self, + msg: str = None, + *, + scope=None, + required_scope=None, + name: str = None, + **kwargs, + ): super().__init__(msg, **kwargs) self.scope = scope self.required_scope = required_scope @@ -112,27 +129,34 @@ def __init__(self, msg: str = None, *, scope=None, required_scope=None, name: st class NotFound(RequestError): status = 404 - def __init__(self, message: str = None, *, path: str = None, query: dict = None, **kwargs): + def __init__( + self, message: str = None, *, path: str = None, query: dict = None, **kwargs + ): from urllib.parse import urlencode + msg = [message] if message else [] if path: - msg.append(f'path: <{path}> not found') + msg.append(f"path: <{path}> not found") if query: - msg.append('query: <%s> not found' % urlencode(query)) + msg.append("query: <%s> not found" % urlencode(query)) if not msg: - msg = ['not found'] + msg = ["not found"] self.path = path self.query = query - super().__init__(message=';'.join(msg), **kwargs) + super().__init__(message=";".join(msg), **kwargs) class MethodNotAllowed(RequestError): status = 405 - def __init__(self, message: str = None, method: str = None, allows: List[str] = None): - self.message = message or f'Method: {method} is not allowed (use methods in {allows})' + def __init__( + self, message: str = None, method: str = None, allows: List[str] = None + ): + self.message = ( + message or f"Method: {method} is not allowed (use methods in {allows})" + ) self.forbid = method - self.allows = ', '.join([str(a) for a in allows]) if allows else '' + self.allows = ", ".join([str(a) for a in allows]) if allows else "" self.append_headers = {Header.ALLOW: self.allows} super().__init__(self.message) @@ -214,17 +238,18 @@ class UpgradeRequired(RequestError): def __init__(self, message: str = None, scheme: str = None): from utilmeta.utils.constant import Header, Scheme + self.message = message self.scheme = scheme if scheme == Scheme.WS: - scheme = 'websocket' + scheme = "websocket" elif isinstance(scheme, str): scheme = scheme.upper() else: scheme = Scheme.HTTPS.upper() self.append_headers = { Header.UPGRADE: scheme, - Header.CONNECTION: Header.UPGRADE + Header.CONNECTION: Header.UPGRADE, } super().__init__(message) @@ -256,13 +281,13 @@ def __init__(self, message=None, response=None): class MaxRetriesExceed(ServerError, RuntimeError): def __init__(self, msg: str = None, max_retries: int = None): - super().__init__(msg or f'Max retries exceeded: {max_retries}') + super().__init__(msg or f"Max retries exceeded: {max_retries}") self.max_retries = max_retries class MaxRetriesTimeoutExceed(ServerError, TimeoutError): def __init__(self, msg: str = None, max_retries_timeout: float = None): - super().__init__(msg or f'Max retries timeout exceeded: {max_retries_timeout}') + super().__init__(msg or f"Max retries timeout exceeded: {max_retries_timeout}") self.max_retries_timeout = max_retries_timeout @@ -325,10 +350,12 @@ class NetworkAuthenticationRequired(ServerError): def http_error(status: int = 400, message: str = None): if status < 400 or status > 600: - raise ValueError(f'Invalid HTTP error status: {status} must in 400~600') + raise ValueError(f"Invalid HTTP error status: {status} must in 400~600") error = HttpError.STATUS_EXCEPTIONS.get(status) if not error: + class error(HttpError): pass + error.status = status return error(message=message) diff --git a/utilmeta/utils/exceptions/runtime.py b/utilmeta/utils/exceptions/runtime.py index ada9fef..b03bb67 100644 --- a/utilmeta/utils/exceptions/runtime.py +++ b/utilmeta/utils/exceptions/runtime.py @@ -1,5 +1,3 @@ - - class NoAvailableInstances(RuntimeError): pass @@ -16,6 +14,7 @@ class CombinedError(Exception): Error util will recognize this class and derive it's children errors (along with there traceback) so developer can do a much better logging and self-defined handling in error hooks """ + def __init__(self, *errors: Exception): self.errors = [] messages = [] @@ -28,9 +27,8 @@ def __init__(self, *errors: Exception): if str(err) not in messages: messages.append(str(err)) self.errors.append(err) - self.message = ';'.join(messages) + self.message = ";".join(messages) super().__init__(self.message) def __str__(self): return self.message - diff --git a/utilmeta/utils/functional/data.py b/utilmeta/utils/functional/data.py index dceadf8..4a57f46 100644 --- a/utilmeta/utils/functional/data.py +++ b/utilmeta/utils/functional/data.py @@ -10,34 +10,53 @@ __all__ = [ - 'repeat', 'multi', 'duplicate', - 'pop', - 'distinct', - 'gen_key', 'order_dict', 'parse_list', - 'keys_or_args', - 'order_list', 'setval', 'dict_list', - 'regular', 'hide_secret_values', - 'readable', - 'readable_size', - 'make_percent', - 'restrict_keys', - 'map_dict', - 'copy_value', 'iterable', 'is_number', - 'dict_number_sum', - 'get_arg', - 'distinct_add', - 'merge_list', - 'dict_number_add', - 'make_dict_by', - 'reduce_value', - 'get_number', 'is_sub_dict', 'convert_data_frame', - 'based_number', 'get_based_number', 'list_or_args', 'bi_search', 'replace_null', - 'make_hash', 'avg', 'pop_null', 'dict_list_merge', 'normalize_title' + "repeat", + "multi", + "duplicate", + "pop", + "distinct", + "gen_key", + "order_dict", + "parse_list", + "keys_or_args", + "order_list", + "setval", + "dict_list", + "regular", + "hide_secret_values", + "readable", + "readable_size", + "make_percent", + "restrict_keys", + "map_dict", + "copy_value", + "iterable", + "is_number", + "dict_number_sum", + "get_arg", + "distinct_add", + "merge_list", + "dict_number_add", + "make_dict_by", + "reduce_value", + "get_number", + "is_sub_dict", + "convert_data_frame", + "based_number", + "get_based_number", + "list_or_args", + "bi_search", + "replace_null", + "make_hash", + "avg", + "pop_null", + "dict_list_merge", + "normalize_title", ] def normalize_title(title: str): - return ' '.join(re.sub(r'\s', ' ', title).split()) + return " ".join(re.sub(r"\s", " ", title).split()) def _list_dict(val): @@ -99,7 +118,14 @@ def duplicate(lst): return set(_list) -def bi_search(targets: list, val, key=lambda x: x, sort: bool = False, start: int = 0, end: int = None): +def bi_search( + targets: list, + val, + key=lambda x: x, + sort: bool = False, + start: int = 0, + end: int = None, +): """ as small as possible """ @@ -128,10 +154,10 @@ def bi_search(targets: list, val, key=lambda x: x, sort: bool = False, start: in def regular(s: str) -> str: - rs = '' + rs = "" for char in s: if char in constant.Reg.META: - rs += f'\\{char}' + rs += f"\\{char}" else: rs += char return rs @@ -150,18 +176,22 @@ def make_percent(val, tot, fix=1, _100=True): def multi(f): - return isinstance(f, (list, set, frozenset, tuple, type({}.values()), type({}.keys()))) + return isinstance( + f, (list, set, frozenset, tuple, type({}.values()), type({}.keys())) + ) def pop(data, key, default=None): if isinstance(data, (dict, Mapping)): - return data.pop(key) if key in data and hasattr(data, 'pop') else default + return data.pop(key) if key in data and hasattr(data, "pop") else default elif isinstance(data, list): return data.pop(key) if key < len(data) else default return default -def make_dict_by(values: List[dict], key: str, formatter: Callable = lambda x: x) -> Dict[Any, List[dict]]: +def make_dict_by( + values: List[dict], key: str, formatter: Callable = lambda x: x +) -> Dict[Any, List[dict]]: """ make a dict that keys is the key value of every items in values make_dict_by([{'k': 1, 'v': 2}, {'k': 1, 'v': 3}, {'k': 2, 'v': 2}], key='k') @@ -220,11 +250,11 @@ def dict_number_add(base: dict, data: dict, nested: bool = False, flag: int = 1) result = {} for key, val in base.items(): if key in data: - result[key] = dict_number_add( - val, data[key], - nested=nested, - flag=flag - ) if nested else (data[key] + val * flag) + result[key] = ( + dict_number_add(val, data[key], nested=nested, flag=flag) + if nested + else (data[key] + val * flag) + ) else: result[key] = val for key, val in data.items(): @@ -286,7 +316,12 @@ def order_list(data: list, orders: list, by: str, join_rest: bool = False) -> li def order_dict(data: dict, orders: tuple) -> OrderedDict: # consider the origin data is already ordered (odict_items) - return OrderedDict(sorted(list(data.items()), key=lambda item: orders.index(item[0]) if item[0] in orders else 0)) + return OrderedDict( + sorted( + list(data.items()), + key=lambda item: orders.index(item[0]) if item[0] in orders else 0, + ) + ) def distinct(data: list, key: str = None, val_type: type = None) -> list: @@ -326,7 +361,7 @@ def get_number(num_str: str, ignore: bool = True) -> Union[int, float, None]: except (ValueError, TypeError): if ignore: return None - raise TypeError(f'Invalid number: {num_str}') + raise TypeError(f"Invalid number: {num_str}") else: if value.is_integer(): return int(value) @@ -335,14 +370,14 @@ def get_number(num_str: str, ignore: bool = True) -> Union[int, float, None]: def reduce_value(data, max_length: int) -> dict: result = {} - t = 'string' + t = "string" items = None if multi(data): - t = 'array' + t = "array" length = len(str(data)) items = len(data) elif isinstance(data, (dict, Mapping)): - t = 'object' + t = "object" length = len(str(data)) items = len(data) elif isinstance(data, (str, bytes)): @@ -352,13 +387,10 @@ def reduce_value(data, max_length: int) -> dict: if length <= max_length: return data if isinstance(data, bytes): - data = data.decode('utf-8', 'ignore') - result['$reduced'] = True + data = data.decode("utf-8", "ignore") + result["$reduced"] = True result.update( - type=t, - length=length, - content=str(data)[:max_length], - content_length=max_length + type=t, length=length, content=str(data)[:max_length], content_length=max_length ) if items is not None: result.update(items=items) @@ -371,25 +403,25 @@ def readable(data, max_length: int = 20, more: bool = True) -> str: _bytes = False if isinstance(data, bytes): _bytes = True - data = data.decode('utf-8', 'ignore') + data = data.decode("utf-8", "ignore") if multi(data): # if not rep: # if len(str(data)) <= max_length: # return str(data) # return str(data)[:max_length] + ('...' if more else '') - form = {list: '[%s]', tuple: '(%s)', set: '{%s}'} + form = {list: "[%s]", tuple: "(%s)", set: "{%s}"} items = [] total = 0 for d in data: total += len(repr(d)) if total > max_length: - items.append(f'...({len(data)} items)' if more else '') + items.append(f"...({len(data)} items)" if more else "") break items.append(repr(d)) for t, fmt in form.items(): if isinstance(data, t): - return fmt % ', '.join(items) - return '[%s]' % ', '.join(items) + return fmt % ", ".join(items) + return "[%s]" % ", ".join(items) elif isinstance(data, dict): # if not rep: # if len(str(data)) <= max_length: @@ -400,34 +432,37 @@ def readable(data, max_length: int = 20, more: bool = True) -> str: for k, v in data.items(): total += len(str(k)) + len(str(v)) + 2 if total > max_length: - items.append(f'...({len(data)} items)' if more else '') + items.append(f"...({len(data)} items)" if more else "") break - items.append(repr(k) + ': ' + repr(v)) - return '{%s}' % ', '.join(items) + items.append(repr(k) + ": " + repr(v)) + return "{%s}" % ", ".join(items) elif is_number(data): return str(data) if not isinstance(data, str): from .py import represent + return represent(data) excess = len(data) - max_length if excess >= 0: - data = data[:max_length] + (f'...({excess} more chars)' if more else '') + data = data[:max_length] + (f"...({excess} more chars)" if more else "") result = repr(data) if _bytes: - result = 'b' + result + result = "b" + result return result def readable_size(size, depth=0): - unit = ['B', 'KB', 'MB', 'GB', 'TB'] + unit = ["B", "KB", "MB", "GB", "TB"] if size < 1 or not size: return f"0 {unit[depth]}" if 1 <= size < 1000: - return f'{str(size)} {unit[depth]}' - return readable_size(int(size/1024), depth+1) + return f"{str(size)} {unit[depth]}" + return readable_size(int(size / 1024), depth + 1) -def parse_list(data, merge=False, distinct_merge=None, merge_type: type = None) -> Union[list, tuple]: +def parse_list( + data, merge=False, distinct_merge=None, merge_type: type = None +) -> Union[list, tuple]: if multi(data): if merge: result = [] @@ -435,10 +470,16 @@ def parse_list(data, merge=False, distinct_merge=None, merge_type: type = None) for d in data: if multi(d): multi_occur = True - result += parse_list(d, merge_type=merge_type, merge=True, distinct_merge=distinct_merge) + result += parse_list( + d, + merge_type=merge_type, + merge=True, + distinct_merge=distinct_merge, + ) else: if merge_type: from utype import type_transform + d = type_transform(d, merge_type) result.append(d) if distinct_merge is None and multi_occur: @@ -453,18 +494,21 @@ def parse_list(data, merge=False, distinct_merge=None, merge_type: type = None) elif not data: return [] elif type(data) == str: + def start_end(value: str, start, end=None): if end is None: end = start return value.startswith(start) and value.endswith(end) data: str - maybe_list = start_end(data, '[', ']') - maybe_tuple = start_end(data, '(', ')') - maybe_json = start_end(data, '{', '}') - spliter = ';' if ';' in data else ',' + maybe_list = start_end(data, "[", "]") + maybe_tuple = start_end(data, "(", ")") + maybe_json = start_end(data, "{", "}") + spliter = ";" if ";" in data else "," if maybe_tuple: - return tuple([v.strip() for v in data.lstrip('(').rstrip(')').split(spliter)]) + return tuple( + [v.strip() for v in data.lstrip("(").rstrip(")").split(spliter)] + ) if maybe_list or maybe_json: try: return parse_list(json.loads(data, strict=False), merge=merge) @@ -473,7 +517,12 @@ def start_end(value: str, start, end=None): elif spliter in data: return [v.strip() for v in data.strip().split(spliter)] elif not isinstance(data, constant.COMMON_TYPES) and iterable(data): - return parse_list(list(data), merge=merge, distinct_merge=distinct_merge, merge_type=merge_type) + return parse_list( + list(data), + merge=merge, + distinct_merge=distinct_merge, + merge_type=merge_type, + ) return [data] @@ -481,11 +530,11 @@ def based_number(num: int, base: int = 10) -> str: num, base = int(num), int(base) n = abs(num) if base <= 1 or base > len(constant.ELEMENTS): - raise ValueError(f'number base should > 1 and <= {len(constant.ELEMENTS)}') + raise ValueError(f"number base should > 1 and <= {len(constant.ELEMENTS)}") if base == 10: return str(n) # values = [] - output = '' + output = "" elements = constant.ELEMENTS[0:base] while n: i = n % base @@ -501,8 +550,8 @@ def get_based_number(num: Union[str, int], from_base: int, to_base: int = 10) -> if from_base <= 36: return based_number(int(num, base=from_base), base=to_base) if from_base > len(constant.ELEMENTS): - raise ValueError(f'number base should > 1 and < {len(constant.ELEMENTS)}') - num = str(num).lstrip('-') + raise ValueError(f"number base should > 1 and < {len(constant.ELEMENTS)}") + num = str(num).lstrip("-") value = 0 for i, n in enumerate(num): value += constant.ELEMENTS.index(n) * from_base ** (len(num) - i - 1) @@ -570,7 +619,9 @@ def distinct_add(target: list, data): if not data: return target if not isinstance(target, list): - raise TypeError(f'Invalid distinct_add target type: {type(target)}, must be lsit') + raise TypeError( + f"Invalid distinct_add target type: {type(target)}, must be lsit" + ) # target = list(target) if not multi(data): if data not in target: @@ -586,8 +637,10 @@ def replace_null(data: dict, default=0): return {key: val or default for key, val in data.items()} -def make_hash(value: str, seed: str = '', mod: int = 2 ** 32): - return int(hashlib.md5((str(value) + str(seed or '')).encode()).hexdigest(), 16) % mod +def make_hash(value: str, seed: str = "", mod: int = 2**32): + return ( + int(hashlib.md5((str(value) + str(seed or "")).encode()).hexdigest(), 16) % mod + ) def restrict_keys(keys: Union[list, tuple, set], data: dict, default=None) -> dict: @@ -608,9 +661,11 @@ def merge_list(*lst, keys=None) -> List[dict]: return result -def convert_data_frame(data: List[dict], align: bool = False, depth: int = 1, keys: List[str] = ()) -> Dict[str, list]: +def convert_data_frame( + data: List[dict], align: bool = False, depth: int = 1, keys: List[str] = () +) -> Dict[str, list]: if not depth: - return data # noqa + return data # noqa if not iterable(data) or not data: return {key: [] for key in keys} result = {} @@ -632,35 +687,48 @@ def convert_data_frame(data: List[dict], align: bool = False, depth: int = 1, ke return result -def gen_key(digit=64, alnum=False, lower=False, excludes: List[str] = ('$', '\\')) -> str: +def gen_key( + digit=64, alnum=False, lower=False, excludes: List[str] = ("$", "\\") +) -> str: import secrets + sample = string.digits if alnum: sample += string.ascii_lowercase if lower else string.ascii_letters else: sample = string.printable[:94] for ex in excludes: - sample = sample.replace(ex, '') + sample = sample.replace(ex, "") while len(sample) < digit: sample += sample - return ''.join(secrets.choice(sample) for i in range(digit)) # noqa + return "".join(secrets.choice(sample) for i in range(digit)) # noqa # return ''.join(random.sample(sample, digit)) -def hide_secret_values(data, secret_names, secret_value=constant.SECRET, file_repr: Union[Callable, str] = ''): +def hide_secret_values( + data, + secret_names, + secret_value=constant.SECRET, + file_repr: Union[Callable, str] = "", +): if not secret_names: return data if data is None: return data from .py import file_like + if isinstance(data, dict): result = {} for k, v in data.items(): k: str if isinstance(v, list): - result[k] = hide_secret_values(v, secret_names, secret_value=secret_value, file_repr=file_repr) + result[k] = hide_secret_values( + v, secret_names, secret_value=secret_value, file_repr=file_repr + ) elif isinstance(v, dict): - result[k] = hide_secret_values(v, secret_names, secret_value=secret_value, file_repr=file_repr) + result[k] = hide_secret_values( + v, secret_names, secret_value=secret_value, file_repr=file_repr + ) elif file_like(v): result[k] = file_repr(v) if callable(file_repr) else file_repr else: @@ -676,7 +744,11 @@ def hide_secret_values(data, secret_names, secret_value=constant.SECRET, file_re if isinstance(data, list): result = [] for d in data: - result.append(hide_secret_values(d, secret_names, secret_value=secret_value, file_repr=file_repr)) + result.append( + hide_secret_values( + d, secret_names, secret_value=secret_value, file_repr=file_repr + ) + ) return result if file_like(data): return file_repr(data) if callable(file_repr) else file_repr diff --git a/utilmeta/utils/functional/orm.py b/utilmeta/utils/functional/orm.py index 5bd0eef..9312ec9 100644 --- a/utilmeta/utils/functional/orm.py +++ b/utilmeta/utils/functional/orm.py @@ -4,17 +4,18 @@ from functools import wraps __all__ = [ - 'get_sql_info', - 'print_queries', + "get_sql_info", + "print_queries", ] -def print_queries(alias='default'): +def print_queries(alias="default"): def deco(f, a=alias): @wraps(f) def wrapper(*args, **kwargs): from django.db import connections import time + start = time.time() connection = connections[a] qs = len(connection.queries) @@ -22,39 +23,48 @@ def wrapper(*args, **kwargs): end = time.time() t = round(end - start, 3) qc = len(connection.queries) - qs - print(f'function {f.__name__} cost {t} s with {qc} queries') + print(f"function {f.__name__} cost {t} s with {qc} queries") return r + return wrapper if callable(alias): - return deco(alias, 'default') + return deco(alias, "default") return deco -def get_sql_info(sql_str: str, table_min_length: int = 2, str_parse: bool = True) -> Tuple[str, List[str]]: +def get_sql_info( + sql_str: str, table_min_length: int = 2, str_parse: bool = True +) -> Tuple[str, List[str]]: if not sql_str: - return '', [] + return "", [] - sql_str = re.sub(re.compile('in \\((.+?)\\)', re.I | re.S), 'IN (0)', sql_str) - sql_str = re.sub(re.compile("'(.+?)'", re.I | re.S), "''", sql_str) # avoid string interfere the string parse + sql_str = re.sub(re.compile("in \\((.+?)\\)", re.I | re.S), "IN (0)", sql_str) + sql_str = re.sub( + re.compile("'(.+?)'", re.I | re.S), "''", sql_str + ) # avoid string interfere the string parse # replace large in sector to improve performance if str_parse: # patterns = [' FROM "%s"', ' JOIN "%s"', ' INTO "%s"', 'UPDATE "%s" '] # types = ['select', 'update', 'insert', 'delete'] types_map = { - 'select': [' from ', ' join '], - 'update': ['update '], - 'delete': [' from '], - 'insert': ['insert into '] + "select": [" from ", " join "], + "update": ["update "], + "delete": [" from "], + "insert": ["insert into "], } opt = sql_str.split()[0] tokens = types_map.get(opt.lower()) if not tokens: - return '', [] + return "", [] tables = [] for token in tokens: - b, *s = (sql_str.split(token) if token in sql_str else sql_str.split(token.upper())) + b, *s = ( + sql_str.split(token) + if token in sql_str + else sql_str.split(token.upper()) + ) for statement in s: tb: str = statement.strip().split()[0].strip('"') if table_min_length and len(tb) <= table_min_length: @@ -66,11 +76,12 @@ def get_sql_info(sql_str: str, table_min_length: int = 2, str_parse: bool = True import sqlparse from sqlparse.sql import Identifier + sql = sqlparse.parse(sql_str) if not sql: - return '', [] + return "", [] tables = [] - preserves = ['CONFLICT', 'subquery'] + preserves = ["CONFLICT", "subquery"] for statement in sql: for token in statement.tokens: if isinstance(token, Identifier): diff --git a/utilmeta/utils/functional/py.py b/utilmeta/utils/functional/py.py index 44801fd..b2f3f4f 100644 --- a/utilmeta/utils/functional/py.py +++ b/utilmeta/utils/functional/py.py @@ -6,29 +6,31 @@ from functools import wraps __all__ = [ - 'return_type', - 'file_like', - 'print_time', - 'import_obj', - 'print_dict', - 'class_func', - 'get_generator_result', - 'represent', - 'valid_attr', - 'check_requirement', - 'get_root_base', - 'function_pass', - 'common_representable', - 'get_doc', - 'requires', - 'get_base_type', + "return_type", + "file_like", + "print_time", + "import_obj", + "print_dict", + "class_func", + "get_generator_result", + "represent", + "valid_attr", + "check_requirement", + "get_root_base", + "function_pass", + "common_representable", + "get_doc", + "requires", + "get_base_type", ] -def _f_pass_doc(): """""" +def _f_pass_doc(): + """""" -def _f_pass(): pass +def _f_pass(): + pass PASSED_CODES = ( @@ -40,15 +42,21 @@ def _f_pass(): pass def represent(val) -> str: if isinstance(val, type): if val is type(None): - return 'type(None)' + return "type(None)" return val.__name__ - if inspect.isfunction(val) or inspect.ismethod(val) or inspect.isclass(val) or inspect.isbuiltin(val): + if ( + inspect.isfunction(val) + or inspect.ismethod(val) + or inspect.isclass(val) + or inspect.isbuiltin(val) + ): return val.__name__ return repr(val) def common_representable(data) -> bool: from .data import multi + if multi(data): for val in data: if not common_representable(val): @@ -65,12 +73,16 @@ def common_representable(data) -> bool: def class_func(f): - return isinstance(f, (staticmethod, classmethod)) or inspect.ismethod(f) or inspect.isfunction(f) + return ( + isinstance(f, (staticmethod, classmethod)) + or inspect.ismethod(f) + or inspect.isfunction(f) + ) def function_pass(f): if not inspect.isfunction(f): - f = getattr(f, '__func__', None) + f = getattr(f, "__func__", None) if not f or not inspect.isfunction(f): return False return getattr(f, constant.Attr.CODE).co_code in PASSED_CODES @@ -78,6 +90,7 @@ def function_pass(f): def valid_attr(name: str): from keyword import iskeyword + return name.isidentifier() and not iskeyword(name) @@ -115,38 +128,44 @@ def print_time(f): @wraps(f) def wrapper(*args, **kwargs): import time + start = time.time() r = f(*args, **kwargs) end = time.time() t = round(end - start, 3) name = f.__name__ - print(f'function {name} cost {t} s') + print(f"function {name} cost {t} s") return r + return wrapper def print_dict(data): if isinstance(data, list): - print('[') + print("[") for d in data: print_dict(d) - print(',') - print(']\n') + print(",") + print("]\n") return - items = getattr(data, 'items', None) + items = getattr(data, "items", None) if callable(items): - print('{') + print("{") for key, val in items(): - print(f'\t{repr(key)}: {repr(val)},') - print('}') + print(f"\t{repr(key)}: {repr(val)},") + print("}") return print(data) def file_like(obj) -> bool: try: - return callable(getattr(obj, 'read')) and callable(getattr(obj, 'seek')) \ - and callable(getattr(obj, 'write')) and callable(getattr(obj, 'close')) + return ( + callable(getattr(obj, "read")) + and callable(getattr(obj, "seek")) + and callable(getattr(obj, "write")) + and callable(getattr(obj, "close")) + ) except AttributeError: return False @@ -156,11 +175,11 @@ def return_type(f, raw: bool = False): f = f.__func__ if not f: return None - _t = getattr(f, constant.Attr.ANNOTATES).get('return') + _t = getattr(f, constant.Attr.ANNOTATES).get("return") if raw: return _t try: - t = typing.get_type_hints(f).get('return') + t = typing.get_type_hints(f).get("return") except NameError: return _t if t is type(None): @@ -169,7 +188,7 @@ def return_type(f, raw: bool = False): def get_generator_result(result): - if hasattr(result, '__next__'): + if hasattr(result, "__next__"): # convert generator yield result into list # result = list(result) values = [] @@ -179,7 +198,7 @@ def get_generator_result(result): v = next(result) except StopIteration: break - if hasattr(v, '__next__'): + if hasattr(v, "__next__"): recursive = True result = v else: @@ -195,14 +214,16 @@ def get_generator_result(result): def get_doc(obj) -> str: if not obj: - return '' + return "" if isinstance(obj, str): return obj - return inspect.cleandoc(getattr(obj, constant.Attr.DOC, '') or '') + return inspect.cleandoc(getattr(obj, constant.Attr.DOC, "") or "") # return (getattr(obj, Attr.DOC, '') or '').replace('\t', '').strip('\n').strip() -def check_requirement(*pkgs: str, hint: str = None, check: bool = True, install_when_require: bool = False): +def check_requirement( + *pkgs: str, hint: str = None, check: bool = True, install_when_require: bool = False +): if len(pkgs) > 1: for pkg in pkgs: try: @@ -217,25 +238,30 @@ def check_requirement(*pkgs: str, hint: str = None, check: bool = True, install_ raise ImportError except (ModuleNotFoundError, ImportError): if install_when_require: - print(f'INFO: current service require <{pkg}> package, installing...') + print(f"INFO: current service require <{pkg}> package, installing...") try: import sys - os.system(f'{sys.executable} -m pip install {pkg}') + + os.system(f"{sys.executable} -m pip install {pkg}") except Exception as e: - print(f'install package failed with error: {e}, fallback to internal solution') - pip_main = import_obj('pip._internal:main') - pip_main(['install', pkg]) + print( + f"install package failed with error: {e}, fallback to internal solution" + ) + pip_main = import_obj("pip._internal:main") + pip_main(["install", pkg]) else: if hint: print(hint) - raise ImportError(f'package <{pkg}> is required for current settings, please install it ' - f'or set install-when-require=True at meta.ini to allow auto installation') + raise ImportError( + f"package <{pkg}> is required for current settings, please install it " + f"or set install-when-require=True at meta.ini to allow auto installation" + ) def requires(*names, **mp): if names: for name in names: - mp[name] = str(name).split('.')[0] + mp[name] = str(name).split(".")[0] for import_name, install_name in mp.items(): try: return import_obj(import_name) @@ -243,14 +269,17 @@ def requires(*names, **mp): except (ModuleNotFoundError, ImportError): pass for import_name, install_name in mp.items(): - print(f'INFO: current service require <{install_name}> package, installing...') + print(f"INFO: current service require <{install_name}> package, installing...") try: import sys - os.system(f'{sys.executable} -m pip install {install_name}') + + os.system(f"{sys.executable} -m pip install {install_name}") except Exception as e: - print(f'install package failed with error: {e}, fallback to internal solution') - pip_main = import_obj('pip._internal:main') - pip_main(['install', install_name]) + print( + f"install package failed with error: {e}, fallback to internal solution" + ) + pip_main = import_obj("pip._internal:main") + pip_main(["install", install_name]) try: return import_obj(import_name) except (ModuleNotFoundError, ImportError): @@ -263,8 +292,9 @@ def import_obj(dotted_path): Import a dotted module path and return the attribute/class designated by the last name in the path. Raise ImportError if the import failed. """ - if '/' in dotted_path and os.path.exists(dotted_path): + if "/" in dotted_path and os.path.exists(dotted_path): from importlib.util import spec_from_file_location + name = dotted_path.split(os.sep)[-1].rstrip(constant.PY) spec = spec_from_file_location(name, dotted_path) return spec.loader.load_module() @@ -274,14 +304,14 @@ def import_obj(dotted_path): except (ImportError, ModuleNotFoundError) as e: if dotted_path not in str(e): raise e - if ':' not in dotted_path and '.' not in dotted_path: + if ":" not in dotted_path and "." not in dotted_path: # module only return importlib.import_module(dotted_path) try: - if ':' in dotted_path: - module_path, class_name = dotted_path.split(':') + if ":" in dotted_path: + module_path, class_name = dotted_path.split(":") else: - module_path, class_name = dotted_path.rsplit('.', 1) + module_path, class_name = dotted_path.rsplit(".", 1) except ValueError as err: raise ImportError("%s doesn't look like a module path" % dotted_path) from err @@ -290,6 +320,7 @@ def import_obj(dotted_path): try: return getattr(module, class_name) except AttributeError as err: - raise ImportError('Module "%s" does not define a "%s" attribute/class' % ( - module_path, class_name) + raise ImportError( + 'Module "%s" does not define a "%s" attribute/class' + % (module_path, class_name) ) from err diff --git a/utilmeta/utils/functional/sys.py b/utilmeta/utils/functional/sys.py index ecb2a6e..5267fdd 100644 --- a/utilmeta/utils/functional/sys.py +++ b/utilmeta/utils/functional/sys.py @@ -5,40 +5,47 @@ from typing import Optional, List, Union, Tuple, Dict from .. import constant from ipaddress import ip_address, ip_network -posix_os = os.name == 'posix' + +posix_os = os.name == "posix" __all__ = [ - 'dir_getsize', 'file_num', - 'get_ip', - 'path_merge', - 'load_ini', 'search_file', - 'clear', 'kill', - 'write_config', 'path_join', - 'running', - 'posix_os', - 'port_using', - 'run', - 'get_real_file', - 'read_from_socket', - 'sys_user_exists', 'sys_user_add', 'find_port', - 'parse_socket', - 'get_processes', - 'get_code', - 'current_master', - 'get_system_fds', - 'read_from', - 'write_to', - 'get_server_ip', - 'get_real_ip', - 'remove_file', - 'get_max_socket_conn', - 'get_max_open_files', - 'ip_belong_networks', - 'get_system_open_files', - 'create_only_write', - 'get_recursive_dirs', - 'get_sys_net_connections_info', - 'get_mac_address', + "dir_getsize", + "file_num", + "get_ip", + "path_merge", + "load_ini", + "search_file", + "clear", + "kill", + "write_config", + "path_join", + "running", + "posix_os", + "port_using", + "run", + "get_real_file", + "read_from_socket", + "sys_user_exists", + "sys_user_add", + "find_port", + "parse_socket", + "get_processes", + "get_code", + "current_master", + "get_system_fds", + "read_from", + "write_to", + "get_server_ip", + "get_real_ip", + "remove_file", + "get_max_socket_conn", + "get_max_open_files", + "ip_belong_networks", + "get_system_open_files", + "create_only_write", + "get_recursive_dirs", + "get_sys_net_connections_info", + "get_mac_address", ] import uuid @@ -60,19 +67,21 @@ def remove_file(file: str, ignore_not_found: bool = True): except PermissionError as e: if not posix_os: raise e - return not os.system(f'sudo rm {file}') + return not os.system(f"sudo rm {file}") -def write_to(file: str, content: str, mode: str = 'w', encoding=None): +def write_to(file: str, content: str, mode: str = "w", encoding=None): try: with open(file, mode, encoding=encoding) as f: f.write(content) except (PermissionError, FileNotFoundError) as e: if not posix_os: raise e - note = '-a' if mode.startswith('a') else '' + note = "-a" if mode.startswith("a") else "" # use single quote here to apply $ escape - os.system(f"echo \'{content}\' | sudo tee {note} {file} >> /dev/null") # ignore file content output + os.system( + f"echo '{content}' | sudo tee {note} {file} >> /dev/null" + ) # ignore file content output def create_only_write(file: str, content: str, fail_silently: bool = False): @@ -90,27 +99,28 @@ def create_only_write(file: str, content: str, fail_silently: bool = False): if not fail_silently: raise else: # No exception, so the file must have been created successfully. - with os.fdopen(file_handle, 'w') as file_obj: + with os.fdopen(file_handle, "w") as file_obj: # Using `os.fdopen` converts the handle to an object that acts like a # regular Python file object, and the `with` context manager means the # file will be automatically closed when we're done with it. file_obj.write(content) -def read_from(file, mode: str = 'r') -> str: +def read_from(file, mode: str = "r") -> str: try: - with open(file, mode, errors='ignore') as f: + with open(file, mode, errors="ignore") as f: return f.read() except PermissionError: if posix_os: - return os.popen(f'sudo cat {file}').read() - return '' + return os.popen(f"sudo cat {file}").read() + return "" except FileNotFoundError: - return '' + return "" def get_system_fds(): import psutil + fds = 0 if psutil.POSIX: for proc in psutil.process_iter(): @@ -123,7 +133,7 @@ def get_system_fds(): def get_system_open_files(): try: - with open('/proc/sys/fs/file-nr') as file_nr: + with open("/proc/sys/fs/file-nr") as file_nr: stats = file_nr.read().split() total_fds = int(stats[0]) - int(stats[1]) return total_fds @@ -142,6 +152,7 @@ def get_system_open_files(): def get_network_ip(ifname: str): import struct + try: import fcntl except ModuleNotFoundError: @@ -150,7 +161,9 @@ def get_network_ip(ifname: str): ifname = ifname.encode() s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) try: - return socket.inet_ntoa(fcntl.ioctl(s.fileno(), 0x8915, struct.pack('256s', ifname[:15]))[20:24]) # noqa + return socket.inet_ntoa( + fcntl.ioctl(s.fileno(), 0x8915, struct.pack("256s", ifname[:15]))[20:24] + ) # noqa except OSError: return None @@ -174,16 +187,16 @@ def get_server_ip(private_only: bool = False) -> Optional[str]: ips = set() for i in range(0, 3): - if_ip = get_network_ip(f'eth{i}') + if_ip = get_network_ip(f"eth{i}") if if_ip: - if not if_ip.startswith('127.'): + if not if_ip.startswith("127."): ips.add(if_ip) else: break if ip: - if ip.startswith('127.'): - s.connect(('8.8.8.8', 53)) + if ip.startswith("127."): + s.connect(("8.8.8.8", 53)) ip = str(s.getsockname()[0]) if ip: ips.add(ip) @@ -249,7 +262,7 @@ def get_real_file(path: str): except (OSError, FileNotFoundError): real_path = path if not os.path.exists(real_path): - raise FileNotFoundError(f'file: {repr(real_path)} not exists') + raise FileNotFoundError(f"file: {repr(real_path)} not exists") return real_path @@ -257,49 +270,61 @@ def load_ini(content: str, parse_key: bool = False) -> dict: ini = {} dic = {} for ln in content.splitlines(): - line = ln.replace(' ', '').replace('\t', '') + line = ln.replace(" ", "").replace("\t", "") if not line or not line.split(): continue - annotate = line.split()[0].startswith('#') or line.split()[0].startswith(';') + annotate = line.split()[0].startswith("#") or line.split()[0].startswith(";") if annotate: continue - if re.fullmatch(r'\[(.*?)\]', line): - key = line.strip('[]') + if re.fullmatch(r"\[(.*?)\]", line): + key = line.strip("[]") ini[key] = dic = {} else: - if '=' not in line: + if "=" not in line: continue from utype import TypeTransformer - key, val = line.split('=') + + key, val = line.split("=") if parse_key: - key = key.replace('_', '-').lower() + key = key.replace("_", "-").lower() if val.isdigit(): val = int(val) elif val in TypeTransformer.FALSE_VALUES: val = False dic[key] = val - return ini or dic # load no header conf file as well + return ini or dic # load no header conf file as well -def write_config(data: dict, path: str, append: bool = False, ini_syntax: bool = True) -> str: - content = '' +def write_config( + data: dict, path: str, append: bool = False, ini_syntax: bool = True +) -> str: + content = "" if ini_syntax: for key, val in data.items(): - content += f'[{key}]\n' - assert type(val) == dict, TypeError(f"write ini failed, syntax error: {val} should be dict") + content += f"[{key}]\n" + assert type(val) == dict, TypeError( + f"write ini failed, syntax error: {val} should be dict" + ) for k, v in val.items(): - content += f'{k} = {v}\n' - content += '\n' + content += f"{k} = {v}\n" + content += "\n" else: for k, v in data.items(): - content += f'{k} = {repr(v)}\n' - content += '\n' + content += f"{k} = {repr(v)}\n" + content += "\n" - write_to(path, content=content, mode='a' if append else 'w') + write_to(path, content=content, mode="a" if append else "w") return content -def path_join(base: str, path: str, *, dir: bool = False, create: bool = False, ignore: bool = False): +def path_join( + base: str, + path: str, + *, + dir: bool = False, + create: bool = False, + ignore: bool = False, +): if not path: return None if os.path.isabs(path): @@ -311,7 +336,7 @@ def path_join(base: str, path: str, *, dir: bool = False, create: bool = False, if dir: os.makedirs(p) else: - write_to(p, content='') + write_to(p, content="") elif not ignore: raise OSError(f"path {p} not exists") return p @@ -334,6 +359,7 @@ def clear(filepath): if os.path.isdir(cur_path): if fd == "__pycache__": import shutil + shutil.rmtree(cur_path) else: clear(cur_path) @@ -347,23 +373,25 @@ def run(cmd, *backup_commands): r = os.system(c) if not r: return - print(f'meta: running occur error, aborting..') + print(f"meta: running occur error, aborting..") exit(r) except KeyboardInterrupt: - print('aborting..') + print("aborting..") exit(0) except Exception as e: - print(f'meta: running occur error: {e}, aborting..') + print(f"meta: running occur error: {e}, aborting..") exit(1) def current_master(): import psutil + return bool(psutil.Process(os.getpid()).children()) def get_processes(name, contains: str = None): import psutil + assert name, name ls = [] for p in psutil.process_iter(): @@ -385,6 +413,7 @@ def get_processes(name, contains: str = None): def kill(name, contains: str = None): import psutil + killed = 0 for p in get_processes(name=name, contains=contains): try: @@ -412,6 +441,7 @@ def find(start: int = 8000, end: int = 10000): continue ports.append(p) return p + return find @@ -419,7 +449,7 @@ def get_max_socket_conn(): if not posix_os: return None try: - r = os.popen('cat /proc/sys/net/core/somaxconn').read().strip('\n') + r = os.popen("cat /proc/sys/net/core/somaxconn").read().strip("\n") if not r: return None return int(r) @@ -431,7 +461,7 @@ def get_max_open_files(): if not posix_os: return None try: - r = os.popen('ulimit -n').read().strip('\n') + r = os.popen("ulimit -n").read().strip("\n") if not r: return None return int(r) @@ -441,6 +471,7 @@ def get_max_open_files(): def running(pid): import psutil + try: return psutil.Process(pid).is_running() except psutil.Error: @@ -451,16 +482,18 @@ def parse_socket(sock: str, valid_path: bool = False): file = False if callable(sock): sock = sock() - assert sock, f'Invalid socket: {sock}' + assert sock, f"Invalid socket: {sock}" if isinstance(sock, int) or isinstance(sock, str) and sock.isdigit(): try: sock = int(sock) assert 1000 < sock < 65536 except (TypeError, ValueError, AssertionError) as e: - raise ValueError(f'socket must be a valid .sock file path or a int port ' - f'in (1000, 65536), got {sock} with error {e}') + raise ValueError( + f"socket must be a valid .sock file path or a int port " + f"in (1000, 65536), got {sock} with error {e}" + ) else: - sock = f'127.0.0.1:{sock}' + sock = f"127.0.0.1:{sock}" elif valid_path: if os.path.exists(sock): file = True @@ -476,9 +509,9 @@ def read_from_socket(value: Union[str, int], buf: int = 1024 * 4) -> bytes: s.connect(sock) else: s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - host, port = sock.split(':') + host, port = sock.split(":") s.connect((host, int(port))) - data = b'' + data = b"" while 1: d = s.recv(buf) data += d @@ -500,48 +533,57 @@ def sys_user_exists(name: str, group: bool = False): return bool(os.popen(f'grep "{name}" /etc/passwd').read()) -def sys_user_add(name: str, home: str = None, group: str = None, login: bool = True, add_group: bool = False): +def sys_user_add( + name: str, + home: str = None, + group: str = None, + login: bool = True, + add_group: bool = False, +): if not posix_os: return if not name: return if add_group: - os.system(f'groupadd {name}') + os.system(f"groupadd {name}") return items = [] if not login: - items.append('-s /bin/false') + items.append("-s /bin/false") if home: - items.append(f'-d {home}') + items.append(f"-d {home}") if group: - items.append(f'-g {group}') - append_str = ' '.join(items) - os.system(f'useradd {append_str} {name}') + items.append(f"-g {group}") + append_str = " ".join(items) + os.system(f"useradd {append_str} {name}") def get_code(f) -> str: # not used in current version code = getattr(f, constant.Attr.CODE, None) if not code: - return '' + return "" fl = code.co_firstlineno file = code.co_filename content = read_from(file) ft = 0 el = fl for i, line in enumerate(content.splitlines()[fl:]): - tabs = line.count(' ' * 4) + tabs = line.count(" " * 4) if not i: ft = tabs continue if tabs <= ft: el = fl + i break - return '\n'.join(content.splitlines()[fl:el]) + return "\n".join(content.splitlines()[fl:el]) -def get_sys_net_connections_info() -> Tuple[int, int, Dict[str, int]]: # (total, active) +def get_sys_net_connections_info() -> Tuple[ + int, int, Dict[str, int] +]: # (total, active) import psutil + info = {} conns = [] try: @@ -559,7 +601,8 @@ def get_sys_net_connections_info() -> Tuple[int, int, Dict[str, int]]: # (tot break except psutil.Error as e: import warnings - warnings.warn(f'retrieve net connections failed with error: {e}') + + warnings.warn(f"retrieve net connections failed with error: {e}") return 0, 0, {} total = len(conns) active = 0 @@ -575,42 +618,54 @@ def get_sys_net_connections_info() -> Tuple[int, int, Dict[str, int]]: # (tot def path_merge(base: str, path: str): """ - the base param is a absolute dir (usually the os.getcwd()) - path can be regular path like dir1/file2 - or the double-dotted path ../../file1 - in every case the base and path will merge to a absolute new path + the base param is a absolute dir (usually the os.getcwd()) + path can be regular path like dir1/file2 + or the double-dotted path ../../file1 + in every case the base and path will merge to a absolute new path """ - path = path or '' - if path.startswith('./'): - path = path.strip('./') + path = path or "" + if path.startswith("./"): + path = path.strip("./") - if path.startswith('/') or path.startswith('~'): + if path.startswith("/") or path.startswith("~"): return path if not path: - return base or '' + return base or "" - if '..' in path: + if ".." in path: divider = os.sep if divider in path: dirs = path.split(divider) - elif '/' in path: - dirs = path.split('/') - elif '\\' in path: - dirs = path.split('\\') + elif "/" in path: + dirs = path.split("/") + elif "\\" in path: + dirs = path.split("\\") else: dirs = [path] - backs = dirs.count('..') + backs = dirs.count("..") while backs: base = os.path.dirname(base) backs -= 1 - path = divider.join([d for d in dirs if d != '..']) - return os.path.join(base, path).replace('/', os.sep).replace('\\', os.sep).rstrip(os.sep) - - -def get_recursive_dirs(path, exclude_suffixes: List[str] = None, include_suffixes: List[str] = None, - include_path: bool = False, file_stats: bool = False, dir_stats: bool = False, - exclude_seg: bool = False, exclude_dot: bool = False): + path = divider.join([d for d in dirs if d != ".."]) + return ( + os.path.join(base, path) + .replace("/", os.sep) + .replace("\\", os.sep) + .rstrip(os.sep) + ) + + +def get_recursive_dirs( + path, + exclude_suffixes: List[str] = None, + include_suffixes: List[str] = None, + include_path: bool = False, + file_stats: bool = False, + dir_stats: bool = False, + exclude_seg: bool = False, + exclude_dot: bool = False, +): try: ab_path, dirs, files = next(os.walk(path)) except (FileNotFoundError, StopIteration): @@ -634,7 +689,7 @@ def _trans_stats(p): for dir in dirs: if exclude_seg and dir.startswith(constant.SEG): continue - if exclude_dot and dir.startswith('.'): + if exclude_dot and dir.startswith("."): continue dir_path = os.path.join(path, dir) values = get_recursive_dirs( @@ -645,21 +700,20 @@ def _trans_stats(p): exclude_seg=exclude_seg, exclude_dot=exclude_dot, include_suffixes=include_suffixes, - exclude_suffixes=exclude_suffixes - ) - val = dict( - name=dir, - children=values + exclude_suffixes=exclude_suffixes, ) + val = dict(name=dir, children=values) if include_path: val.update(path=dir_path) if dir_stats: val.update(_trans_stats(dir_path)) result.append(val) for file in files: - if include_suffixes and not any([file.endswith(f'.{s}') for s in include_suffixes]): + if include_suffixes and not any( + [file.endswith(f".{s}") for s in include_suffixes] + ): continue - if exclude_suffixes and any([file.endswith(f'.{s}') for s in exclude_suffixes]): + if exclude_suffixes and any([file.endswith(f".{s}") for s in exclude_suffixes]): pass val = dict( name=file, @@ -675,11 +729,13 @@ def _trans_stats(p): def get_ip(host: str, ip_only: bool = False) -> Optional[str]: import ipaddress + try: return str(ipaddress.ip_address(host)) except ValueError: pass from urllib.parse import urlparse, ParseResult + res: ParseResult = urlparse(host) try: return str(ipaddress.ip_address(res.hostname)) @@ -690,6 +746,7 @@ def get_ip(host: str, ip_only: bool = False) -> Optional[str]: if ip_only: return None from .web import get_hostname + try: return socket.gethostbyname(get_hostname(host)) except (socket.error, OSError): @@ -698,6 +755,7 @@ def get_ip(host: str, ip_only: bool = False) -> Optional[str]: def get_real_ip(ip: str): from .web import localhost + if localhost(ip): return get_server_ip() return get_ip(ip) diff --git a/utilmeta/utils/functional/time.py b/utilmeta/utils/functional/time.py index 6c63e22..df7271c 100644 --- a/utilmeta/utils/functional/time.py +++ b/utilmeta/utils/functional/time.py @@ -1,16 +1,17 @@ from datetime import datetime, timedelta, tzinfo, timezone -from typing import Union, Optional +from typing import Union, Optional import decimal __all__ = [ - 'closest_hour', - 'local_time_offset', - 'get_interval', - 'time_now', 'time_local', - 'convert_time', - 'utc_ms_ts', - 'wait_till', - 'get_timezone', + "closest_hour", + "local_time_offset", + "get_interval", + "time_now", + "time_local", + "convert_time", + "utc_ms_ts", + "wait_till", + "get_timezone", ] @@ -19,26 +20,32 @@ def utc_ms_ts() -> int: def get_timezone(timezone_name: str) -> tzinfo: - if timezone_name.lower() == 'utc': + if timezone_name.lower() == "utc": return timezone.utc try: import zoneinfo + return zoneinfo.ZoneInfo(timezone_name) except ModuleNotFoundError: try: from backports import zoneinfo + return zoneinfo.ZoneInfo(timezone_name) # django > 4.0 except ModuleNotFoundError: try: - import pytz # noqa + import pytz # noqa + return pytz.timezone(timezone_name) except ModuleNotFoundError: - raise ModuleNotFoundError('You should install zoneinfo or pytz to use timezone feature') + raise ModuleNotFoundError( + "You should install zoneinfo or pytz to use timezone feature" + ) def wait_till(ts: Union[int, float, datetime], tick: float = None): import time + if isinstance(ts, datetime): ts = ts.timestamp() if tick is None: @@ -59,26 +66,33 @@ def wait_till(ts: Union[int, float, datetime], tick: float = None): def time_now(relative: datetime = None) -> datetime: from utilmeta.conf.time import Time + return (Time.config() or Time()).time_now(relative) def time_local(dt: datetime = None) -> datetime: from utilmeta.conf.time import Time + return (Time.config() or Time()).time_local(dt) def convert_time(dt: datetime) -> datetime: from utilmeta.conf.time import Time + return (Time.config() or Time()).convert_time(dt) -def get_interval(interval: Union[int, float, decimal.Decimal, timedelta], null: bool = False, - ge: Optional[Union[int, float, decimal.Decimal, timedelta]] = 0, silent: bool = False, - le: Optional[Union[int, float, decimal.Decimal, timedelta]] = None) -> Optional[float]: +def get_interval( + interval: Union[int, float, decimal.Decimal, timedelta], + null: bool = False, + ge: Optional[Union[int, float, decimal.Decimal, timedelta]] = 0, + silent: bool = False, + le: Optional[Union[int, float, decimal.Decimal, timedelta]] = None, +) -> Optional[float]: if interval is None: if null: return interval - raise TypeError(f'interval is not null') + raise TypeError(f"interval is not null") if isinstance(interval, (int, float, decimal.Decimal)): interval = float(interval) elif isinstance(interval, timedelta): @@ -86,27 +100,30 @@ def get_interval(interval: Union[int, float, decimal.Decimal, timedelta], null: else: if silent: return 0 - raise TypeError(f'Invalid interval: {interval}, must be int, float or timedelta object') + raise TypeError( + f"Invalid interval: {interval}, must be int, float or timedelta object" + ) if ge is not None: if not isinstance(ge, (int, float)): ge = get_interval(ge) if interval < ge: if silent: return ge - raise ValueError(f'Invalid interval: {interval}, must greater than {ge}') + raise ValueError(f"Invalid interval: {interval}, must greater than {ge}") if le is not None: if not isinstance(le, (int, float)): le = get_interval(le) if interval > le: if silent: return le - raise ValueError(f'Invalid interval: {interval}, must less than {le}') + raise ValueError(f"Invalid interval: {interval}, must less than {le}") return interval def local_time_offset(t=None): """Return offset of local zone from GMT, either at present or at time t.""" import time + if t is None: t = time.time() if time.localtime(t).tm_isdst and time.daylight: @@ -117,18 +134,10 @@ def local_time_offset(t=None): def closest_hour(dt: datetime) -> datetime: lo = datetime( - year=dt.year, - month=dt.month, - day=dt.day, - hour=dt.hour, - tzinfo=dt.tzinfo + year=dt.year, month=dt.month, day=dt.day, hour=dt.hour, tzinfo=dt.tzinfo ) hi = datetime( - year=dt.year, - month=dt.month, - day=dt.day, - hour=dt.hour, - tzinfo=dt.tzinfo + year=dt.year, month=dt.month, day=dt.day, hour=dt.hour, tzinfo=dt.tzinfo ) + timedelta(hours=1) if dt - lo > hi - dt: return hi diff --git a/utilmeta/utils/functional/web.py b/utilmeta/utils/functional/web.py index 17550d4..3bc0340 100644 --- a/utilmeta/utils/functional/web.py +++ b/utilmeta/utils/functional/web.py @@ -6,29 +6,57 @@ from datetime import datetime, timezone from ipaddress import ip_address from typing import TypeVar, List, Dict, Tuple, Union, Optional -from utilmeta.utils.constant import COMMON_ERRORS, RequestType, DateFormat, \ - LOCAL, Scheme, SCHEME, HTTP, HTTPS, SCHEMES, UTF_8, AgentDevice, LOCAL_IP +from utilmeta.utils.constant import ( + COMMON_ERRORS, + RequestType, + DateFormat, + LOCAL, + Scheme, + SCHEME, + HTTP, + HTTPS, + SCHEMES, + UTF_8, + AgentDevice, + LOCAL_IP, +) from urllib.parse import urlparse, ParseResult from .data import based_number, multi, is_number, get_number from .py import file_like from utype import type_transform -T = TypeVar('T') +T = TypeVar("T") __all__ = [ - 'http_header', - 'parse_query_dict', - 'make_header', 'http_time', - 'retrieve_path', 'get_domain', 'parse_user_agents', - 'get_origin', 'get_request_ip', - 'normalize', 'url_join', 'localhost', 'etag', - 'dumps', 'loads', 'process_url', - 'get_content_tag', 'handle_json_float', - 'parse_raw_url', 'json_dumps', 'get_hostname', 'parse_query_string', - 'encode_multipart_form', - 'valid_url', 'encode_query', 'is_hop_by_hop', - 'guess_mime_type', 'fast_digest', + "http_header", + "parse_query_dict", + "make_header", + "http_time", + "retrieve_path", + "get_domain", + "parse_user_agents", + "get_origin", + "get_request_ip", + "normalize", + "url_join", + "localhost", + "etag", + "dumps", + "loads", + "process_url", + "get_content_tag", + "handle_json_float", + "parse_raw_url", + "json_dumps", + "get_hostname", + "parse_query_string", + "encode_multipart_form", + "valid_url", + "encode_query", + "is_hop_by_hop", + "guess_mime_type", + "fast_digest", ] @@ -45,25 +73,31 @@ def http_time(dt: datetime, to_utc: bool = True): def guess_mime_type(path: str, strict: bool = False): if not path: return None, None - dots = path.split('.') + dots = path.split(".") if dots: suffix = dots[-1] type_map = { - 'js': 'text/javascript', - 'woff2': 'font/woff2', - 'ts': 'text/typescript', + "js": "text/javascript", + "woff2": "font/woff2", + "ts": "text/typescript", } if suffix in type_map: return type_map[suffix], None import mimetypes + return mimetypes.guess_type(path, strict=strict) def is_hop_by_hop(header_name): return header_name.lower() in { - 'connection', 'keep-alive', 'proxy-authenticate', - 'proxy-authorization', 'te', 'trailers', 'transfer-encoding', - 'upgrade' + "connection", + "keep-alive", + "proxy-authenticate", + "proxy-authorization", + "te", + "trailers", + "transfer-encoding", + "upgrade", } @@ -75,15 +109,20 @@ def handle_json_float(data): elif isinstance(data, float): if math.isnan(data): return None - if data == float('inf'): + if data == float("inf"): return "Infinity" - if data == float('-inf'): - return '-Infinity' + if data == float("-inf"): + return "-Infinity" return data -def fast_digest(value, compress: Union[bool, int] = False, case_insensitive: bool = False, - consistent: bool = True, mod: int = 2 ** 64): +def fast_digest( + value, + compress: Union[bool, int] = False, + case_insensitive: bool = False, + consistent: bool = True, + mod: int = 2**64, +): if consistent: encoded = value if isinstance(value, bytes) else str(value).encode() dig_mod = int(hashlib.md5(encoded).hexdigest(), 16) % mod @@ -98,7 +137,7 @@ def fast_digest(value, compress: Union[bool, int] = False, case_insensitive: boo elif isinstance(compress, int): base = compress else: - raise TypeError(f'Invalid compress: {compress}') + raise TypeError(f"Invalid compress: {compress}") return based_number(dig_mod, base) return str(dig_mod) @@ -115,97 +154,107 @@ def etag(data, weak: bool = False) -> str: comp = fast_digest(data, compress=36, consistent=True).lower() quoted = f'"{comp}"' if weak: - quoted = f'W/{quoted}' + quoted = f"W/{quoted}" return quoted def localhost(host: str) -> bool: if not isinstance(host, str): return False - if '://' not in host: + if "://" not in host: # can be http/https/unix/redis... - host = 'http://' + host + host = "http://" + host hostname = urlparse(host).hostname - return hostname in ['127.0.0.1', 'localhost'] + return hostname in ["127.0.0.1", "localhost"] -def encode_query(query: dict, exclude_null: bool = True, - multi_bracket_suffix: bool = False, - multi_comma_join: bool = False) -> str: +def encode_query( + query: dict, + exclude_null: bool = True, + multi_bracket_suffix: bool = False, + multi_comma_join: bool = False, +) -> str: if not query: - return '' + return "" from urllib.parse import urlencode + encoded = [] for key, val in query.items(): if val is None and exclude_null: continue if multi(val): if multi_comma_join: - val = ','.join(val) + val = ",".join(val) else: arg = key if multi_bracket_suffix: - arg += '[]' - encoded.extend([f'{arg}={v}' for v in val]) + arg += "[]" + encoded.extend([f"{arg}={v}" for v in val]) continue encoded.append(urlencode({key: val})) - return '&'.join(encoded) + return "&".join(encoded) def retrieve_path(url): parse: ParseResult = urlparse(url) if parse.scheme: - return url[len(f'{parse.scheme}{SCHEME}{parse.netloc}'):] - if url.startswith('/'): + return url[len(f"{parse.scheme}{SCHEME}{parse.netloc}") :] + if url.startswith("/"): return url - return f'/{url}' + return f"/{url}" def valid_url(url: str, raise_err: bool = True, http_only: bool = True) -> str: - url = url.strip('/').replace(' ', '') + url = url.strip("/").replace(" ", "") res = urlparse(url) if not res.scheme: if raise_err: - raise ValueError(f'Invalid url syntax: {url}') - return '' + raise ValueError(f"Invalid url syntax: {url}") + return "" if http_only and res.scheme not in (Scheme.HTTP, Scheme.HTTPS): if raise_err: - raise ValueError(f'Invalid scheme: {res.scheme}') - return '' + raise ValueError(f"Invalid scheme: {res.scheme}") + return "" if not res.netloc: if raise_err: - raise ValueError(f'empty net loc') - return '' + raise ValueError(f"empty net loc") + return "" return url def encode_multipart_form(form: dict, boundary: str = None) -> Tuple[bytes, str]: import binascii + boundary = boundary or binascii.hexlify(os.urandom(16)) if isinstance(boundary, str): - boundary = boundary.encode('ascii') + boundary = boundary.encode("ascii") items = [] for field, value in form.items(): key = str(field).encode() - beg = b"--%s\r\nContent-Disposition: form-data; name=\"%s\"" % (boundary, key) + beg = b'--%s\r\nContent-Disposition: form-data; name="%s"' % (boundary, key) files = value if multi(value) else [value] for i, val in enumerate(files): if file_like(val): content = val.read() if isinstance(content, str): content = content.encode() - filename = str(getattr(val, 'filename', None) or getattr(val, 'name', None) or '') + filename = str( + getattr(val, "filename", None) or getattr(val, "name", None) or "" + ) if filename: - if '/' in filename or '\\' in filename: + if "/" in filename or "\\" in filename: filename = os.path.basename(filename) else: - filename = f'{field}-file-{i}' - content_type = str(getattr(val, 'content_type', None) or '') + filename = f"{field}-file-{i}" + content_type = str(getattr(val, "content_type", None) or "") if not content_type: content_type, encoding = guess_mime_type(filename) if not content_type: content_type = RequestType.OCTET_STREAM - prep = b'; filename=\"%s\"\r\nContent-Type: %s' % (filename.encode(), content_type.encode()) + prep = b'; filename="%s"\r\nContent-Type: %s' % ( + filename.encode(), + content_type.encode(), + ) else: if isinstance(val, bytes): content = val @@ -213,10 +262,10 @@ def encode_multipart_form(form: dict, boundary: str = None) -> Tuple[bytes, str] content = json_dumps(val).encode() else: content = str(val).encode() - prep = b'' + prep = b"" items.append(b"%s%s\r\n\r\n%s\r\n" % (beg, prep, content)) body = b"".join(items) + b"--%s--\r\n" % boundary - content_type = "multipart/form-data; boundary=%s" % boundary.decode('ascii') + content_type = "multipart/form-data; boundary=%s" % boundary.decode("ascii") return body, content_type @@ -236,10 +285,16 @@ def encode_multipart_form(form: dict, boundary: str = None) -> Tuple[bytes, str] # return False -def get_origin(url: str, with_scheme: bool = True, remove_www_prefix: bool = False, convert_port: bool = False, - default_scheme: str = Scheme.HTTP, trans_local: bool = True): +def get_origin( + url: str, + with_scheme: bool = True, + remove_www_prefix: bool = False, + convert_port: bool = False, + default_scheme: str = Scheme.HTTP, + trans_local: bool = True, +): if not url: - return '' + return "" default_scheme = str(default_scheme).lower() if default_scheme not in SCHEMES: default_scheme = Scheme.HTTP @@ -250,36 +305,38 @@ def get_origin(url: str, with_scheme: bool = True, remove_www_prefix: bool = Fal host: str = result.netloc port = result.port if convert_port: - if port == '80': + if port == "80": port = None if not scheme: scheme = Scheme.HTTP elif scheme != Scheme.HTTP: - raise ValueError(f'Invalid scheme port combination: {scheme} {port}') - if port == '443': + raise ValueError(f"Invalid scheme port combination: {scheme} {port}") + if port == "443": port = None if not scheme: scheme = Scheme.HTTPS elif scheme != Scheme.HTTPS: - raise ValueError(f'Invalid scheme port combination: {scheme} {port}') + raise ValueError(f"Invalid scheme port combination: {scheme} {port}") if remove_www_prefix: - host = host.lstrip('www.') + host = host.lstrip("www.") if trans_local and host.startswith(LOCAL): - host = '127.0.0.1' + host = "127.0.0.1" if port: - host = f'{host}:{port}' + host = f"{host}:{port}" if not with_scheme: return host if not scheme: scheme = default_scheme - return f'{scheme}{SCHEME}{host}' + return f"{scheme}{SCHEME}{host}" def normalize(data, _json: bool = False): from utype.utils.compat import ATOM_TYPES + if isinstance(data, ATOM_TYPES): return data import pickle + try: if _json: raise ValueError @@ -291,10 +348,11 @@ def normalize(data, _json: bool = False): def json_dumps(data, **kwargs) -> str: from utype import JSONEncoder + if data is None: - return '' - kwargs.setdefault('cls', JSONEncoder) - kwargs.setdefault('ensure_ascii', False) + return "" + kwargs.setdefault("cls", JSONEncoder) + kwargs.setdefault("ensure_ascii", False) return json.dumps(data, **kwargs) @@ -302,12 +360,15 @@ def dumps(data, exclude_types: Tuple[type, ...] = (), bulk_data: bool = False): if data is None: return None if bulk_data and isinstance(data, dict): - return {key: dumps(val, exclude_types=exclude_types) for key, val in data.items()} + return { + key: dumps(val, exclude_types=exclude_types) for key, val in data.items() + } if isinstance(data, exclude_types): # False / True isinstance of int, so isinstance is not accurate here # for incrby / decrby / incrbyfloat work fine at lua script number typed data will not be dump return str(data).encode() import pickle + return pickle.dumps(normalize(data)) @@ -326,6 +387,7 @@ def loads(data, exclude_types: Tuple[type, ...] = (), bulk_data: bool = False): values[key] = loads(val, exclude_types=exclude_types, bulk_data=bulk_data) return values import pickle + try: return pickle.loads(data) except (*COMMON_ERRORS, pickle.PickleError): @@ -346,7 +408,7 @@ def loads(data, exclude_types: Tuple[type, ...] = (), bulk_data: bool = False): def get_content_tag(body): if not body: - return '' + return "" tag = body if isinstance(body, (dict, list)): tag = json_dumps(body).encode(UTF_8) @@ -355,26 +417,31 @@ def get_content_tag(body): return hashlib.sha256(tag).hexdigest() -def url_join(base: str, *routes: str, with_scheme: bool = True, - prepend_slash: bool = False, append_slash: bool = None): +def url_join( + base: str, + *routes: str, + with_scheme: bool = True, + prepend_slash: bool = False, + append_slash: bool = None, +): if not base: # force convert to str - base = '' + base = "" route_list = [] if not any(routes): if append_slash: - if not base.endswith('/'): - base = base + '/' + if not base.endswith("/"): + base = base + "/" elif append_slash is False: - base = base.rstrip('/') + base = base.rstrip("/") return base if not isinstance(base, str): - base = str(base) if base else '' + base = str(base) if base else "" final_route = base for route in routes: if not route: continue - url = str(route).strip('/') + url = str(route).strip("/") if not url: continue final_route = str(route) @@ -382,43 +449,45 @@ def url_join(base: str, *routes: str, with_scheme: bool = True, if url_res.scheme: return url route_list.append(url) - end_slash = final_route.endswith('/') # last route + end_slash = final_route.endswith("/") # last route res = urlparse(base) if with_scheme and not res.scheme: - raise ValueError('base url must specify a valid scheme') - result = '/'.join([base.strip('/'), *route_list]) + raise ValueError("base url must specify a valid scheme") + result = "/".join([base.strip("/"), *route_list]) if not res.scheme: if prepend_slash: - if not result.startswith('/'): - result = '/' + result + if not result.startswith("/"): + result = "/" + result else: - result = result.lstrip('/') + result = result.lstrip("/") if append_slash is not None: if append_slash: - if not result.endswith('/'): - result = result + '/' + if not result.endswith("/"): + result = result + "/" else: - result = result.rstrip('/') + result = result.rstrip("/") elif end_slash: - result = result + '/' + result = result + "/" return result def process_url(url: Union[str, List[str]]): if multi(url): return [process_url(u) for u in url if u] - url = url.strip('/') - return f'/{url}/' if url else '/' + url = url.strip("/") + return f"/{url}/" if url else "/" def parse_query_string(qs: str) -> dict: from urllib.parse import parse_qs + return parse_query_dict(parse_qs(qs)) def parse_raw_url(url: str) -> Tuple[str, dict]: res = urlparse(url) from django.http.request import QueryDict + return res.path, parse_query_dict(QueryDict(res.query)) @@ -426,20 +495,20 @@ def parse_query_dict(qd: Dict[str, List[str]]) -> dict: data = {} for key, val in dict(qd).items(): if not multi(val): - data[key] = str(val or '') + data[key] = str(val or "") continue - if key.endswith('[]'): - data[key.rstrip('[]')] = val + if key.endswith("[]"): + data[key.rstrip("[]")] = val continue if len(val) > 1: data[key] = val continue v = val[0] - if v.startswith('='): - if key.endswith('>') or key.endswith('<'): - data[key + '='] = v[1:] + if v.startswith("="): + if key.endswith(">") or key.endswith("<"): + data[key + "="] = v[1:] continue - data[key] = v or '' + data[key] = v or "" return data @@ -466,22 +535,22 @@ def parse_user_agents(ua_string: str) -> Optional[dict]: elif ua.is_email_client: device = AgentDevice.email return dict( - browser=f'{ua.browser.family} {ua.browser.version_string}'.strip(' '), - os=f'{ua.os.family} {ua.os.version_string}'.strip(' '), + browser=f"{ua.browser.family} {ua.browser.version_string}".strip(" "), + os=f"{ua.os.family} {ua.os.version_string}".strip(" "), mobile=ua.is_mobile, bot=ua.is_bot, - device=device + device=device, ) def http_header(header: str) -> str: - h = header.upper().replace('-', '_') - if h in {'CONTENT_TYPE', 'CONTENT_LENGTH'}: + h = header.upper().replace("-", "_") + if h in {"CONTENT_TYPE", "CONTENT_LENGTH"}: return h - if h.startswith('HTTP_'): + if h.startswith("HTTP_"): # make idempotent return h - return 'HTTP_' + h + return "HTTP_" + h def make_header(h: T) -> T: @@ -492,13 +561,13 @@ def make_header(h: T) -> T: return {make_header(key): val for key, val in h.items()} elif multi(h): return [make_header(v) for v in h] - return '-'.join([s.capitalize() for s in str(h).lower().split('_')]) + return "-".join([s.capitalize() for s in str(h).lower().split("_")]) def get_netloc(url: str) -> str: # contains port if not url: - return '' + return "" res = urlparse(url) if res.netloc: return res.netloc @@ -509,7 +578,7 @@ def get_netloc(url: str) -> str: def get_hostname(url: str) -> str: # does not contains port if not url: - return '' + return "" res = urlparse(url) if res.hostname: return res.hostname @@ -524,13 +593,16 @@ def get_domain(url: str) -> Optional[str]: ip_address(hostname) return None except ValueError: - return '.'.join(hostname.split('.')[-2:]) + return ".".join(hostname.split(".")[-2:]) def get_request_ip(headers: dict): - headers = {str(k).lower().replace('_', '-'): v for k, v in headers.items()} - ips = [*headers.get('x-forwarded-for', '').replace(' ', '').split(','), - headers.get('remote-addr'), headers.get('x-real-ip')] + headers = {str(k).lower().replace("_", "-"): v for k, v in headers.items()} + ips = [ + *headers.get("x-forwarded-for", "").replace(" ", "").split(","), + headers.get("remote-addr"), + headers.get("x-real-ip"), + ] for ip in ips: if not ip or ip == LOCAL_IP: continue diff --git a/utilmeta/utils/logical.py b/utilmeta/utils/logical.py index a25b5b6..4521108 100644 --- a/utilmeta/utils/logical.py +++ b/utilmeta/utils/logical.py @@ -4,7 +4,7 @@ from typing import Dict, Any, List, Optional, Union -__all__ = ['LogicUtil'] +__all__ = ["LogicUtil"] class LogicUtil(Util, metaclass=Meta): @@ -65,8 +65,11 @@ def apply_logic(self, *args, **kwargs): if xor is None: xor = con else: - errors.append(self.XOR_ERROR_CLS( - f'More than 1 conditions ({xor}, {con}) is True in XOR conditions')) + errors.append( + self.XOR_ERROR_CLS( + f"More than 1 conditions ({xor}, {con}) is True in XOR conditions" + ) + ) xor = None break except Exception as e: @@ -79,13 +82,16 @@ def apply_logic(self, *args, **kwargs): try: con = self._conditions[0] result = con(*args, **kwargs) - errors.append(self.XOR_ERROR_CLS(f'Negate condition: {con} is violated')) + errors.append( + self.XOR_ERROR_CLS(f"Negate condition: {con} is violated") + ) except Exception as e: # use error as result result = self._get_error_result(e, *args, **kwargs) if errors: # apply negate from .error import Error + err = exc.CombinedError(*errors) if len(errors) > 1 else errors[0] raise Error(err).throw() return result @@ -98,8 +104,9 @@ def _get_error_result(self, err, *args, **kwargs): # noqa def _combine(self, other, operator): name = self.__class__.__name__ - assert isinstance(other, self.__class__), \ - f"{name} instance must combine with other {name} instance, got {other}" + assert isinstance( + other, self.__class__ + ), f"{name} instance must combine with other {name} instance, got {other}" util = self.__class__() util._operator = operator if self._operator == operator: @@ -114,7 +121,12 @@ def _combine(self, other, operator): def _repr(self, params: List[str] = None, excludes: List[str] = None): if self._logic_applied: - return self._operator.join([(f'({str(c)})' if c._logic_applied else str(c)) for c in self._conditions]) + return self._operator.join( + [ + (f"({str(c)})" if c._logic_applied else str(c)) + for c in self._conditions + ] + ) return f'{Logic.NOT if self._negate else ""}{super()._repr(params=params, excludes=excludes)}' def __copy__(self): @@ -125,7 +137,7 @@ def __copy__(self): util._conditions = self._copy(self._conditions) return util - def __eq__(self, other: 'LogicUtil'): + def __eq__(self, other: "LogicUtil"): if not isinstance(other, self.__class__): return False if self._operator: @@ -136,13 +148,13 @@ def __eq__(self, other: 'LogicUtil'): return False return super(LogicUtil, self).__eq__(other) - def __or__(self, other: 'LogicUtil'): + def __or__(self, other: "LogicUtil"): return self._combine(other, Logic.OR) - def __xor__(self, other: 'LogicUtil'): + def __xor__(self, other: "LogicUtil"): return self._combine(other, Logic.XOR) - def __and__(self, other: 'LogicUtil'): + def __and__(self, other: "LogicUtil"): return self._combine(other, Logic.AND) def __invert__(self): diff --git a/utilmeta/utils/plugin.py b/utilmeta/utils/plugin.py index 176ed79..3c4ddbd 100644 --- a/utilmeta/utils/plugin.py +++ b/utilmeta/utils/plugin.py @@ -5,6 +5,7 @@ from typing import Type, Dict, List, Callable, Iterator, Union, Tuple from functools import partial, wraps from utype.parser.func import FunctionParser + # from .context import Property from collections import OrderedDict @@ -32,13 +33,16 @@ def omit_unsupported_params(f, asynchronous: bool = None): return f if inspect.iscoroutinefunction(f) or asynchronous: + @wraps(f) async def wrapper(*args, **kwargs): r = f(*args[:args_num], **{k: v for k, v in kwargs.items() if k in keys}) if inspect.isawaitable(r): return await r return r + else: + @wraps(f) def wrapper(*args, **kwargs): return f(*args[:args_num], **{k: v for k, v in kwargs.items() if k in keys}) @@ -50,7 +54,11 @@ class PluginBase(Util): def __new__(cls, _kw=None, *args, **kwargs): instance = super().__new__(cls) if not args and not kwargs: - if isinstance(_kw, type) and issubclass(_kw, PluginTarget) or isinstance(_kw, PluginTarget): + if ( + isinstance(_kw, type) + and issubclass(_kw, PluginTarget) + or isinstance(_kw, PluginTarget) + ): # @plugin # without init # def APIClass(API): # pass @@ -66,21 +74,21 @@ def __init__(self, _kw=None, *args, **kwargs): super().__init__(_kw or kwargs) if inspect.isclass(self): try: - self.__ref__ = f'{self.__module__}.{self.__qualname__}' + self.__ref__ = f"{self.__module__}.{self.__qualname__}" except AttributeError: - self.__ref__ = f'{self.__module__}.{self.__name__}' + self.__ref__ = f"{self.__module__}.{self.__name__}" else: - self.__name__ = f'{self.__class__.__name__}(...)' + self.__name__ = f"{self.__class__.__name__}(...)" self.__ref__ = None self.__args__ = args @classmethod - def apply_for(cls, target: 'PluginTarget') -> 'PluginBase': + def apply_for(cls, target: "PluginTarget") -> "PluginBase": return cls() def __call__(self, func, *args, **kwargs): if inspect.isfunction(func): - if getattr(func, 'plugins', None): + if getattr(func, "plugins", None): if self not in func.plugins: # the later the plugin was decorated # the earlier it will be applied @@ -89,13 +97,18 @@ def __call__(self, func, *args, **kwargs): func.plugins.append(self) else: func.plugins = [self] - elif inspect.isclass(func) and issubclass(func, PluginTarget) or isinstance(func, PluginTarget): + elif ( + inspect.isclass(func) + and issubclass(func, PluginTarget) + or isinstance(func, PluginTarget) + ): func._add_plugins(self) return func @classmethod - def initialize(cls, params: dict = None, - default_value=None, ignore_required: bool = False): + def initialize( + cls, params: dict = None, default_value=None, ignore_required: bool = False + ): args = [] kwargs = {} extras = {} @@ -108,19 +121,19 @@ def get_value(_name: str, _class: Type = cls): return attr_value.initialize( params=params.get(key), default_value=default_value, - ignore_required=ignore_required + ignore_required=ignore_required, ) elif issubclass(attr_value, dict): # like Schema attr_dict = {} for n in list(attr_value.__dict__): - if n.startswith('_'): + if n.startswith("_"): continue if hasattr(dict, n): continue attr_dict[n] = get_value(n, _class=attr_value) - inst = attr_value(**attr_dict) # use dict to initialize + inst = attr_value(**attr_dict) # use dict to initialize for k, v in attr_dict.items(): - setattr(inst, k, v) # set to the new attribute + setattr(inst, k, v) # set to the new attribute return inst elif inspect.isdatadescriptor(attr_value): # has __set__ or __delete__ # delete the attribute before initialize @@ -138,7 +151,7 @@ def get_value(_name: str, _class: Type = cls): if ignore_required: args.append(default_value) continue - raise TypeError(f'{cls} required arg: {repr(key)} not defined') + raise TypeError(f"{cls} required arg: {repr(key)} not defined") args.append(get_value(key)) for key in cls._kw_keys: @@ -152,7 +165,7 @@ def get_value(_name: str, _class: Type = cls): if ignore_required: kwargs[key] = default_value continue - raise TypeError(f'{cls} required arg: {repr(key)} not defined') + raise TypeError(f"{cls} required arg: {repr(key)} not defined") if key in kwargs: continue kwargs[key] = get_value(key) @@ -171,7 +184,7 @@ def get_value(_name: str, _class: Type = cls): kwargs.update(extras) if cls._pos_var: - ext_args = params.get('@args') + ext_args = params.get("@args") if ext_args: args.extend(ext_args) @@ -187,9 +200,13 @@ def get_value(_name: str, _class: Type = cls): class PluginEvent: function_parser_cls = FunctionParser - def __init__(self, name: str, streamline_result: bool = False, - synchronous_only: bool = False, - asynchronous_only: bool = False): + def __init__( + self, + name: str, + streamline_result: bool = False, + synchronous_only: bool = False, + asynchronous_only: bool = False, + ): self.name = name self.streamline_result = streamline_result self.synchronous_only = synchronous_only @@ -198,7 +215,9 @@ def __init__(self, name: str, streamline_result: bool = False, self._hooks: Dict[Type, List[tuple]] = {} self._callback_hooks = {} - def __call__(self, inst: Union['PluginTarget', Type['PluginTarget']], *args, **kwargs): + def __call__( + self, inst: Union["PluginTarget", Type["PluginTarget"]], *args, **kwargs + ): # inst can be PluginTarget instance or class result = None if self.streamline_result: @@ -218,7 +237,9 @@ def __call__(self, inst: Union['PluginTarget', Type['PluginTarget']], *args, **k return result @awaitable(__call__) - async def __call__(self, inst: Union['PluginTarget', Type['PluginTarget']], *args, **kwargs): + async def __call__( + self, inst: Union["PluginTarget", Type["PluginTarget"]], *args, **kwargs + ): # inst can be PluginTarget instance or class result = None if self.streamline_result: @@ -254,20 +275,23 @@ def get_hooks(self, target): cls_hooks.extend(self._hooks.get(target_cls)) return cls_hooks - def get(self, plugin: PluginBase, target: 'PluginTarget' = None, asynchronous=None): + def get(self, plugin: PluginBase, target: "PluginTarget" = None, asynchronous=None): handler = getattr(plugin, self.name, None) if callable(handler) and not function_pass(handler): handler = omit_unsupported_params(handler, asynchronous=asynchronous) if target: + @wraps(handler) def target_handler(*args, **kwargs): return handler(*args, target, **kwargs) + return target_handler return handler return None - def iter(self, *targets: 'PluginTarget', - asynchronous: bool = None, reverse: bool = False) -> Iterator[Callable]: + def iter( + self, *targets: "PluginTarget", asynchronous: bool = None, reverse: bool = False + ) -> Iterator[Callable]: # accept iterate over more than 1 target (eg. API/Client + Endpoint) _classes = set() for target in reversed(targets) if reverse else targets: @@ -278,7 +302,9 @@ def iter(self, *targets: 'PluginTarget', if not plugins or not isinstance(plugins, dict): continue hooks = self.get_hooks(target) - for plugin_cls, plugin in reversed(plugins.items()) if reverse else plugins.items(): + for plugin_cls, plugin in ( + reversed(plugins.items()) if reverse else plugins.items() + ): if plugin_cls in _classes: # in case for more than 1 plugin target continue @@ -291,7 +317,8 @@ def iter(self, *targets: 'PluginTarget', _classes.add(plugin_cls) yield partial( omit_unsupported_params(func, asynchronous=asynchronous), - plugin, **partial_kw + plugin, + **partial_kw, ) if hooked: continue @@ -306,7 +333,7 @@ def iter(self, *targets: 'PluginTarget', def register(self, target_class): if not inspect.isclass(target_class): - raise TypeError(f'Invalid register class: {target_class}, must be a class') + raise TypeError(f"Invalid register class: {target_class}, must be a class") if target_class not in self._hooks: self._hooks.setdefault(target_class, []) @@ -317,9 +344,13 @@ def unregister(self, target_class): def make_callable(self, func, target_class): func = self.function_parser_cls.apply_for(func) if self.synchronous_only and func.is_asynchronous: - raise TypeError(f'PluginEvent: {self.name} is synchronous only, got async function: {func}') + raise TypeError( + f"PluginEvent: {self.name} is synchronous only, got async function: {func}" + ) if self.asynchronous_only and not func.is_asynchronous: - raise TypeError(f'PluginEvent: {self.name} is asynchronous only, got sync function: {func}') + raise TypeError( + f"PluginEvent: {self.name} is asynchronous only, got sync function: {func}" + ) target_arg = None for key, field in func.fields.items(): annotate = field.type @@ -334,40 +365,52 @@ def make_callable(self, func, target_class): # target_arg = _arg return func, target_arg - def add_callback_hook(self, func, target_class, priority=0, registered_only: bool = False): - if not inspect.isclass(target_class): # or not issubclass(target_class, PluginTarget): - raise ValueError(f'{self.name}.hook target_class: {target_class} must be subclass of PluginTarget') + def add_callback_hook( + self, func, target_class, priority=0, registered_only: bool = False + ): + if not inspect.isclass( + target_class + ): # or not issubclass(target_class, PluginTarget): + raise ValueError( + f"{self.name}.hook target_class: {target_class} must be subclass of PluginTarget" + ) if registered_only and target_class not in self._hooks: - raise ValueError(f'{self.name}.hook target_class: {target_class} not registered') + raise ValueError( + f"{self.name}.hook target_class: {target_class} not registered" + ) if target_class not in self._hooks: self.register(target_class) func, target_arg = self.make_callable(func, target_class=target_class) - item = ( - func, - target_arg, - priority - ) + item = (func, target_arg, priority) if item not in self._callback_hooks[target_class]: self._callback_hooks[target_class].append(item) self._callback_hooks[target_class].sort(key=lambda tup: -tup[-1]) - def add_plugin_hook(self, func, target_class, plugin_class, priority=0, registered_only: bool = False): - if not inspect.isclass(target_class): # or not issubclass(target_class, PluginTarget): - raise ValueError(f'{self.name}.hook target_class: {target_class} must be subclass of PluginTarget') + def add_plugin_hook( + self, + func, + target_class, + plugin_class, + priority=0, + registered_only: bool = False, + ): + if not inspect.isclass( + target_class + ): # or not issubclass(target_class, PluginTarget): + raise ValueError( + f"{self.name}.hook target_class: {target_class} must be subclass of PluginTarget" + ) if registered_only and target_class not in self._hooks: - raise ValueError(f'{self.name}.hook target_class: {target_class} not registered') + raise ValueError( + f"{self.name}.hook target_class: {target_class} not registered" + ) if target_class not in self._hooks: self.register(target_class) func, target_arg = self.make_callable(func, target_class=target_class) - item = ( - plugin_class, - func, - target_arg, - priority - ) + item = (plugin_class, func, target_arg, priority) if item not in self._hooks[target_class]: self._hooks[target_class].append(item) self._hooks[target_class].sort(key=lambda tup: -tup[-1]) @@ -380,9 +423,17 @@ def hook_callback(self, target_class, priority=0): def wrapper(f): self.add_callback_hook(f, target_class=target_class, priority=priority) return f + return wrapper - def hook(self, target_class, plugin_class=None, *, priority=0, registered_only: bool = False): + def hook( + self, + target_class, + plugin_class=None, + *, + priority=0, + registered_only: bool = False, + ): def wrapper(f): if function_pass(f): return f @@ -395,7 +446,7 @@ def wrapper(f): plugin = v.annotation else: # like Type[Class] - _origin = getattr(plugin, '__origin__', None) + _origin = getattr(plugin, "__origin__", None) if _origin == type: _arg = plugin.__args__[0] if inspect.isclass(_arg): @@ -403,11 +454,19 @@ def wrapper(f): else: break if not plugin: - raise ValueError(f'{self.name}.hook does not specify plugin_class (either by param or annotation)') - self.add_plugin_hook(f, target_class, plugin_class=plugin, - priority=priority, registered_only=registered_only) + raise ValueError( + f"{self.name}.hook does not specify plugin_class (either by param or annotation)" + ) + self.add_plugin_hook( + f, + target_class, + plugin_class=plugin, + priority=priority, + registered_only=registered_only, + ) return f # target_class._add_plugin_hook(self, f, plugin_class, priority=priority) + return wrapper @@ -427,15 +486,15 @@ class PluginTarget: _plugins: OrderedDict = OrderedDict() def __init_subclass__(cls, **kwargs): - cls.__ref__ = f'{cls.__module__}.{cls.__qualname__}' + cls.__ref__ = f"{cls.__module__}.{cls.__qualname__}" for key, val in cls.__annotations__.items(): - if inspect.isclass(val) and issubclass(val, PluginBase): # fixed plugins + if inspect.isclass(val) and issubclass(val, PluginBase): # fixed plugins cls._fixed_plugins[key] = val plugins = OrderedDict(cls._plugins) for slot in list(cls.__dict__): - if slot.startswith('_'): + if slot.startswith("_"): continue util = cls.__dict__[slot] @@ -449,18 +508,22 @@ def __init_subclass__(cls, **kwargs): if slot in cls._fixed_plugins: plugin_cls = cls._fixed_plugins.get(slot) if not isinstance(util, plugin_cls): - raise TypeError(f'{cls}.{slot} must be a {plugin_cls} instance, got {util}') + raise TypeError( + f"{cls}.{slot} must be a {plugin_cls} instance, got {util}" + ) # else: # if isinstance(util, tuple(cls._fixed_plugins.values())): # # if a util other than # continue if isinstance(util, PluginBase): - path = f'{cls.__ref__}.{slot}' + path = f"{cls.__ref__}.{slot}" if util.__ref__: if util.__ref__ != path: - warnings.warn(f'{cls} same util: {util} mount to different ' - f'path: {repr(path)}, {repr(util.__ref__)}') + warnings.warn( + f"{cls} same util: {util} mount to different " + f"path: {repr(path)}, {repr(util.__ref__)}" + ) else: util.__ref__ = path else: @@ -486,7 +549,9 @@ def _add_plugins(cls, *plugins): elif isinstance(plugin, PluginBase): plugin_dict[plugin.__class__] = plugin continue - warnings.warn(f'{cls}: add invalid plugin: {plugin}, must be a {PluginBase} subclass of instance') + warnings.warn( + f"{cls}: add invalid plugin: {plugin}, must be a {PluginBase} subclass of instance" + ) cls._plugins.update(plugin_dict) def _plugin(self, plugin, setdefault=False): @@ -503,7 +568,9 @@ def _plugin(self, plugin, setdefault=False): else: self._plugins[plugin.__class__] = plugin return - warnings.warn(f'{self}: add invalid plugin: {plugin}, must be a {PluginBase} subclass of instance') + warnings.warn( + f"{self}: add invalid plugin: {plugin}, must be a {PluginBase} subclass of instance" + ) @classmethod def _get_plugin(cls, plugin_class): diff --git a/utilmeta/utils/schema/backends/attrs.py b/utilmeta/utils/schema/backends/attrs.py index 1320c00..4f0f609 100644 --- a/utilmeta/utils/schema/backends/attrs.py +++ b/utilmeta/utils/schema/backends/attrs.py @@ -2,7 +2,7 @@ from collections.abc import Mapping -@register_transformer(attr='__attrs_attrs__') +@register_transformer(attr="__attrs_attrs__") def transform_attrs(transformer: TypeTransformer, data, cls): if not transformer.no_explicit_cast and not isinstance(data, (dict, Mapping)): data = transformer.to_dict(data) @@ -11,7 +11,7 @@ def transform_attrs(transformer: TypeTransformer, data, cls): return cls(**data) -@register_encoder(attr='__attrs_attrs__') +@register_encoder(attr="__attrs_attrs__") def encode_attrs(encoder, data): values = {} for v in data.__attrs_attrs__: diff --git a/utilmeta/utils/schema/backends/dataclass.py b/utilmeta/utils/schema/backends/dataclass.py index 6cce9e5..134479d 100644 --- a/utilmeta/utils/schema/backends/dataclass.py +++ b/utilmeta/utils/schema/backends/dataclass.py @@ -2,7 +2,7 @@ from collections.abc import Mapping -@register_transformer(attr='__dataclass_fields__') +@register_transformer(attr="__dataclass_fields__") def transform_attrs(transformer: TypeTransformer, data, cls): if not transformer.no_explicit_cast and not isinstance(data, (dict, Mapping)): data = transformer(data, dict) @@ -10,7 +10,7 @@ def transform_attrs(transformer: TypeTransformer, data, cls): return cls(**data) -@register_encoder(attr='__dataclass_fields__') +@register_encoder(attr="__dataclass_fields__") def transform_attrs(encoder, data): values = {} for k in data.__dataclass_fields__: diff --git a/utilmeta/utils/schema/base.py b/utilmeta/utils/schema/base.py index f1f5bdc..9a77e68 100644 --- a/utilmeta/utils/schema/base.py +++ b/utilmeta/utils/schema/base.py @@ -1,4 +1,3 @@ - class SchemaAdaptor: def get_properties(self): pass