diff --git a/.github/workflows/build-docs.yml b/.github/workflows/build-docs.yml
index cbd6c8f8a..98ad59f18 100644
--- a/.github/workflows/build-docs.yml
+++ b/.github/workflows/build-docs.yml
@@ -9,11 +9,11 @@ jobs:
name: build-doc
runs-on: ubuntu-latest
steps:
- - uses: actions/checkout@v3
+ - uses: actions/checkout@v4
with:
fetch-depth: 1
- name: Set up Python
- uses: actions/setup-python@v3
+ uses: actions/setup-python@v5
with:
python-version: '3.11'
diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml
index 53d1e330f..62728659c 100644
--- a/.github/workflows/publish.yml
+++ b/.github/workflows/publish.yml
@@ -19,13 +19,16 @@ permissions:
jobs:
deploy:
-
runs-on: ubuntu-latest
+ environment: release
+ permissions:
+ id-token: write
+
steps:
- - uses: actions/checkout@v3
+ - uses: actions/checkout@v4
- name: Set up Python
- uses: actions/setup-python@v3
+ uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Install dependencies
@@ -38,8 +41,5 @@ jobs:
- name: Build package
run: hatch build
- name: Publish package
- env:
- HATCH_INDEX_USER: __token__
- HATCH_INDEX_AUTH: ${{ secrets.PYPI_API_TOKEN }}
- run: |
- hatch publish
+ uses: pypa/gh-action-pypi-publish@release/v1
+
diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml
index 6ad8fe3c8..012d2f1c0 100644
--- a/.github/workflows/python.yml
+++ b/.github/workflows/python.yml
@@ -13,13 +13,13 @@ jobs:
strategy:
matrix:
os: [ubuntu-22.04, macos-12]
- python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"]
+ python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12", "3.13"]
steps:
- - uses: actions/checkout@v3
+ - uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v2
+ uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
diff --git a/pyproject.toml b/pyproject.toml
index 098feed1f..55138ca6d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -21,6 +21,7 @@ classifiers = [
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
+ "Programming Language :: Python :: 3.13",
"Programming Language :: Python :: Implementation",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
diff --git a/pyrogram/__init__.py b/pyrogram/__init__.py
index 5231a113d..f23bb17c6 100644
--- a/pyrogram/__init__.py
+++ b/pyrogram/__init__.py
@@ -17,7 +17,7 @@
# along with Pyrogram. If not, see .
__fork_name__ = "pyrotgfork"
-__version__ = "2.1.33.8"
+__version__ = "2.1.33.9"
__license__ = "GNU Lesser General Public License v3.0 (LGPL-3.0)"
__copyright__ = "Copyright (C) 2017-present Dan "
diff --git a/pyrogram/client.py b/pyrogram/client.py
index 535e99d17..59a671349 100644
--- a/pyrogram/client.py
+++ b/pyrogram/client.py
@@ -55,7 +55,7 @@
from pyrogram.types import User, TermsOfService
from pyrogram.utils import ainput
from .connection import Connection
-from .connection.transport import TCP, TCPAbridged
+from .connection.transport import TCP, TCPAbridged, TCPFull
from .dispatcher import Dispatcher
from .file_id import FileId, FileType, ThumbnailSource
from .mime_types import mime_types
@@ -339,7 +339,7 @@ def __init__(
self.storage = FileStorage(self.name, self.WORKDIR)
self.connection_factory = Connection
- self.protocol_factory = TCPAbridged
+ self.protocol_factory = TCPFull
self.dispatcher = Dispatcher(self)
diff --git a/pyrogram/connection/connection.py b/pyrogram/connection/connection.py
index 8db9724fb..618c92a51 100644
--- a/pyrogram/connection/connection.py
+++ b/pyrogram/connection/connection.py
@@ -18,64 +18,69 @@
import asyncio
import logging
-from typing import Optional, Type
+from typing import Optional
-from .transport import TCP, TCPAbridged
+from .transport import *
from ..session.internals import DataCenter
log = logging.getLogger(__name__)
class Connection:
- MAX_CONNECTION_ATTEMPTS = 3
+ MAX_RETRIES = 3
- def __init__(
- self,
- dc_id: int,
- test_mode: bool,
- ipv6: bool,
- proxy: dict,
- media: bool = False,
- protocol_factory: Type[TCP] = TCPAbridged
- ) -> None:
+ MODES = {
+ 0: TCPFull,
+ 1: TCPAbridged,
+ 2: TCPIntermediate,
+ 3: TCPAbridgedO,
+ 4: TCPIntermediateO
+ }
+
+ def __init__(self, dc_id: int, test_mode: bool, ipv6: bool, proxy: dict, media: bool = False, mode: int = 3):
self.dc_id = dc_id
self.test_mode = test_mode
self.ipv6 = ipv6
self.proxy = proxy
self.media = media
- self.protocol_factory = protocol_factory
-
self.address = DataCenter(dc_id, test_mode, ipv6, media)
- self.protocol: Optional[TCP] = None
+ self.mode = self.MODES.get(mode, TCPAbridged)
+
+ self.protocol = None # type: TCP
- async def connect(self) -> None:
- for i in range(Connection.MAX_CONNECTION_ATTEMPTS):
- self.protocol = self.protocol_factory(ipv6=self.ipv6, proxy=self.proxy)
+ async def connect(self):
+ for i in range(Connection.MAX_RETRIES):
+ self.protocol = self.mode(self.ipv6, self.proxy)
try:
log.info("Connecting...")
await self.protocol.connect(self.address)
except OSError as e:
- log.warning("Unable to connect due to network issues: %s", e)
- await self.protocol.close()
+ log.warning(f"Unable to connect due to network issues: {e}")
+ self.protocol.close()
await asyncio.sleep(1)
else:
- log.info("Connected! %s DC%s%s - IPv%s",
- "Test" if self.test_mode else "Production",
- self.dc_id,
- " (media)" if self.media else "",
- "6" if self.ipv6 else "4")
+ log.info("Connected! {} DC{}{} - IPv{} - {}".format(
+ "Test" if self.test_mode else "Production",
+ self.dc_id,
+ " (media)" if self.media else "",
+ "6" if self.ipv6 else "4",
+ self.mode.__name__,
+ ))
break
else:
log.warning("Connection failed! Trying again...")
- raise ConnectionError
+ raise TimeoutError
- async def close(self) -> None:
- await self.protocol.close()
+ def close(self):
+ self.protocol.close()
log.info("Disconnected")
- async def send(self, data: bytes) -> None:
- await self.protocol.send(data)
+ async def send(self, data: bytes):
+ try:
+ await self.protocol.send(data)
+ except Exception as e:
+ raise OSError(e)
async def recv(self) -> Optional[bytes]:
return await self.protocol.recv()
diff --git a/pyrogram/connection/transport/tcp/__init__.py b/pyrogram/connection/transport/tcp/__init__.py
index bae35e882..3e23a8837 100644
--- a/pyrogram/connection/transport/tcp/__init__.py
+++ b/pyrogram/connection/transport/tcp/__init__.py
@@ -16,7 +16,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see .
-from .tcp import TCP, Proxy
+from .tcp import TCP
from .tcp_abridged import TCPAbridged
from .tcp_abridged_o import TCPAbridgedO
from .tcp_full import TCPFull
diff --git a/pyrogram/connection/transport/tcp/tcp.py b/pyrogram/connection/transport/tcp/tcp.py
index 9994fb822..c0efb625a 100644
--- a/pyrogram/connection/transport/tcp/tcp.py
+++ b/pyrogram/connection/transport/tcp/tcp.py
@@ -20,136 +20,92 @@
import ipaddress
import logging
import socket
+import time
from concurrent.futures import ThreadPoolExecutor
-from typing import Tuple, Dict, TypedDict, Optional
-import socks
+try:
+ import socks
+except ImportError as e:
+ e.msg = (
+ "PySocks is missing and Pyrogram can't run without. "
+ "Please install it using \"pip3 install pysocks\"."
+ )
-log = logging.getLogger(__name__)
-
-proxy_type_by_scheme: Dict[str, int] = {
- "SOCKS4": socks.SOCKS4,
- "SOCKS5": socks.SOCKS5,
- "HTTP": socks.HTTP,
-}
+ raise e
-
-class Proxy(TypedDict):
- scheme: str
- hostname: str
- port: int
- username: Optional[str]
- password: Optional[str]
+log = logging.getLogger(__name__)
class TCP:
TIMEOUT = 10
- def __init__(self, ipv6: bool, proxy: Proxy) -> None:
- self.ipv6 = ipv6
- self.proxy = proxy
+ def __init__(self, ipv6: bool, proxy: dict):
+ self.socket = None
- self.reader: Optional[asyncio.StreamReader] = None
- self.writer: Optional[asyncio.StreamWriter] = None
+ self.reader = None # type: asyncio.StreamReader
+ self.writer = None # type: asyncio.StreamWriter
self.lock = asyncio.Lock()
self.loop = asyncio.get_event_loop()
- async def _connect_via_proxy(
- self,
- destination: Tuple[str, int]
- ) -> None:
- scheme = self.proxy.get("scheme")
- if scheme is None:
- raise ValueError("No scheme specified")
+ if proxy:
+ hostname = proxy.get("hostname")
- proxy_type = proxy_type_by_scheme.get(scheme.upper())
- if proxy_type is None:
- raise ValueError(f"Unknown proxy type {scheme}")
+ try:
+ ip_address = ipaddress.ip_address(hostname)
+ except ValueError:
+ self.socket = socks.socksocket(socket.AF_INET)
+ else:
+ if isinstance(ip_address, ipaddress.IPv6Address):
+ self.socket = socks.socksocket(socket.AF_INET6)
+ else:
+ self.socket = socks.socksocket(socket.AF_INET)
- hostname = self.proxy.get("hostname")
- port = self.proxy.get("port")
- username = self.proxy.get("username")
- password = self.proxy.get("password")
+ self.socket.set_proxy(
+ proxy_type=getattr(socks, proxy.get("scheme").upper()),
+ addr=hostname,
+ port=proxy.get("port", None),
+ username=proxy.get("username", None),
+ password=proxy.get("password", None)
+ )
- try:
- ip_address = ipaddress.ip_address(hostname)
- except ValueError:
- is_proxy_ipv6 = False
+ log.info(f"Using proxy {hostname}")
else:
- is_proxy_ipv6 = isinstance(ip_address, ipaddress.IPv6Address)
-
- proxy_family = socket.AF_INET6 if is_proxy_ipv6 else socket.AF_INET
- sock = socks.socksocket(proxy_family)
-
- sock.set_proxy(
- proxy_type=proxy_type,
- addr=hostname,
- port=port,
- username=username,
- password=password
- )
- sock.settimeout(TCP.TIMEOUT)
-
- await self.loop.sock_connect(
- sock=sock,
- address=destination
- )
-
- sock.setblocking(False)
-
- self.reader, self.writer = await asyncio.open_connection(
- sock=sock
- )
-
- async def _connect_via_direct(
- self,
- destination: Tuple[str, int]
- ) -> None:
- host, port = destination
- family = socket.AF_INET6 if self.ipv6 else socket.AF_INET
- self.reader, self.writer = await asyncio.open_connection(
- host=host,
- port=port,
- family=family
- )
-
- async def _connect(self, destination: Tuple[str, int]) -> None:
- if self.proxy:
- await self._connect_via_proxy(destination)
- else:
- await self._connect_via_direct(destination)
+ self.socket = socks.socksocket(
+ socket.AF_INET6 if ipv6
+ else socket.AF_INET
+ )
- async def connect(self, address: Tuple[str, int]) -> None:
- try:
- await asyncio.wait_for(self._connect(address), TCP.TIMEOUT)
- except asyncio.TimeoutError: # Re-raise as TimeoutError. asyncio.TimeoutError is deprecated in 3.11
- raise TimeoutError("Connection timed out")
+ self.socket.settimeout(TCP.TIMEOUT)
- async def close(self) -> None:
- if self.writer is None:
- return None
+ async def connect(self, address: tuple):
+ # The socket used by the whole logic is blocking and thus it blocks when connecting.
+ # Offload the task to a thread executor to avoid blocking the main event loop.
+ with ThreadPoolExecutor(1) as executor:
+ await self.loop.run_in_executor(executor, self.socket.connect, address)
+ self.reader, self.writer = await asyncio.open_connection(sock=self.socket)
+
+ def close(self):
try:
self.writer.close()
- await asyncio.wait_for(self.writer.wait_closed(), TCP.TIMEOUT)
- except Exception as e:
- log.info("Close exception: %s %s", type(e).__name__, e)
-
- async def send(self, data: bytes) -> None:
- if self.writer is None:
- return None
-
- async with self.lock:
+ except AttributeError:
try:
- self.writer.write(data)
- await self.writer.drain()
- except Exception as e:
- # error coming somewhere here
- log.exception("Send exception: %s %s", type(e).__name__, e)
- raise OSError(e)
-
- async def recv(self, length: int = 0) -> Optional[bytes]:
+ self.socket.shutdown(socket.SHUT_RDWR)
+ except OSError:
+ pass
+ finally:
+ # A tiny sleep placed here helps avoiding .recv(n) hanging until the timeout.
+ # This is a workaround that seems to fix the occasional delayed stop of a client.
+ time.sleep(0.001)
+ self.socket.close()
+
+ async def send(self, data: bytes):
+ async with self.lock:
+ self.writer.write(data)
+ await self.writer.drain()
+
+ async def recv(self, length: int = 0):
data = b""
while len(data) < length:
diff --git a/pyrogram/connection/transport/tcp/tcp_abridged.py b/pyrogram/connection/transport/tcp/tcp_abridged.py
index 4cb4c1b2a..77d44cf41 100644
--- a/pyrogram/connection/transport/tcp/tcp_abridged.py
+++ b/pyrogram/connection/transport/tcp/tcp_abridged.py
@@ -17,22 +17,22 @@
# along with Pyrogram. If not, see .
import logging
-from typing import Optional, Tuple
+from typing import Optional
-from .tcp import TCP, Proxy
+from .tcp import TCP
log = logging.getLogger(__name__)
class TCPAbridged(TCP):
- def __init__(self, ipv6: bool, proxy: Proxy) -> None:
+ def __init__(self, ipv6: bool, proxy: dict):
super().__init__(ipv6, proxy)
- async def connect(self, address: Tuple[str, int]) -> None:
+ async def connect(self, address: tuple):
await super().connect(address)
await super().send(b"\xef")
- async def send(self, data: bytes, *args) -> None:
+ async def send(self, data: bytes, *args):
length = len(data) // 4
await super().send(
diff --git a/pyrogram/connection/transport/tcp/tcp_abridged_o.py b/pyrogram/connection/transport/tcp/tcp_abridged_o.py
index 20efb5ec3..12a832f6a 100644
--- a/pyrogram/connection/transport/tcp/tcp_abridged_o.py
+++ b/pyrogram/connection/transport/tcp/tcp_abridged_o.py
@@ -18,11 +18,11 @@
import logging
import os
-from typing import Optional, Tuple
+from typing import Optional
import pyrogram
from pyrogram.crypto import aes
-from .tcp import TCP, Proxy
+from .tcp import TCP
log = logging.getLogger(__name__)
@@ -30,19 +30,19 @@
class TCPAbridgedO(TCP):
RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4)
- def __init__(self, ipv6: bool, proxy: Proxy) -> None:
+ def __init__(self, ipv6: bool, proxy: dict):
super().__init__(ipv6, proxy)
self.encrypt = None
self.decrypt = None
- async def connect(self, address: Tuple[str, int]) -> None:
+ async def connect(self, address: tuple):
await super().connect(address)
while True:
nonce = bytearray(os.urandom(64))
- if bytes([nonce[0]]) != b"\xef" and nonce[:4] not in self.RESERVED and nonce[4:8] != b"\x00" * 4:
+ if nonce[0] != b"\xef" and nonce[:4] not in self.RESERVED and nonce[4:4] != b"\x00" * 4:
nonce[56] = nonce[57] = nonce[58] = nonce[59] = 0xef
break
@@ -55,7 +55,7 @@ async def connect(self, address: Tuple[str, int]) -> None:
await super().send(nonce)
- async def send(self, data: bytes, *args) -> None:
+ async def send(self, data: bytes, *args):
length = len(data) // 4
data = (bytes([length]) if length <= 126 else b"\x7f" + length.to_bytes(3, "little")) + data
payload = await self.loop.run_in_executor(pyrogram.crypto_executor, aes.ctr256_encrypt, data, *self.encrypt)
diff --git a/pyrogram/connection/transport/tcp/tcp_full.py b/pyrogram/connection/transport/tcp/tcp_full.py
index ad9d98171..8bd89000c 100644
--- a/pyrogram/connection/transport/tcp/tcp_full.py
+++ b/pyrogram/connection/transport/tcp/tcp_full.py
@@ -19,24 +19,24 @@
import logging
from binascii import crc32
from struct import pack, unpack
-from typing import Optional, Tuple
+from typing import Optional
-from .tcp import TCP, Proxy
+from .tcp import TCP
log = logging.getLogger(__name__)
class TCPFull(TCP):
- def __init__(self, ipv6: bool, proxy: Proxy) -> None:
+ def __init__(self, ipv6: bool, proxy: dict):
super().__init__(ipv6, proxy)
- self.seq_no: Optional[int] = None
+ self.seq_no = None
- async def connect(self, address: Tuple[str, int]) -> None:
+ async def connect(self, address: tuple):
await super().connect(address)
self.seq_no = 0
- async def send(self, data: bytes, *args) -> None:
+ async def send(self, data: bytes, *args):
data = pack(" None:
+ def __init__(self, ipv6: bool, proxy: dict):
super().__init__(ipv6, proxy)
- async def connect(self, address: Tuple[str, int]) -> None:
+ async def connect(self, address: tuple):
await super().connect(address)
await super().send(b"\xee" * 4)
- async def send(self, data: bytes, *args) -> None:
+ async def send(self, data: bytes, *args):
await super().send(pack(" Optional[bytes]:
diff --git a/pyrogram/connection/transport/tcp/tcp_intermediate_o.py b/pyrogram/connection/transport/tcp/tcp_intermediate_o.py
index 3f47bdfe0..5b267661d 100644
--- a/pyrogram/connection/transport/tcp/tcp_intermediate_o.py
+++ b/pyrogram/connection/transport/tcp/tcp_intermediate_o.py
@@ -19,10 +19,10 @@
import logging
import os
from struct import pack, unpack
-from typing import Optional, Tuple
+from typing import Optional
from pyrogram.crypto import aes
-from .tcp import TCP, Proxy
+from .tcp import TCP
log = logging.getLogger(__name__)
@@ -30,19 +30,19 @@
class TCPIntermediateO(TCP):
RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4)
- def __init__(self, ipv6: bool, proxy: Proxy) -> None:
+ def __init__(self, ipv6: bool, proxy: dict):
super().__init__(ipv6, proxy)
self.encrypt = None
self.decrypt = None
- async def connect(self, address: Tuple[str, int]) -> None:
+ async def connect(self, address: tuple):
await super().connect(address)
while True:
nonce = bytearray(os.urandom(64))
- if bytes([nonce[0]]) != b"\xef" and nonce[:4] not in self.RESERVED and nonce[4:8] != b"\x00" * 4:
+ if nonce[0] != b"\xef" and nonce[:4] not in self.RESERVED and nonce[4:4] != b"\x00" * 4:
nonce[56] = nonce[57] = nonce[58] = nonce[59] = 0xee
break
@@ -55,7 +55,7 @@ async def connect(self, address: Tuple[str, int]) -> None:
await super().send(nonce)
- async def send(self, data: bytes, *args) -> None:
+ async def send(self, data: bytes, *args):
await super().send(
aes.ctr256_encrypt(
pack(" bytes:
@@ -84,44 +81,37 @@ async def create(self):
# The server may close the connection at any time, causing the auth key creation to fail.
# If that happens, just try again up to MAX_RETRIES times.
while True:
- self.connection = self.connection_factory(
- dc_id=self.dc_id,
- test_mode=self.test_mode,
- ipv6=self.ipv6,
- proxy=self.proxy,
- media=False,
- protocol_factory=self.protocol_factory
- )
+ self.connection = Connection(self.dc_id, self.test_mode, self.ipv6, self.proxy)
try:
- log.info("Start creating a new auth key on DC%s", self.dc_id)
+ log.info(f"Start creating a new auth key on DC{self.dc_id}")
await self.connection.connect()
# Step 1; Step 2
nonce = int.from_bytes(urandom(16), "little", signed=True)
- log.debug("Send req_pq: %s", nonce)
+ log.debug(f"Send req_pq: {nonce}")
res_pq = await self.invoke(raw.functions.ReqPqMulti(nonce=nonce))
- log.debug("Got ResPq: %s", res_pq.server_nonce)
- log.debug("Server public key fingerprints: %s", res_pq.server_public_key_fingerprints)
+ log.debug(f"Got ResPq: {res_pq.server_nonce}")
+ log.debug(f"Server public key fingerprints: {res_pq.server_public_key_fingerprints}")
for i in res_pq.server_public_key_fingerprints:
if i in rsa.server_public_keys:
- log.debug("Using fingerprint: %s", i)
+ log.debug(f"Using fingerprint: {i}")
public_key_fingerprint = i
break
else:
- log.debug("Fingerprint unknown: %s", i)
+ log.debug(f"Fingerprint unknown: {i}")
else:
raise Exception("Public key not found")
# Step 3
pq = int.from_bytes(res_pq.pq, "big")
- log.debug("Start PQ factorization: %s", pq)
+ log.debug(f"Start PQ factorization: {pq}")
start = time.time()
g = prime.decompose(pq)
p, q = sorted((g, pq // g)) # p < q
- log.debug("Done PQ factorization (%ss): %s %s", round(time.time() - start, 3), p, q)
+ log.debug(f"Done PQ factorization ({round(time.time() - start, 3)}s): {p} {q}")
# Step 4
server_nonce = res_pq.server_nonce
@@ -183,7 +173,7 @@ async def create(self):
dh_prime = int.from_bytes(server_dh_inner_data.dh_prime, "big")
delta_time = server_dh_inner_data.server_time - time.time()
- log.debug("Delta time: %s", round(delta_time, 3))
+ log.debug(f"Delta time: {round(delta_time, 3)}")
# Step 6
g = server_dh_inner_data.g
@@ -277,9 +267,9 @@ async def create(self):
# Step 9
server_salt = aes.xor(new_nonce[:8], server_nonce[:8])
- log.debug("Server salt: %s", int.from_bytes(server_salt, "little"))
+ log.debug(f"Server salt: {int.from_bytes(server_salt, 'little')}")
- log.info("Done auth key exchange: %s", set_client_dh_params_answer.__class__.__name__)
+ log.info(f"Done auth key exchange: {set_client_dh_params_answer.__class__.__name__}")
except Exception as e:
log.info("Retrying due to %s: %s", type(e).__name__, e)
@@ -293,4 +283,4 @@ async def create(self):
else:
return auth_key
finally:
- await self.connection.close()
+ self.connection.close()
diff --git a/pyrogram/session/session.py b/pyrogram/session/session.py
index 491d81078..120fca439 100644
--- a/pyrogram/session/session.py
+++ b/pyrogram/session/session.py
@@ -22,7 +22,6 @@
import os
from hashlib import sha1
from io import BytesIO
-from typing import Optional
import pyrogram
from pyrogram import raw
@@ -33,7 +32,6 @@
FloodWait, FloodPremiumWait,
ServiceUnavailable, BadMsgNotification,
SecurityCheckMismatch,
- Unauthorized
)
from pyrogram.raw.all import layer
from pyrogram.raw.core import TLObject, MsgContainer, Int, FutureSalts
@@ -49,11 +47,11 @@ def __init__(self):
class Session:
- START_TIMEOUT = 5
+ START_TIMEOUT = 1
WAIT_TIMEOUT = 15
SLEEP_THRESHOLD = 10
- MAX_RETRIES = 10
- ACKS_THRESHOLD = 10
+ MAX_RETRIES = 5
+ ACKS_THRESHOLD = 8
PING_INTERVAL = 5
STORED_MSG_IDS_MAX_SIZE = 1000 * 2
@@ -85,7 +83,7 @@ def __init__(
self.is_media = is_media
self.is_cdn = is_cdn
- self.connection: Optional[Connection] = None
+ self.connection = None
self.auth_key_id = sha1(auth_key).digest()[-8:]
@@ -103,28 +101,26 @@ def __init__(
self.ping_task = None
self.ping_task_event = asyncio.Event()
- self.recv_task = None
+ self.network_task = None
- self.is_started = asyncio.Event()
- self.restart_event = asyncio.Event()
+ self.is_connected = asyncio.Event()
self.loop = asyncio.get_event_loop()
async def start(self):
while True:
- self.connection = self.client.connection_factory(
- dc_id=self.dc_id,
- test_mode=self.test_mode,
- ipv6=self.client.ipv6,
- proxy=self.client.proxy,
- media=self.is_media,
- protocol_factory=self.client.protocol_factory
+ self.connection = Connection(
+ self.dc_id,
+ self.test_mode,
+ self.client.ipv6,
+ self.client.proxy,
+ self.is_media
)
try:
await self.connection.connect()
- self.recv_task = self.loop.create_task(self.recv_worker())
+ self.network_task = self.loop.create_task(self.network_worker())
await self.send(raw.functions.Ping(ping_id=0), timeout=self.START_TIMEOUT)
@@ -153,13 +149,14 @@ async def start(self):
self.ping_task = self.loop.create_task(self.ping_worker())
- log.info("Session initialized: Layer %s", layer)
- log.info("Device: %s - %s", self.client.device_model, self.client.app_version)
- log.info("System: %s (%s)", self.client.system_version, self.client.lang_code)
+ log.info(f"Session initialized: Layer {layer}")
+ log.info(f"Device: {self.client.device_model} - {self.client.app_version}")
+ log.info(f"System: {self.client.system_version} ({self.client.lang_code.upper()})")
+
except AuthKeyDuplicated as e:
await self.stop()
raise e
- except (OSError, RPCError):
+ except (OSError, TimeoutError, RPCError):
await self.stop()
except Exception as e:
await self.stop()
@@ -167,14 +164,12 @@ async def start(self):
else:
break
- self.is_started.set()
+ self.is_connected.set()
log.info("Session started")
async def stop(self):
- self.is_started.clear()
-
- self.stored_msg_ids.clear()
+ self.is_connected.clear()
self.ping_task_event.set()
@@ -183,24 +178,25 @@ async def stop(self):
self.ping_task_event.clear()
- await self.connection.close()
+ self.connection.close()
- if self.recv_task:
- await self.recv_task
+ if self.network_task:
+ await self.network_task
+
+ for i in self.results.values():
+ i.event.set()
if not self.is_media and callable(self.client.disconnect_handler):
try:
await self.client.disconnect_handler(self.client)
except Exception as e:
- log.exception(e)
+ log.error(e, exc_info=True)
log.info("Session stopped")
async def restart(self):
- self.restart_event.set()
await self.stop()
await self.start()
- self.restart_event.clear()
async def handle_packet(self, packet):
try:
@@ -210,11 +206,10 @@ async def handle_packet(self, packet):
BytesIO(packet),
self.session_id,
self.auth_key,
- self.auth_key_id
+ self.auth_key_id,
+ # self.stored_msg_ids
)
- except ValueError as e:
- log.debug(e)
- self.loop.create_task(self.restart())
+ except SecurityCheckMismatch:
return
messages = (
@@ -223,9 +218,15 @@ async def handle_packet(self, packet):
else [data]
)
- log.debug("Received: %s", data)
+ # Call log.debug twice because calling it once by appending "data" to the previous string (i.e. f"Kind: {data}")
+ # will cause "data" to be evaluated as string every time instead of only when debug is actually enabled.
+ log.debug("Received:")
+ log.debug(data)
for msg in messages:
+ # if msg.seq_no == 0:
+ # MsgId.set_server_time(msg.msg_id / (2 ** 32))
+
if msg.seq_no % 2 != 0:
if msg.msg_id in self.pending_acks:
continue
@@ -283,11 +284,11 @@ async def handle_packet(self, packet):
self.results[msg_id].event.set()
if len(self.pending_acks) >= self.ACKS_THRESHOLD:
- log.debug("Sending %s acks", len(self.pending_acks))
+ log.debug(f"Send {len(self.pending_acks)} acks")
try:
await self.send(raw.types.MsgsAck(msg_ids=list(self.pending_acks)), False)
- except OSError:
+ except (OSError, TimeoutError):
pass
else:
self.pending_acks.clear()
@@ -309,15 +310,12 @@ async def ping_worker(self):
ping_id=0, disconnect_delay=self.WAIT_TIMEOUT + 10
), False
)
- except OSError:
- self.loop.create_task(self.restart())
- break
- except RPCError:
+ except (OSError, TimeoutError, RPCError):
pass
log.info("PingTask stopped")
- async def recv_worker(self):
+ async def network_worker(self):
log.info("NetworkTask started")
while True:
@@ -325,19 +323,9 @@ async def recv_worker(self):
if packet is None or len(packet) == 4:
if packet:
- error_code = -Int.read(BytesIO(packet))
-
- # if error_code == 404:
- # raise Unauthorized(
- # "Auth key not found in the system. You must delete your session file "
- # "and log in again with your phone number or bot token."
- # )
- log.warning(
- "Server sent transport error: %s (%s)",
- error_code, Session.TRANSPORT_ERRORS.get(error_code, "unknown error")
- )
+ log.warning(f'Server sent "{Int.read(BytesIO(packet))}"')
- if self.is_started.is_set():
+ if self.is_connected.is_set():
self.loop.create_task(self.restart())
break
@@ -358,7 +346,10 @@ async def send(
if wait_response:
self.results[msg_id] = Result()
- log.debug("Sent: %s", message)
+ # Call log.debug twice because calling it once by appending "data" to the previous string (i.e. f"Kind: {data}")
+ # will cause "data" to be evaluated as string every time instead of only when debug is actually enabled.
+ log.debug(f"Sent:")
+ log.debug(message)
payload = await self.loop.run_in_executor(
pyrogram.crypto_executor,
@@ -381,26 +372,23 @@ async def send(
await asyncio.wait_for(self.results[msg_id].event.wait(), timeout)
except asyncio.TimeoutError:
pass
-
- result = self.results.pop(msg_id).value
+ finally:
+ result = self.results.pop(msg_id).value
if result is None:
- raise TimeoutError("Request timed out")
-
- if isinstance(result, raw.types.RpcError):
+ raise TimeoutError
+ elif isinstance(result, raw.types.RpcError):
if isinstance(data, Session.CUR_ALWD_INNR_QRYS):
data = data.query
RPCError.raise_it(result, type(data))
-
- if isinstance(result, raw.types.BadMsgNotification):
- log.warning("%s: %s", BadMsgNotification.__name__, BadMsgNotification(result.error_code))
-
- if isinstance(result, raw.types.BadServerSalt):
+ elif isinstance(result, raw.types.BadMsgNotification):
+ raise BadMsgNotification(result.error_code)
+ elif isinstance(result, raw.types.BadServerSalt):
self.salt = result.new_server_salt
return await self.send(data, wait_response, timeout)
-
- return result
+ else:
+ return result
async def invoke(
self,
@@ -410,7 +398,7 @@ async def invoke(
sleep_threshold: float = SLEEP_THRESHOLD
):
try:
- await asyncio.wait_for(self.is_started.wait(), self.WAIT_TIMEOUT)
+ await asyncio.wait_for(self.is_connected.wait(), self.WAIT_TIMEOUT)
except asyncio.TimeoutError:
pass
@@ -430,46 +418,17 @@ async def invoke(
if amount > sleep_threshold >= 0:
raise
- log.warning('[%s] Waiting for %s seconds before continuing (required by "%s")',
- self.client.name, amount, query_name)
+ log.warning(f'[{self.client.name}] Waiting for {amount} seconds before continuing '
+ f'(required by "{query_name}")')
await asyncio.sleep(amount)
- except (
- OSError,
- RuntimeError,
- InternalServerError,
- ServiceUnavailable,
- TimeoutError,
- ) as e:
- retries -= 1
- if (
- retries == 0 or
- (
- isinstance(e, InternalServerError)
- and getattr(e, "code", 0) == 500
- and (e.ID or e.NAME) in [
- "HISTORY_GET_FAILED"
- ]
- )
- ):
+ except (OSError, TimeoutError, InternalServerError, ServiceUnavailable) as e:
+ if retries == 0:
raise e from None
(log.warning if retries < 2 else log.info)(
- '[%s] Retrying "%s" due to: %s',
- Session.MAX_RETRIES - retries + 1,
- query_name, str(e) or repr(e)
- )
+ f'[{Session.MAX_RETRIES - retries + 1}] Retrying "{query_name}" due to {str(e) or repr(e)}')
- # restart was never being called after Exception block
- if not self.restart_event.is_set():
- self.loop.create_task(self.restart())
- else:
- # multiple Exceptions can be raised in a row, so we need to wait for the restart to finish
- try:
- await asyncio.wait_for(self.restart_event.wait(), self.WAIT_TIMEOUT)
- except asyncio.TimeoutError:
- pass
-
await asyncio.sleep(0.5)
return await self.invoke(query, retries - 1, timeout)