From 97262e1efc7ae220bdef52c8ef0e5af756f77699 Mon Sep 17 00:00:00 2001 From: Ollie <69084614+olijeffers0n@users.noreply.github.com> Date: Wed, 11 Jan 2023 15:32:06 +0000 Subject: [PATCH] Per RustSocket eventloops --- rustplus/api/base_rust_api.py | 11 +++++------ rustplus/api/remote/events/event_handler.py | 8 ++++---- .../api/remote/events/event_loop_manager.py | 18 +++++++++++++----- .../api/remote/events/map_event_listener.py | 4 ++-- rustplus/api/remote/rust_remote_interface.py | 2 +- rustplus/api/remote/rustws.py | 2 +- rustplus/commands/command_handler.py | 7 ++++--- rustplus/conversation/conversation.py | 2 +- 8 files changed, 31 insertions(+), 23 deletions(-) diff --git a/rustplus/api/base_rust_api.py b/rustplus/api/base_rust_api.py index 8c61452..599ce66 100644 --- a/rustplus/api/base_rust_api.py +++ b/rustplus/api/base_rust_api.py @@ -114,10 +114,9 @@ async def connect( :return: None """ - if self.event_loop is None: - EventLoopManager._loop = asyncio.get_event_loop() - else: - EventLoopManager._loop = self.event_loop + EventLoopManager.set_loop(self.event_loop if self.event_loop is not None else asyncio.get_event_loop(), + self.server_id) + try: if self.remote.ws is None: await self.remote.connect( @@ -215,7 +214,7 @@ async def switch_server( self.remote.command_handler.command_options = command_options else: self.remote.use_commands = True - self.remote.command_handler = CommandHandler(self.command_options) + self.remote.command_handler = CommandHandler(self.command_options, self) self.raise_ratelimit_exception = raise_ratelimit_exception @@ -354,7 +353,7 @@ def entity_event_callback(future_inner: Future): ), self.server_id ) - future = asyncio.run_coroutine_threadsafe(get_entity(self, eid), EventLoopManager.get_loop()) + future = asyncio.run_coroutine_threadsafe(get_entity(self, eid), EventLoopManager.get_loop(self.server_id)) future.add_done_callback(entity_event_callback) return RegisteredListener(eid, coro) diff --git a/rustplus/api/remote/events/event_handler.py b/rustplus/api/remote/events/event_handler.py index e54a06f..579e58a 100644 --- a/rustplus/api/remote/events/event_handler.py +++ b/rustplus/api/remote/events/event_handler.py @@ -28,7 +28,7 @@ def run_entity_event(self, name, app_message, server_id) -> None: for handler in handlers.copy(): coro, event_type = handler.data - self._schedule_event(EventLoopManager.get_loop(), coro, EntityEvent(app_message, event_type)) + self._schedule_event(EventLoopManager.get_loop(server_id), coro, EntityEvent(app_message, event_type)) def run_team_event(self, app_message, server_id) -> None: @@ -36,7 +36,7 @@ def run_team_event(self, app_message, server_id) -> None: for handler in handlers.copy(): coro = handler.data - self._schedule_event(EventLoopManager.get_loop(), coro, TeamEvent(app_message)) + self._schedule_event(EventLoopManager.get_loop(server_id), coro, TeamEvent(app_message)) def run_chat_event(self, app_message, server_id) -> None: @@ -44,7 +44,7 @@ def run_chat_event(self, app_message, server_id) -> None: for handler in handlers.copy(): coro = handler.data - self._schedule_event(EventLoopManager.get_loop(), coro, ChatEvent(app_message)) + self._schedule_event(EventLoopManager.get_loop(server_id), coro, ChatEvent(app_message)) def run_proto_event(self, byte_data: bytes, server_id) -> None: @@ -52,4 +52,4 @@ def run_proto_event(self, byte_data: bytes, server_id) -> None: for handler in handlers.copy(): coro = handler.data - self._schedule_event(EventLoopManager.get_loop(), coro, ProtobufEvent(byte_data)) + self._schedule_event(EventLoopManager.get_loop(server_id), coro, ProtobufEvent(byte_data)) diff --git a/rustplus/api/remote/events/event_loop_manager.py b/rustplus/api/remote/events/event_loop_manager.py index d48951a..2353d00 100644 --- a/rustplus/api/remote/events/event_loop_manager.py +++ b/rustplus/api/remote/events/event_loop_manager.py @@ -1,13 +1,21 @@ +import asyncio +from ....utils import ServerID + + class EventLoopManager: - _loop = None + _loop = {} @staticmethod - def get_loop(): - if EventLoopManager._loop is None: + def get_loop(server_id: ServerID) -> asyncio.AbstractEventLoop: + if EventLoopManager._loop is None or EventLoopManager._loop.get(server_id) is None: raise RuntimeError("Event loop is not set") - if EventLoopManager._loop.is_closed(): + if EventLoopManager._loop.get(server_id).is_closed(): raise RuntimeError("Event loop is not running") - return EventLoopManager._loop + return EventLoopManager._loop.get(server_id) + + @staticmethod + def set_loop(loop: asyncio.AbstractEventLoop, server_id: ServerID) -> None: + EventLoopManager._loop[server_id] = loop diff --git a/rustplus/api/remote/events/map_event_listener.py b/rustplus/api/remote/events/map_event_listener.py index 630616c..b4615b4 100644 --- a/rustplus/api/remote/events/map_event_listener.py +++ b/rustplus/api/remote/events/map_event_listener.py @@ -42,7 +42,7 @@ def _run(self) -> None: try: future = asyncio.run_coroutine_threadsafe( - self.api.get_markers(), EventLoopManager.get_loop() + self.api.get_markers(), EventLoopManager.get_loop(self.api.server_id) ) new_highest_id = 0 for marker in future.result(): @@ -80,7 +80,7 @@ def _run(self) -> None: def call_event(self, marker, is_new) -> None: for listener in self.listeners: asyncio.run_coroutine_threadsafe( - listener.get_coro()(MarkerEvent(marker, is_new)), EventLoopManager.get_loop() + listener.get_coro()(MarkerEvent(marker, is_new)), EventLoopManager.get_loop(self.api.server_id) ).result() diff --git a/rustplus/api/remote/rust_remote_interface.py b/rustplus/api/remote/rust_remote_interface.py index c9342bf..8854488 100644 --- a/rustplus/api/remote/rust_remote_interface.py +++ b/rustplus/api/remote/rust_remote_interface.py @@ -50,7 +50,7 @@ def __init__( self.use_commands = False else: self.use_commands = True - self.command_handler = CommandHandler(self.command_options) + self.command_handler = CommandHandler(self.command_options, api) self.event_handler = EventHandler() diff --git a/rustplus/api/remote/rustws.py b/rustplus/api/remote/rustws.py index 3f5b7bd..3c5417b 100644 --- a/rustplus/api/remote/rustws.py +++ b/rustplus/api/remote/rustws.py @@ -163,7 +163,7 @@ def run(self) -> None: f"{datetime.now().strftime('%d/%m/%Y %H:%M:%S')} [RustPlus.py] Connection interrupted, Retrying" ) asyncio.run_coroutine_threadsafe( - self.connect(ignore_open_value=True), EventLoopManager.get_loop() + self.connect(ignore_open_value=True), EventLoopManager.get_loop(self.server_id) ).result() continue return diff --git a/rustplus/commands/command_handler.py b/rustplus/commands/command_handler.py index 7c99c57..6643134 100644 --- a/rustplus/commands/command_handler.py +++ b/rustplus/commands/command_handler.py @@ -10,9 +10,10 @@ class CommandHandler: - def __init__(self, command_options: CommandOptions) -> None: + def __init__(self, command_options: CommandOptions, api) -> None: self.command_options = command_options self.commands = {} + self.api = api def register_command(self, data: CommandData) -> None: @@ -40,7 +41,7 @@ def run_command(self, message: RustChatMessage, prefix) -> None: data = self.commands[command] self._schedule_event( - EventLoopManager.get_loop(), + EventLoopManager.get_loop(self.api.server_id), data.coro, Command( message.name, @@ -57,7 +58,7 @@ def run_command(self, message: RustChatMessage, prefix) -> None: if command in data.aliases or data.callable_func(command): self._schedule_event( - EventLoopManager.get_loop(), + EventLoopManager.get_loop(self.api.server_id), data.coro, Command( message.name, diff --git a/rustplus/conversation/conversation.py b/rustplus/conversation/conversation.py index 9985e27..ada28d6 100644 --- a/rustplus/conversation/conversation.py +++ b/rustplus/conversation/conversation.py @@ -55,7 +55,7 @@ async def start(self) -> None: await self.send_prompt(await self._prompts[0].prompt()) def run_coro(self, coro, args): - return asyncio.run_coroutine_threadsafe(coro(*args), EventLoopManager.get_loop()).result() + return asyncio.run_coroutine_threadsafe(coro(*args), EventLoopManager.get_loop(self._api.server_id)).result() def get_answers(self) -> List[str]: return self._answers