From 52b3f09cc83be8de7a4bd34b62a66dc626a62a06 Mon Sep 17 00:00:00 2001 From: Shrimadhav U K Date: Sat, 12 Oct 2024 12:38:32 +0530 Subject: [PATCH] Update Pyrogram to v2.1.33.9 *Experimental Changes (#98) * (1): Experimental Change * (2): Experimental Change * Revert some of the changes by deus-developer. This reverts commit 20f4f3c8 and 983d396 * Add Python 3.13 support Co-authored-by: Artem Ukolov <43943664+deus-developer@users.noreply.github.com> Co-authored-by: GautamKumar Co-authored-by: wulan17 --- .github/workflows/build-docs.yml | 4 +- .github/workflows/publish.yml | 16 +- .github/workflows/python.yml | 6 +- pyproject.toml | 1 + pyrogram/__init__.py | 2 +- pyrogram/client.py | 4 +- pyrogram/connection/connection.py | 65 +++---- pyrogram/connection/transport/tcp/__init__.py | 2 +- pyrogram/connection/transport/tcp/tcp.py | 168 +++++++----------- .../connection/transport/tcp/tcp_abridged.py | 10 +- .../transport/tcp/tcp_abridged_o.py | 12 +- pyrogram/connection/transport/tcp/tcp_full.py | 12 +- .../transport/tcp/tcp_intermediate.py | 10 +- .../transport/tcp/tcp_intermediate_o.py | 12 +- pyrogram/session/auth.py | 38 ++-- pyrogram/session/session.py | 165 +++++++---------- 16 files changed, 219 insertions(+), 308 deletions(-) diff --git a/.github/workflows/build-docs.yml b/.github/workflows/build-docs.yml index cbd6c8f8a..98ad59f18 100644 --- a/.github/workflows/build-docs.yml +++ b/.github/workflows/build-docs.yml @@ -9,11 +9,11 @@ jobs: name: build-doc runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: fetch-depth: 1 - name: Set up Python - uses: actions/setup-python@v3 + uses: actions/setup-python@v5 with: python-version: '3.11' diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 53d1e330f..62728659c 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -19,13 +19,16 @@ permissions: jobs: deploy: - runs-on: ubuntu-latest + environment: release + permissions: + id-token: write + steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v3 + uses: actions/setup-python@v5 with: python-version: '3.10' - name: Install dependencies @@ -38,8 +41,5 @@ jobs: - name: Build package run: hatch build - name: Publish package - env: - HATCH_INDEX_USER: __token__ - HATCH_INDEX_AUTH: ${{ secrets.PYPI_API_TOKEN }} - run: | - hatch publish + uses: pypa/gh-action-pypi-publish@release/v1 + diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 6ad8fe3c8..012d2f1c0 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -13,13 +13,13 @@ jobs: strategy: matrix: os: [ubuntu-22.04, macos-12] - python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"] + python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12", "3.13"] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} diff --git a/pyproject.toml b/pyproject.toml index 098feed1f..55138ca6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ classifiers = [ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Programming Language :: Python :: Implementation", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", diff --git a/pyrogram/__init__.py b/pyrogram/__init__.py index 5231a113d..f23bb17c6 100644 --- a/pyrogram/__init__.py +++ b/pyrogram/__init__.py @@ -17,7 +17,7 @@ # along with Pyrogram. If not, see . __fork_name__ = "pyrotgfork" -__version__ = "2.1.33.8" +__version__ = "2.1.33.9" __license__ = "GNU Lesser General Public License v3.0 (LGPL-3.0)" __copyright__ = "Copyright (C) 2017-present Dan " diff --git a/pyrogram/client.py b/pyrogram/client.py index 535e99d17..59a671349 100644 --- a/pyrogram/client.py +++ b/pyrogram/client.py @@ -55,7 +55,7 @@ from pyrogram.types import User, TermsOfService from pyrogram.utils import ainput from .connection import Connection -from .connection.transport import TCP, TCPAbridged +from .connection.transport import TCP, TCPAbridged, TCPFull from .dispatcher import Dispatcher from .file_id import FileId, FileType, ThumbnailSource from .mime_types import mime_types @@ -339,7 +339,7 @@ def __init__( self.storage = FileStorage(self.name, self.WORKDIR) self.connection_factory = Connection - self.protocol_factory = TCPAbridged + self.protocol_factory = TCPFull self.dispatcher = Dispatcher(self) diff --git a/pyrogram/connection/connection.py b/pyrogram/connection/connection.py index 8db9724fb..618c92a51 100644 --- a/pyrogram/connection/connection.py +++ b/pyrogram/connection/connection.py @@ -18,64 +18,69 @@ import asyncio import logging -from typing import Optional, Type +from typing import Optional -from .transport import TCP, TCPAbridged +from .transport import * from ..session.internals import DataCenter log = logging.getLogger(__name__) class Connection: - MAX_CONNECTION_ATTEMPTS = 3 + MAX_RETRIES = 3 - def __init__( - self, - dc_id: int, - test_mode: bool, - ipv6: bool, - proxy: dict, - media: bool = False, - protocol_factory: Type[TCP] = TCPAbridged - ) -> None: + MODES = { + 0: TCPFull, + 1: TCPAbridged, + 2: TCPIntermediate, + 3: TCPAbridgedO, + 4: TCPIntermediateO + } + + def __init__(self, dc_id: int, test_mode: bool, ipv6: bool, proxy: dict, media: bool = False, mode: int = 3): self.dc_id = dc_id self.test_mode = test_mode self.ipv6 = ipv6 self.proxy = proxy self.media = media - self.protocol_factory = protocol_factory - self.address = DataCenter(dc_id, test_mode, ipv6, media) - self.protocol: Optional[TCP] = None + self.mode = self.MODES.get(mode, TCPAbridged) + + self.protocol = None # type: TCP - async def connect(self) -> None: - for i in range(Connection.MAX_CONNECTION_ATTEMPTS): - self.protocol = self.protocol_factory(ipv6=self.ipv6, proxy=self.proxy) + async def connect(self): + for i in range(Connection.MAX_RETRIES): + self.protocol = self.mode(self.ipv6, self.proxy) try: log.info("Connecting...") await self.protocol.connect(self.address) except OSError as e: - log.warning("Unable to connect due to network issues: %s", e) - await self.protocol.close() + log.warning(f"Unable to connect due to network issues: {e}") + self.protocol.close() await asyncio.sleep(1) else: - log.info("Connected! %s DC%s%s - IPv%s", - "Test" if self.test_mode else "Production", - self.dc_id, - " (media)" if self.media else "", - "6" if self.ipv6 else "4") + log.info("Connected! {} DC{}{} - IPv{} - {}".format( + "Test" if self.test_mode else "Production", + self.dc_id, + " (media)" if self.media else "", + "6" if self.ipv6 else "4", + self.mode.__name__, + )) break else: log.warning("Connection failed! Trying again...") - raise ConnectionError + raise TimeoutError - async def close(self) -> None: - await self.protocol.close() + def close(self): + self.protocol.close() log.info("Disconnected") - async def send(self, data: bytes) -> None: - await self.protocol.send(data) + async def send(self, data: bytes): + try: + await self.protocol.send(data) + except Exception as e: + raise OSError(e) async def recv(self) -> Optional[bytes]: return await self.protocol.recv() diff --git a/pyrogram/connection/transport/tcp/__init__.py b/pyrogram/connection/transport/tcp/__init__.py index bae35e882..3e23a8837 100644 --- a/pyrogram/connection/transport/tcp/__init__.py +++ b/pyrogram/connection/transport/tcp/__init__.py @@ -16,7 +16,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrogram. If not, see . -from .tcp import TCP, Proxy +from .tcp import TCP from .tcp_abridged import TCPAbridged from .tcp_abridged_o import TCPAbridgedO from .tcp_full import TCPFull diff --git a/pyrogram/connection/transport/tcp/tcp.py b/pyrogram/connection/transport/tcp/tcp.py index 9994fb822..c0efb625a 100644 --- a/pyrogram/connection/transport/tcp/tcp.py +++ b/pyrogram/connection/transport/tcp/tcp.py @@ -20,136 +20,92 @@ import ipaddress import logging import socket +import time from concurrent.futures import ThreadPoolExecutor -from typing import Tuple, Dict, TypedDict, Optional -import socks +try: + import socks +except ImportError as e: + e.msg = ( + "PySocks is missing and Pyrogram can't run without. " + "Please install it using \"pip3 install pysocks\"." + ) -log = logging.getLogger(__name__) - -proxy_type_by_scheme: Dict[str, int] = { - "SOCKS4": socks.SOCKS4, - "SOCKS5": socks.SOCKS5, - "HTTP": socks.HTTP, -} + raise e - -class Proxy(TypedDict): - scheme: str - hostname: str - port: int - username: Optional[str] - password: Optional[str] +log = logging.getLogger(__name__) class TCP: TIMEOUT = 10 - def __init__(self, ipv6: bool, proxy: Proxy) -> None: - self.ipv6 = ipv6 - self.proxy = proxy + def __init__(self, ipv6: bool, proxy: dict): + self.socket = None - self.reader: Optional[asyncio.StreamReader] = None - self.writer: Optional[asyncio.StreamWriter] = None + self.reader = None # type: asyncio.StreamReader + self.writer = None # type: asyncio.StreamWriter self.lock = asyncio.Lock() self.loop = asyncio.get_event_loop() - async def _connect_via_proxy( - self, - destination: Tuple[str, int] - ) -> None: - scheme = self.proxy.get("scheme") - if scheme is None: - raise ValueError("No scheme specified") + if proxy: + hostname = proxy.get("hostname") - proxy_type = proxy_type_by_scheme.get(scheme.upper()) - if proxy_type is None: - raise ValueError(f"Unknown proxy type {scheme}") + try: + ip_address = ipaddress.ip_address(hostname) + except ValueError: + self.socket = socks.socksocket(socket.AF_INET) + else: + if isinstance(ip_address, ipaddress.IPv6Address): + self.socket = socks.socksocket(socket.AF_INET6) + else: + self.socket = socks.socksocket(socket.AF_INET) - hostname = self.proxy.get("hostname") - port = self.proxy.get("port") - username = self.proxy.get("username") - password = self.proxy.get("password") + self.socket.set_proxy( + proxy_type=getattr(socks, proxy.get("scheme").upper()), + addr=hostname, + port=proxy.get("port", None), + username=proxy.get("username", None), + password=proxy.get("password", None) + ) - try: - ip_address = ipaddress.ip_address(hostname) - except ValueError: - is_proxy_ipv6 = False + log.info(f"Using proxy {hostname}") else: - is_proxy_ipv6 = isinstance(ip_address, ipaddress.IPv6Address) - - proxy_family = socket.AF_INET6 if is_proxy_ipv6 else socket.AF_INET - sock = socks.socksocket(proxy_family) - - sock.set_proxy( - proxy_type=proxy_type, - addr=hostname, - port=port, - username=username, - password=password - ) - sock.settimeout(TCP.TIMEOUT) - - await self.loop.sock_connect( - sock=sock, - address=destination - ) - - sock.setblocking(False) - - self.reader, self.writer = await asyncio.open_connection( - sock=sock - ) - - async def _connect_via_direct( - self, - destination: Tuple[str, int] - ) -> None: - host, port = destination - family = socket.AF_INET6 if self.ipv6 else socket.AF_INET - self.reader, self.writer = await asyncio.open_connection( - host=host, - port=port, - family=family - ) - - async def _connect(self, destination: Tuple[str, int]) -> None: - if self.proxy: - await self._connect_via_proxy(destination) - else: - await self._connect_via_direct(destination) + self.socket = socks.socksocket( + socket.AF_INET6 if ipv6 + else socket.AF_INET + ) - async def connect(self, address: Tuple[str, int]) -> None: - try: - await asyncio.wait_for(self._connect(address), TCP.TIMEOUT) - except asyncio.TimeoutError: # Re-raise as TimeoutError. asyncio.TimeoutError is deprecated in 3.11 - raise TimeoutError("Connection timed out") + self.socket.settimeout(TCP.TIMEOUT) - async def close(self) -> None: - if self.writer is None: - return None + async def connect(self, address: tuple): + # The socket used by the whole logic is blocking and thus it blocks when connecting. + # Offload the task to a thread executor to avoid blocking the main event loop. + with ThreadPoolExecutor(1) as executor: + await self.loop.run_in_executor(executor, self.socket.connect, address) + self.reader, self.writer = await asyncio.open_connection(sock=self.socket) + + def close(self): try: self.writer.close() - await asyncio.wait_for(self.writer.wait_closed(), TCP.TIMEOUT) - except Exception as e: - log.info("Close exception: %s %s", type(e).__name__, e) - - async def send(self, data: bytes) -> None: - if self.writer is None: - return None - - async with self.lock: + except AttributeError: try: - self.writer.write(data) - await self.writer.drain() - except Exception as e: - # error coming somewhere here - log.exception("Send exception: %s %s", type(e).__name__, e) - raise OSError(e) - - async def recv(self, length: int = 0) -> Optional[bytes]: + self.socket.shutdown(socket.SHUT_RDWR) + except OSError: + pass + finally: + # A tiny sleep placed here helps avoiding .recv(n) hanging until the timeout. + # This is a workaround that seems to fix the occasional delayed stop of a client. + time.sleep(0.001) + self.socket.close() + + async def send(self, data: bytes): + async with self.lock: + self.writer.write(data) + await self.writer.drain() + + async def recv(self, length: int = 0): data = b"" while len(data) < length: diff --git a/pyrogram/connection/transport/tcp/tcp_abridged.py b/pyrogram/connection/transport/tcp/tcp_abridged.py index 4cb4c1b2a..77d44cf41 100644 --- a/pyrogram/connection/transport/tcp/tcp_abridged.py +++ b/pyrogram/connection/transport/tcp/tcp_abridged.py @@ -17,22 +17,22 @@ # along with Pyrogram. If not, see . import logging -from typing import Optional, Tuple +from typing import Optional -from .tcp import TCP, Proxy +from .tcp import TCP log = logging.getLogger(__name__) class TCPAbridged(TCP): - def __init__(self, ipv6: bool, proxy: Proxy) -> None: + def __init__(self, ipv6: bool, proxy: dict): super().__init__(ipv6, proxy) - async def connect(self, address: Tuple[str, int]) -> None: + async def connect(self, address: tuple): await super().connect(address) await super().send(b"\xef") - async def send(self, data: bytes, *args) -> None: + async def send(self, data: bytes, *args): length = len(data) // 4 await super().send( diff --git a/pyrogram/connection/transport/tcp/tcp_abridged_o.py b/pyrogram/connection/transport/tcp/tcp_abridged_o.py index 20efb5ec3..12a832f6a 100644 --- a/pyrogram/connection/transport/tcp/tcp_abridged_o.py +++ b/pyrogram/connection/transport/tcp/tcp_abridged_o.py @@ -18,11 +18,11 @@ import logging import os -from typing import Optional, Tuple +from typing import Optional import pyrogram from pyrogram.crypto import aes -from .tcp import TCP, Proxy +from .tcp import TCP log = logging.getLogger(__name__) @@ -30,19 +30,19 @@ class TCPAbridgedO(TCP): RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4) - def __init__(self, ipv6: bool, proxy: Proxy) -> None: + def __init__(self, ipv6: bool, proxy: dict): super().__init__(ipv6, proxy) self.encrypt = None self.decrypt = None - async def connect(self, address: Tuple[str, int]) -> None: + async def connect(self, address: tuple): await super().connect(address) while True: nonce = bytearray(os.urandom(64)) - if bytes([nonce[0]]) != b"\xef" and nonce[:4] not in self.RESERVED and nonce[4:8] != b"\x00" * 4: + if nonce[0] != b"\xef" and nonce[:4] not in self.RESERVED and nonce[4:4] != b"\x00" * 4: nonce[56] = nonce[57] = nonce[58] = nonce[59] = 0xef break @@ -55,7 +55,7 @@ async def connect(self, address: Tuple[str, int]) -> None: await super().send(nonce) - async def send(self, data: bytes, *args) -> None: + async def send(self, data: bytes, *args): length = len(data) // 4 data = (bytes([length]) if length <= 126 else b"\x7f" + length.to_bytes(3, "little")) + data payload = await self.loop.run_in_executor(pyrogram.crypto_executor, aes.ctr256_encrypt, data, *self.encrypt) diff --git a/pyrogram/connection/transport/tcp/tcp_full.py b/pyrogram/connection/transport/tcp/tcp_full.py index ad9d98171..8bd89000c 100644 --- a/pyrogram/connection/transport/tcp/tcp_full.py +++ b/pyrogram/connection/transport/tcp/tcp_full.py @@ -19,24 +19,24 @@ import logging from binascii import crc32 from struct import pack, unpack -from typing import Optional, Tuple +from typing import Optional -from .tcp import TCP, Proxy +from .tcp import TCP log = logging.getLogger(__name__) class TCPFull(TCP): - def __init__(self, ipv6: bool, proxy: Proxy) -> None: + def __init__(self, ipv6: bool, proxy: dict): super().__init__(ipv6, proxy) - self.seq_no: Optional[int] = None + self.seq_no = None - async def connect(self, address: Tuple[str, int]) -> None: + async def connect(self, address: tuple): await super().connect(address) self.seq_no = 0 - async def send(self, data: bytes, *args) -> None: + async def send(self, data: bytes, *args): data = pack(" None: + def __init__(self, ipv6: bool, proxy: dict): super().__init__(ipv6, proxy) - async def connect(self, address: Tuple[str, int]) -> None: + async def connect(self, address: tuple): await super().connect(address) await super().send(b"\xee" * 4) - async def send(self, data: bytes, *args) -> None: + async def send(self, data: bytes, *args): await super().send(pack(" Optional[bytes]: diff --git a/pyrogram/connection/transport/tcp/tcp_intermediate_o.py b/pyrogram/connection/transport/tcp/tcp_intermediate_o.py index 3f47bdfe0..5b267661d 100644 --- a/pyrogram/connection/transport/tcp/tcp_intermediate_o.py +++ b/pyrogram/connection/transport/tcp/tcp_intermediate_o.py @@ -19,10 +19,10 @@ import logging import os from struct import pack, unpack -from typing import Optional, Tuple +from typing import Optional from pyrogram.crypto import aes -from .tcp import TCP, Proxy +from .tcp import TCP log = logging.getLogger(__name__) @@ -30,19 +30,19 @@ class TCPIntermediateO(TCP): RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4) - def __init__(self, ipv6: bool, proxy: Proxy) -> None: + def __init__(self, ipv6: bool, proxy: dict): super().__init__(ipv6, proxy) self.encrypt = None self.decrypt = None - async def connect(self, address: Tuple[str, int]) -> None: + async def connect(self, address: tuple): await super().connect(address) while True: nonce = bytearray(os.urandom(64)) - if bytes([nonce[0]]) != b"\xef" and nonce[:4] not in self.RESERVED and nonce[4:8] != b"\x00" * 4: + if nonce[0] != b"\xef" and nonce[:4] not in self.RESERVED and nonce[4:4] != b"\x00" * 4: nonce[56] = nonce[57] = nonce[58] = nonce[59] = 0xee break @@ -55,7 +55,7 @@ async def connect(self, address: Tuple[str, int]) -> None: await super().send(nonce) - async def send(self, data: bytes, *args) -> None: + async def send(self, data: bytes, *args): await super().send( aes.ctr256_encrypt( pack(" bytes: @@ -84,44 +81,37 @@ async def create(self): # The server may close the connection at any time, causing the auth key creation to fail. # If that happens, just try again up to MAX_RETRIES times. while True: - self.connection = self.connection_factory( - dc_id=self.dc_id, - test_mode=self.test_mode, - ipv6=self.ipv6, - proxy=self.proxy, - media=False, - protocol_factory=self.protocol_factory - ) + self.connection = Connection(self.dc_id, self.test_mode, self.ipv6, self.proxy) try: - log.info("Start creating a new auth key on DC%s", self.dc_id) + log.info(f"Start creating a new auth key on DC{self.dc_id}") await self.connection.connect() # Step 1; Step 2 nonce = int.from_bytes(urandom(16), "little", signed=True) - log.debug("Send req_pq: %s", nonce) + log.debug(f"Send req_pq: {nonce}") res_pq = await self.invoke(raw.functions.ReqPqMulti(nonce=nonce)) - log.debug("Got ResPq: %s", res_pq.server_nonce) - log.debug("Server public key fingerprints: %s", res_pq.server_public_key_fingerprints) + log.debug(f"Got ResPq: {res_pq.server_nonce}") + log.debug(f"Server public key fingerprints: {res_pq.server_public_key_fingerprints}") for i in res_pq.server_public_key_fingerprints: if i in rsa.server_public_keys: - log.debug("Using fingerprint: %s", i) + log.debug(f"Using fingerprint: {i}") public_key_fingerprint = i break else: - log.debug("Fingerprint unknown: %s", i) + log.debug(f"Fingerprint unknown: {i}") else: raise Exception("Public key not found") # Step 3 pq = int.from_bytes(res_pq.pq, "big") - log.debug("Start PQ factorization: %s", pq) + log.debug(f"Start PQ factorization: {pq}") start = time.time() g = prime.decompose(pq) p, q = sorted((g, pq // g)) # p < q - log.debug("Done PQ factorization (%ss): %s %s", round(time.time() - start, 3), p, q) + log.debug(f"Done PQ factorization ({round(time.time() - start, 3)}s): {p} {q}") # Step 4 server_nonce = res_pq.server_nonce @@ -183,7 +173,7 @@ async def create(self): dh_prime = int.from_bytes(server_dh_inner_data.dh_prime, "big") delta_time = server_dh_inner_data.server_time - time.time() - log.debug("Delta time: %s", round(delta_time, 3)) + log.debug(f"Delta time: {round(delta_time, 3)}") # Step 6 g = server_dh_inner_data.g @@ -277,9 +267,9 @@ async def create(self): # Step 9 server_salt = aes.xor(new_nonce[:8], server_nonce[:8]) - log.debug("Server salt: %s", int.from_bytes(server_salt, "little")) + log.debug(f"Server salt: {int.from_bytes(server_salt, 'little')}") - log.info("Done auth key exchange: %s", set_client_dh_params_answer.__class__.__name__) + log.info(f"Done auth key exchange: {set_client_dh_params_answer.__class__.__name__}") except Exception as e: log.info("Retrying due to %s: %s", type(e).__name__, e) @@ -293,4 +283,4 @@ async def create(self): else: return auth_key finally: - await self.connection.close() + self.connection.close() diff --git a/pyrogram/session/session.py b/pyrogram/session/session.py index 491d81078..120fca439 100644 --- a/pyrogram/session/session.py +++ b/pyrogram/session/session.py @@ -22,7 +22,6 @@ import os from hashlib import sha1 from io import BytesIO -from typing import Optional import pyrogram from pyrogram import raw @@ -33,7 +32,6 @@ FloodWait, FloodPremiumWait, ServiceUnavailable, BadMsgNotification, SecurityCheckMismatch, - Unauthorized ) from pyrogram.raw.all import layer from pyrogram.raw.core import TLObject, MsgContainer, Int, FutureSalts @@ -49,11 +47,11 @@ def __init__(self): class Session: - START_TIMEOUT = 5 + START_TIMEOUT = 1 WAIT_TIMEOUT = 15 SLEEP_THRESHOLD = 10 - MAX_RETRIES = 10 - ACKS_THRESHOLD = 10 + MAX_RETRIES = 5 + ACKS_THRESHOLD = 8 PING_INTERVAL = 5 STORED_MSG_IDS_MAX_SIZE = 1000 * 2 @@ -85,7 +83,7 @@ def __init__( self.is_media = is_media self.is_cdn = is_cdn - self.connection: Optional[Connection] = None + self.connection = None self.auth_key_id = sha1(auth_key).digest()[-8:] @@ -103,28 +101,26 @@ def __init__( self.ping_task = None self.ping_task_event = asyncio.Event() - self.recv_task = None + self.network_task = None - self.is_started = asyncio.Event() - self.restart_event = asyncio.Event() + self.is_connected = asyncio.Event() self.loop = asyncio.get_event_loop() async def start(self): while True: - self.connection = self.client.connection_factory( - dc_id=self.dc_id, - test_mode=self.test_mode, - ipv6=self.client.ipv6, - proxy=self.client.proxy, - media=self.is_media, - protocol_factory=self.client.protocol_factory + self.connection = Connection( + self.dc_id, + self.test_mode, + self.client.ipv6, + self.client.proxy, + self.is_media ) try: await self.connection.connect() - self.recv_task = self.loop.create_task(self.recv_worker()) + self.network_task = self.loop.create_task(self.network_worker()) await self.send(raw.functions.Ping(ping_id=0), timeout=self.START_TIMEOUT) @@ -153,13 +149,14 @@ async def start(self): self.ping_task = self.loop.create_task(self.ping_worker()) - log.info("Session initialized: Layer %s", layer) - log.info("Device: %s - %s", self.client.device_model, self.client.app_version) - log.info("System: %s (%s)", self.client.system_version, self.client.lang_code) + log.info(f"Session initialized: Layer {layer}") + log.info(f"Device: {self.client.device_model} - {self.client.app_version}") + log.info(f"System: {self.client.system_version} ({self.client.lang_code.upper()})") + except AuthKeyDuplicated as e: await self.stop() raise e - except (OSError, RPCError): + except (OSError, TimeoutError, RPCError): await self.stop() except Exception as e: await self.stop() @@ -167,14 +164,12 @@ async def start(self): else: break - self.is_started.set() + self.is_connected.set() log.info("Session started") async def stop(self): - self.is_started.clear() - - self.stored_msg_ids.clear() + self.is_connected.clear() self.ping_task_event.set() @@ -183,24 +178,25 @@ async def stop(self): self.ping_task_event.clear() - await self.connection.close() + self.connection.close() - if self.recv_task: - await self.recv_task + if self.network_task: + await self.network_task + + for i in self.results.values(): + i.event.set() if not self.is_media and callable(self.client.disconnect_handler): try: await self.client.disconnect_handler(self.client) except Exception as e: - log.exception(e) + log.error(e, exc_info=True) log.info("Session stopped") async def restart(self): - self.restart_event.set() await self.stop() await self.start() - self.restart_event.clear() async def handle_packet(self, packet): try: @@ -210,11 +206,10 @@ async def handle_packet(self, packet): BytesIO(packet), self.session_id, self.auth_key, - self.auth_key_id + self.auth_key_id, + # self.stored_msg_ids ) - except ValueError as e: - log.debug(e) - self.loop.create_task(self.restart()) + except SecurityCheckMismatch: return messages = ( @@ -223,9 +218,15 @@ async def handle_packet(self, packet): else [data] ) - log.debug("Received: %s", data) + # Call log.debug twice because calling it once by appending "data" to the previous string (i.e. f"Kind: {data}") + # will cause "data" to be evaluated as string every time instead of only when debug is actually enabled. + log.debug("Received:") + log.debug(data) for msg in messages: + # if msg.seq_no == 0: + # MsgId.set_server_time(msg.msg_id / (2 ** 32)) + if msg.seq_no % 2 != 0: if msg.msg_id in self.pending_acks: continue @@ -283,11 +284,11 @@ async def handle_packet(self, packet): self.results[msg_id].event.set() if len(self.pending_acks) >= self.ACKS_THRESHOLD: - log.debug("Sending %s acks", len(self.pending_acks)) + log.debug(f"Send {len(self.pending_acks)} acks") try: await self.send(raw.types.MsgsAck(msg_ids=list(self.pending_acks)), False) - except OSError: + except (OSError, TimeoutError): pass else: self.pending_acks.clear() @@ -309,15 +310,12 @@ async def ping_worker(self): ping_id=0, disconnect_delay=self.WAIT_TIMEOUT + 10 ), False ) - except OSError: - self.loop.create_task(self.restart()) - break - except RPCError: + except (OSError, TimeoutError, RPCError): pass log.info("PingTask stopped") - async def recv_worker(self): + async def network_worker(self): log.info("NetworkTask started") while True: @@ -325,19 +323,9 @@ async def recv_worker(self): if packet is None or len(packet) == 4: if packet: - error_code = -Int.read(BytesIO(packet)) - - # if error_code == 404: - # raise Unauthorized( - # "Auth key not found in the system. You must delete your session file " - # "and log in again with your phone number or bot token." - # ) - log.warning( - "Server sent transport error: %s (%s)", - error_code, Session.TRANSPORT_ERRORS.get(error_code, "unknown error") - ) + log.warning(f'Server sent "{Int.read(BytesIO(packet))}"') - if self.is_started.is_set(): + if self.is_connected.is_set(): self.loop.create_task(self.restart()) break @@ -358,7 +346,10 @@ async def send( if wait_response: self.results[msg_id] = Result() - log.debug("Sent: %s", message) + # Call log.debug twice because calling it once by appending "data" to the previous string (i.e. f"Kind: {data}") + # will cause "data" to be evaluated as string every time instead of only when debug is actually enabled. + log.debug(f"Sent:") + log.debug(message) payload = await self.loop.run_in_executor( pyrogram.crypto_executor, @@ -381,26 +372,23 @@ async def send( await asyncio.wait_for(self.results[msg_id].event.wait(), timeout) except asyncio.TimeoutError: pass - - result = self.results.pop(msg_id).value + finally: + result = self.results.pop(msg_id).value if result is None: - raise TimeoutError("Request timed out") - - if isinstance(result, raw.types.RpcError): + raise TimeoutError + elif isinstance(result, raw.types.RpcError): if isinstance(data, Session.CUR_ALWD_INNR_QRYS): data = data.query RPCError.raise_it(result, type(data)) - - if isinstance(result, raw.types.BadMsgNotification): - log.warning("%s: %s", BadMsgNotification.__name__, BadMsgNotification(result.error_code)) - - if isinstance(result, raw.types.BadServerSalt): + elif isinstance(result, raw.types.BadMsgNotification): + raise BadMsgNotification(result.error_code) + elif isinstance(result, raw.types.BadServerSalt): self.salt = result.new_server_salt return await self.send(data, wait_response, timeout) - - return result + else: + return result async def invoke( self, @@ -410,7 +398,7 @@ async def invoke( sleep_threshold: float = SLEEP_THRESHOLD ): try: - await asyncio.wait_for(self.is_started.wait(), self.WAIT_TIMEOUT) + await asyncio.wait_for(self.is_connected.wait(), self.WAIT_TIMEOUT) except asyncio.TimeoutError: pass @@ -430,46 +418,17 @@ async def invoke( if amount > sleep_threshold >= 0: raise - log.warning('[%s] Waiting for %s seconds before continuing (required by "%s")', - self.client.name, amount, query_name) + log.warning(f'[{self.client.name}] Waiting for {amount} seconds before continuing ' + f'(required by "{query_name}")') await asyncio.sleep(amount) - except ( - OSError, - RuntimeError, - InternalServerError, - ServiceUnavailable, - TimeoutError, - ) as e: - retries -= 1 - if ( - retries == 0 or - ( - isinstance(e, InternalServerError) - and getattr(e, "code", 0) == 500 - and (e.ID or e.NAME) in [ - "HISTORY_GET_FAILED" - ] - ) - ): + except (OSError, TimeoutError, InternalServerError, ServiceUnavailable) as e: + if retries == 0: raise e from None (log.warning if retries < 2 else log.info)( - '[%s] Retrying "%s" due to: %s', - Session.MAX_RETRIES - retries + 1, - query_name, str(e) or repr(e) - ) + f'[{Session.MAX_RETRIES - retries + 1}] Retrying "{query_name}" due to {str(e) or repr(e)}') - # restart was never being called after Exception block - if not self.restart_event.is_set(): - self.loop.create_task(self.restart()) - else: - # multiple Exceptions can be raised in a row, so we need to wait for the restart to finish - try: - await asyncio.wait_for(self.restart_event.wait(), self.WAIT_TIMEOUT) - except asyncio.TimeoutError: - pass - await asyncio.sleep(0.5) return await self.invoke(query, retries - 1, timeout)