diff --git a/tgram/__init__.py b/tgram/__init__.py index 5b4ea05..de7d681 100644 --- a/tgram/__init__.py +++ b/tgram/__init__.py @@ -1,4 +1,4 @@ -__all__ = ["types", "TgBot", "handlers", "filters"] +__all__ = ["types", "TgBot", "handlers", "filters", "compose", "StopPropagation"] __version__ = "1.10.7" __author__ = "Zaid" @@ -7,3 +7,6 @@ from . import types, handlers, filters from .tgbot import TgBot +from .sync import compose + +from .errors import StopPropagation diff --git a/tgram/errors.py b/tgram/errors.py index 4216dea..8f579cf 100644 --- a/tgram/errors.py +++ b/tgram/errors.py @@ -1,2 +1,6 @@ class APIException(Exception): pass + + +class StopPropagation(Exception): + pass diff --git a/tgram/sync.py b/tgram/sync.py index 0393c34..9b173f1 100644 --- a/tgram/sync.py +++ b/tgram/sync.py @@ -4,6 +4,8 @@ import threading import logging +from tgram import utils + logger = logging.getLogger(__name__) @@ -107,3 +109,7 @@ def wrap(source): method ): async_to_sync(source, name) + + +async_to_sync(utils, "compose") +compose = utils.compose diff --git a/tgram/tgbot.py b/tgram/tgbot.py index 4c2b0bb..1084d47 100644 --- a/tgram/tgbot.py +++ b/tgram/tgbot.py @@ -29,10 +29,6 @@ class Dispatcher: - _is_running = False - _handlers: List["tgram.handlers.Handler"] = [] - _listen_handlers: List["tgram.types.Listener"] = [] - async def run_for_updates(self: "TgBot", skip_updates: bool = True) -> None: if self.plugins: self.load_plugins() @@ -41,10 +37,10 @@ async def run_for_updates(self: "TgBot", skip_updates: bool = True) -> None: self.allowed_updates, 100, ) - self._is_running = True + self.is_running = True self.me = await self.get_me() - while self._is_running: + while self.is_running: try: updates = await self.get_updates( offset=offset, @@ -56,13 +52,19 @@ async def run_for_updates(self: "TgBot", skip_updates: bool = True) -> None: offset = update.update_id + 1 await self._check_update(update) except (asyncio.CancelledError, KeyboardInterrupt): - self._is_running = False + self.is_running = False + except tgram.StopPropagation: + pass except Exception as e: logger.exception(e) session = await self._get_session() await session.close() + async def stop(self) -> Literal[True]: + self.is_running = False + return True + async def _check_cancel(self: "TgBot", callback: Callable, update: Any) -> bool: logger.debug("Checking listener in %s func", callback.__name__) try: @@ -121,11 +123,6 @@ async def _process_update(self: "TgBot", update: Any, callback: Callable) -> Non class TgBot(TelegramBotMethods, Decorators, Dispatcher): - me: "tgram.types.User" = None - _session: "aiohttp.ClientSession" = None - _api_url: str = None - _custom_types: dict = {} - def __init__( self, bot_token: str, @@ -150,10 +147,18 @@ def __init__( self.executor = ThreadPoolExecutor(self.workers, thread_name_prefix="Handlers") self.loop = asyncio.get_event_loop() + self.is_running: bool = None + self.me: "tgram.types.User" = None + + self._listen_handlers: List["tgram.types.Listener"] = [] + self._handlers: List["tgram.handlers.Handler"] = [] + self._custom_types: dict = {} + self._session: "aiohttp.ClientSession" = None + if not api_url.endswith("/"): api_url += "/" - self._api_url = f"{api_url}bot{bot_token}/" + self._api_url: str = f"{api_url}bot{bot_token}/" def add_handler(self, handler: "tgram.handlers.Handler") -> None: if handler.type == "all": @@ -226,7 +231,7 @@ async def _send_request(self, method: str, **kwargs) -> Any: ), ) - if not self._is_running: + if not self.is_running: await session.close() response_json = await response.json() diff --git a/tgram/utils.py b/tgram/utils.py index 60f80ab..c18d5a1 100644 --- a/tgram/utils.py +++ b/tgram/utils.py @@ -2,6 +2,7 @@ import tgram import re import html +import asyncio from pathlib import Path from typing import List, Union @@ -315,3 +316,9 @@ def recursive(entity_i: int) -> int: last_offset = offset return remove_surrogates(text) + + +async def compose(bots: List["tgram.TgBot"]): + tasks = [asyncio.create_task(bot.run_for_updates()) for bot in bots] + + return await asyncio.wait(tasks)