Skip to content

Commit

Permalink
Refactor loop
Browse files Browse the repository at this point in the history
  • Loading branch information
KurimuzonAkuma committed Mar 5, 2025
1 parent 7d1e536 commit aae9bfc
Show file tree
Hide file tree
Showing 14 changed files with 89 additions and 50 deletions.
38 changes: 24 additions & 14 deletions pyrogram/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,9 @@ class Client(Methods):
Pass True to automatically fetch stories if they are missing.
Defaults to True.
loop (:py:class:`asyncio.AbstractEventLoop`, *optional*):
Event loop.
init_connection_params (:obj:`~pyrogram.raw.base.JSONValue`, *optional*):
Additional initConnection parameters.
For now, only the tz_offset field is supported, for specifying timezone offset in seconds.
Expand Down Expand Up @@ -285,7 +288,8 @@ def __init__(
fetch_stories: Optional[bool] = True,
init_connection_params: Optional["raw.base.JSONValue"] = None,
connection_factory: Type[Connection] = Connection,
protocol_factory: Type[TCP] = TCPAbridged
protocol_factory: Type[TCP] = TCPAbridged,
loop: Optional[asyncio.AbstractEventLoop] = None
):
super().__init__()

Expand Down Expand Up @@ -376,7 +380,13 @@ def __init__(
self.updates_watchdog_event = asyncio.Event()
self.last_update_time = datetime.now()

self.loop = asyncio.get_event_loop()
if isinstance(loop, asyncio.AbstractEventLoop):
self.loop = loop
else:
try:
self.loop = asyncio.get_running_loop()
except RuntimeError:
self.loop = asyncio.new_event_loop()

def __enter__(self):
return self.start()
Expand Down Expand Up @@ -420,12 +430,12 @@ async def authorize(self) -> User:
try:
if not self.phone_number:
while True:
value = await ainput("Enter phone number or bot token: ")
value = await ainput("Enter phone number or bot token: ", loop=self.loop)

if not value:
continue

confirm = (await ainput(f'Is "{value}" correct? (y/N): ')).lower()
confirm = (await ainput(f'Is "{value}" correct? (y/N): ', loop=self.loop)).lower()

if confirm == "y":
break
Expand Down Expand Up @@ -457,7 +467,7 @@ async def authorize(self) -> User:

while True:
if not self.phone_code:
self.phone_code = await ainput("Enter confirmation code: ")
self.phone_code = await ainput("Enter confirmation code: ", loop=self.loop)

try:
signed_in = await self.sign_in(self.phone_number, sent_code.phone_code_hash, self.phone_code)
Expand All @@ -471,18 +481,18 @@ async def authorize(self) -> User:
print("Password hint: {}".format(await self.get_password_hint()))

if not self.password:
self.password = await ainput("Enter password (empty to recover): ", hide=self.hide_password)
self.password = await ainput("Enter password (empty to recover): ", hide=self.hide_password, loop=self.loop)

try:
if not self.password:
confirm = await ainput("Confirm password recovery (y/n): ")
confirm = await ainput("Confirm password recovery (y/n): ", loop=self.loop)

if confirm == "y":
email_pattern = await self.send_recovery_code()
print(f"The recovery code has been sent to {email_pattern}")

while True:
recovery_code = await ainput("Enter recovery code: ")
recovery_code = await ainput("Enter recovery code: ", loop=self.loop)

try:
return await self.recover_password(recovery_code)
Expand All @@ -505,8 +515,8 @@ async def authorize(self) -> User:
return signed_in

while True:
first_name = await ainput("Enter first name: ")
last_name = await ainput("Enter last name (empty to skip): ")
first_name = await ainput("Enter first name: ", loop=self.loop)
last_name = await ainput("Enter last name (empty to skip): ", loop=self.loop)

try:
signed_up = await self.sign_up(
Expand Down Expand Up @@ -555,7 +565,7 @@ async def authorize_qr(self, except_ids: List[int] = []) -> User:
except SessionPasswordNeeded:
print(f"Password hint: {await self.get_password_hint()}")
return await self.check_password(
await ainput("Enter 2FA password: ", hide=self.hide_password)
await ainput("Enter 2FA password: ", hide=self.hide_password, loop=self.loop)
)

def set_parse_mode(self, parse_mode: Optional["enums.ParseMode"]):
Expand Down Expand Up @@ -873,13 +883,13 @@ async def load_session(self):
else:
while True:
try:
value = int(await ainput("Enter the api_id part of the API key: "))
value = int(await ainput("Enter the api_id part of the API key: ", loop=self.loop))

if value <= 0:
print("Invalid value")
continue

confirm = (await ainput(f'Is "{value}" correct? (y/N): ')).lower()
confirm = (await ainput(f'Is "{value}" correct? (y/N): ', loop=self.loop)).lower()

if confirm == "y":
await self.storage.api_id(value)
Expand Down Expand Up @@ -1261,7 +1271,7 @@ async def get_file(
def guess_mime_type(self, filename: Union[str, BytesIO]) -> Optional[str]:
if isinstance(filename, BytesIO):
return self.mimetypes.guess_type(filename.name)[0]

return self.mimetypes.guess_type(filename)[0]

def guess_extension(self, mime_type: str) -> Optional[str]:
Expand Down
13 changes: 11 additions & 2 deletions pyrogram/connection/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def __init__(
ipv6: bool,
proxy: dict,
media: bool = False,
protocol_factory: Type[TCP] = TCPAbridged
protocol_factory: Type[TCP] = TCPAbridged,
loop: Optional[asyncio.AbstractEventLoop] = None
) -> None:
self.dc_id = dc_id
self.test_mode = test_mode
Expand All @@ -48,9 +49,17 @@ def __init__(
self.address = DataCenter(dc_id, test_mode, ipv6, media)
self.protocol: Optional[TCP] = None

if isinstance(loop, asyncio.AbstractEventLoop):
self.loop = loop
else:
try:
self.loop = asyncio.get_running_loop()
except RuntimeError:
self.loop = asyncio.new_event_loop()

async def connect(self) -> None:
for i in range(Connection.MAX_CONNECTION_ATTEMPTS):
self.protocol = self.protocol_factory(ipv6=self.ipv6, proxy=self.proxy)
self.protocol = self.protocol_factory(ipv6=self.ipv6, proxy=self.proxy, loop=self.loop)

try:
log.info("Connecting...")
Expand Down
11 changes: 9 additions & 2 deletions pyrogram/connection/transport/tcp/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,22 @@ class Proxy(TypedDict):
class TCP:
TIMEOUT = 10

def __init__(self, ipv6: bool, proxy: Proxy) -> None:
def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
self.ipv6 = ipv6
self.proxy = proxy

self.reader: Optional[asyncio.StreamReader] = None
self.writer: Optional[asyncio.StreamWriter] = None

self.lock = asyncio.Lock()
self.loop = asyncio.get_event_loop()

if isinstance(loop, asyncio.AbstractEventLoop):
self.loop = loop
else:
try:
self.loop = asyncio.get_running_loop()
except RuntimeError:
self.loop = asyncio.new_event_loop()

async def _connect_via_proxy(
self,
Expand Down
5 changes: 3 additions & 2 deletions pyrogram/connection/transport/tcp/tcp_abridged.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.

import asyncio
import logging
from typing import Optional, Tuple

Expand All @@ -25,8 +26,8 @@


class TCPAbridged(TCP):
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
super().__init__(ipv6, proxy)
def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
super().__init__(ipv6, proxy, loop)

async def connect(self, address: Tuple[str, int]) -> None:
await super().connect(address)
Expand Down
5 changes: 3 additions & 2 deletions pyrogram/connection/transport/tcp/tcp_abridged_o.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.

import asyncio
import logging
import os
from typing import Optional, Tuple
Expand All @@ -30,8 +31,8 @@
class TCPAbridgedO(TCP):
RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4)

def __init__(self, ipv6: bool, proxy: Proxy) -> None:
super().__init__(ipv6, proxy)
def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
super().__init__(ipv6, proxy, loop)

self.encrypt = None
self.decrypt = None
Expand Down
5 changes: 3 additions & 2 deletions pyrogram/connection/transport/tcp/tcp_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.

import asyncio
import logging
from binascii import crc32
from struct import pack, unpack
Expand All @@ -27,8 +28,8 @@


class TCPFull(TCP):
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
super().__init__(ipv6, proxy)
def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
super().__init__(ipv6, proxy, loop)

self.seq_no: Optional[int] = None

Expand Down
5 changes: 3 additions & 2 deletions pyrogram/connection/transport/tcp/tcp_intermediate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.

import asyncio
import logging
from struct import pack, unpack
from typing import Optional, Tuple
Expand All @@ -26,8 +27,8 @@


class TCPIntermediate(TCP):
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
super().__init__(ipv6, proxy)
def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
super().__init__(ipv6, proxy, loop)

async def connect(self, address: Tuple[str, int]) -> None:
await super().connect(address)
Expand Down
5 changes: 3 additions & 2 deletions pyrogram/connection/transport/tcp/tcp_intermediate_o.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.

import asyncio
import logging
import os
from struct import pack, unpack
Expand All @@ -30,8 +31,8 @@
class TCPIntermediateO(TCP):
RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4)

def __init__(self, ipv6: bool, proxy: Proxy) -> None:
super().__init__(ipv6, proxy)
def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
super().__init__(ipv6, proxy, loop)

self.encrypt = None
self.decrypt = None
Expand Down
9 changes: 4 additions & 5 deletions pyrogram/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ class Dispatcher:

def __init__(self, client: "pyrogram.Client"):
self.client = client
self.loop = asyncio.get_event_loop()

self.handler_worker_tasks = []
self.locks_list = []
Expand Down Expand Up @@ -218,7 +217,7 @@ async def start(self):
self.locks_list.append(asyncio.Lock())

self.handler_worker_tasks.append(
self.loop.create_task(self.handler_worker(self.locks_list[-1]))
self.client.loop.create_task(self.handler_worker(self.locks_list[-1]))
)

log.info("Started %s HandlerTasks", self.client.workers)
Expand Down Expand Up @@ -255,7 +254,7 @@ async def fn():
for lock in self.locks_list:
lock.release()

self.loop.create_task(fn())
self.client.loop.create_task(fn())

def remove_handler(self, handler, group: int):
async def fn():
Expand All @@ -271,7 +270,7 @@ async def fn():
for lock in self.locks_list:
lock.release()

self.loop.create_task(fn())
self.client.loop.create_task(fn())

async def handler_worker(self, lock):
while True:
Expand Down Expand Up @@ -318,7 +317,7 @@ async def handler_worker(self, lock):
if inspect.iscoroutinefunction(handler.callback):
await handler.callback(self.client, *args)
else:
await self.loop.run_in_executor(
await self.client.loop.run_in_executor(
self.client.executor,
handler.callback,
self.client,
Expand Down
2 changes: 1 addition & 1 deletion pyrogram/methods/messages/download_media.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,4 +284,4 @@ async def progress(current, total):
if block:
return await downloader
else:
asyncio.get_event_loop().create_task(downloader)
self.loop.create_task(downloader)
2 changes: 1 addition & 1 deletion pyrogram/methods/stories/get_stories.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class GetStories:
async def get_stories(
self: "pyrogram.Client",
chat_id: Optional[Union[int, str]] = None,
story_ids: Optional[Union[int, Iterable[int]]] = None,
story_ids: Optional[Union[int, Iterable[int], str]] = None,
) -> Optional[Union["types.Story", List["types.Story"]]]:
"""Get one or more stories from a chat by using stories identifiers.
Expand Down
4 changes: 3 additions & 1 deletion pyrogram/session/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(
self.proxy = client.proxy
self.connection_factory = client.connection_factory
self.protocol_factory = client.protocol_factory
self.loop = client.loop

self.connection: Optional[Connection] = None

Expand Down Expand Up @@ -90,7 +91,8 @@ async def create(self):
ipv6=self.ipv6,
proxy=self.proxy,
media=False,
protocol_factory=self.protocol_factory
protocol_factory=self.protocol_factory,
loop=self.loop
)

try:
Expand Down
Loading

0 comments on commit aae9bfc

Please sign in to comment.