diff --git a/requirements.txt b/requirements.txt index 28192b4..083b1f0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -websocket_client +websockets Pillow asyncio rustPlusPushReceiver==0.4.1 diff --git a/rustplus/__init__.py b/rustplus/__init__.py index dbee633..2bb7fb7 100644 --- a/rustplus/__init__.py +++ b/rustplus/__init__.py @@ -22,5 +22,5 @@ __name__ = "rustplus" __author__ = "olijeffers0n" -__version__ = "5.5.13" +__version__ = "5.6.0" __support__ = "Discord: https://discord.gg/nQqJe8qvP8" diff --git a/rustplus/api/base_rust_api.py b/rustplus/api/base_rust_api.py index 8bdd120..76ba734 100644 --- a/rustplus/api/base_rust_api.py +++ b/rustplus/api/base_rust_api.py @@ -39,7 +39,6 @@ def __init__( event_loop: asyncio.AbstractEventLoop = None, rate_limiter: RateLimiter = None, ) -> None: - if ip is None: raise ValueError("Ip cannot be None") if steam_id is None: @@ -80,16 +79,15 @@ async def _handle_ratelimit(self, amount=1) -> None: :return: None """ while True: - - if self.remote.ratelimiter.can_consume(self.server_id, amount): - self.remote.ratelimiter.consume(self.server_id, amount) + if await self.remote.ratelimiter.can_consume(self.server_id, amount): + await self.remote.ratelimiter.consume(self.server_id, amount) break if self.raise_ratelimit_exception: raise RateLimitError("Out of tokens") await asyncio.sleep( - self.remote.ratelimiter.get_estimated_delay_time(self.server_id, amount) + await self.remote.ratelimiter.get_estimated_delay_time(self.server_id, amount) ) self.heartbeat.reset_rhythm() @@ -151,7 +149,7 @@ async def close_connection(self) -> None: :return: None """ - self.remote.close() + await self.remote.close() async def disconnect(self) -> None: """ @@ -173,7 +171,7 @@ async def send_wakeup_request(self) -> None: app_request.get_time = AppEmpty() app_request.get_time._serialized_on_wire = True - self.remote.ignored_responses.append(app_request.seq) + await self.remote.add_ignored_response(app_request.seq) await self.remote.send_message(app_request) @@ -238,7 +236,8 @@ async def switch_server( self.remote.server_id = ServerID(ip, port, steam_id, player_token) # reset ratelimiter - self.remote.ratelimiter.remove(self.server_id) + self.remote.use_proxy = use_proxy + await self.remote.ratelimiter.remove(self.server_id) self.remote.ratelimiter.add_socket( self.server_id, self.ratelimit_limit, @@ -284,7 +283,6 @@ def command( return RegisteredListener(coro.__name__, cmd_data.coro) def wrap_func(coro): - if self.command_options is None: raise CommandsNotEnabledError("Not enabled") @@ -341,7 +339,6 @@ def entity_event(self, eid): """ def wrap_func(coro) -> RegisteredListener: - if isinstance(coro, RegisteredListener): coro = coro.get_coro() @@ -588,14 +585,14 @@ async def get_tc_storage_contents( """ raise NotImplementedError("Not Implemented") - async def get_camera_manager(self, id: str) -> CameraManager: + async def get_camera_manager(self, cam_id: str) -> CameraManager: """ Gets a camera manager for a given camera ID NOTE: This will override the current camera manager if one exists for the given ID so you cannot have multiple - :param id: The ID of the camera + :param cam_id: The ID of the camera :return CameraManager: The camera manager - :raises RequestError: If the camera is not found or you cannot access it. See reason for more info + :raises RequestError: If the camera is not found, or you cannot access it. See reason for more info """ raise NotImplementedError("Not Implemented") diff --git a/rustplus/api/remote/camera/camera_manager.py b/rustplus/api/remote/camera/camera_manager.py index 265cf1a..bc70ffb 100644 --- a/rustplus/api/remote/camera/camera_manager.py +++ b/rustplus/api/remote/camera/camera_manager.py @@ -1,10 +1,9 @@ import time -from typing import Iterable, Union, List, Coroutine, TypeVar, Set +from typing import Iterable, Union, List, Coroutine, TypeVar, Set, Callable from PIL import Image from .camera_parser import Parser -from ..events import EventLoopManager, EventHandler from ..rustplus_proto import ( AppCameraInput, Vector2, @@ -31,31 +30,29 @@ def __init__( self._cam_info_message.width, self._cam_info_message.height ) self.time_since_last_subscribe: float = time.time() - self.frame_callbacks: Set[Coroutine] = set() + self.frame_callbacks: Set[Callable[[Image.Image], Coroutine]] = set() - def add_packet(self, packet) -> None: + async def add_packet(self, packet) -> None: self._last_packets.add(packet) if len(self.frame_callbacks) == 0: return - frame = self._create_frame() + frame = await self._create_frame() for callback in self.frame_callbacks: - EventHandler.schedule_event( - EventLoopManager.get_loop(self.rust_socket.server_id), - callback, - frame, - ) + await callback(frame) - def on_frame_received(self, coro: Coroutine) -> Coroutine: + def on_frame_received( + self, coro: Callable[[Image.Image], Coroutine] + ) -> Callable[[Image.Image], Coroutine]: self.frame_callbacks.add(coro) return coro def has_frame_data(self) -> bool: return self._last_packets is not None and len(self._last_packets) > 0 - def _create_frame( + async def _create_frame( self, render_entities: bool = True, entity_render_distance: float = float("inf"), @@ -96,7 +93,7 @@ async def get_frame( entity_render_distance: float = float("inf"), max_entity_amount: int = float("inf"), ) -> Union[Image.Image, None]: - return self._create_frame( + return await self._create_frame( render_entities, entity_render_distance, max_entity_amount ) @@ -115,7 +112,6 @@ async def send_mouse_movement(self, mouse_delta: Vector) -> None: async def send_combined_movement( self, movements: Iterable[int] = None, joystick_vector: Vector = None ) -> None: - if joystick_vector is None: joystick_vector = Vector() @@ -138,7 +134,7 @@ async def send_combined_movement( app_request.camera_input = cam_input await self.rust_socket.remote.send_message(app_request) - self.rust_socket.remote.ignored_responses.append(app_request.seq) + await self.rust_socket.remote.add_ignored_response(app_request.seq) async def exit_camera(self) -> None: await self.rust_socket._handle_ratelimit() @@ -147,7 +143,7 @@ async def exit_camera(self) -> None: app_request.camera_unsubscribe._serialized_on_wire = True await self.rust_socket.remote.send_message(app_request) - self.rust_socket.remote.ignored_responses.append(app_request.seq) + await self.rust_socket.remote.add_ignored_response(app_request.seq) self._open = False self._last_packets.clear() diff --git a/rustplus/api/remote/camera/camera_parser.py b/rustplus/api/remote/camera/camera_parser.py index ea353aa..c65d39e 100644 --- a/rustplus/api/remote/camera/camera_parser.py +++ b/rustplus/api/remote/camera/camera_parser.py @@ -54,7 +54,6 @@ def reset_output(self) -> None: ) def handle_camera_ray_data(self, data) -> None: - if data is None: return @@ -66,7 +65,6 @@ def handle_camera_ray_data(self, data) -> None: self._ray_lookback = [[0 for _ in range(3)] for _ in range(64)] def step(self) -> None: - if self._rays is None: return @@ -80,7 +78,6 @@ def process_rays_batch(self) -> bool: return True for h in range(100): - if self.data_pointer >= len(self._rays.ray_data) - 1: return True @@ -107,7 +104,6 @@ def process_rays_batch(self) -> bool: not (distance == 1 and alignment == 0 and material == 0) and material != 7 ): - self.colour_output[ x : x + self.scale_factor, y : y + self.scale_factor ] = MathUtils._convert_colour( @@ -153,7 +149,6 @@ def next_ray(self, ray_data) -> List[Union[float, int]]: self._ray_lookback[u][2] = i else: - c = 192 & byte if c == 0: @@ -217,7 +212,6 @@ def handle_entities( entity_render_distance: float, max_entity_amount: int, ) -> Any: - image_data = np.array(image_data) players = [player for player in entities if player.type == 2] @@ -265,7 +259,6 @@ def handle_entities( text = set() for entity in entities: - if entity.position.z > entity_render_distance and entity.type == 1: continue @@ -307,7 +300,6 @@ def handle_entity( aspect_ratio, text, ) -> None: - entity.size.x = min(entity.size.x, 5) entity.size.y = min(entity.size.y, 5) entity.size.z = min(entity.size.z, 5) @@ -418,7 +410,6 @@ def render( entity_render_distance: float, max_entity_amount: int, ) -> Image.Image: - # We have the output array filled with RayData objects # We can get the material at each pixel and use that to get the colour # We can then use the alignment to get the alpha value @@ -556,7 +547,6 @@ def _convert_colour( cls, colour: Tuple[float, float, float, float], ) -> Tuple[int, int, int]: - if colour in cls.COLOUR_CACHE: return cls.COLOUR_CACHE[colour] @@ -589,7 +579,6 @@ def solve_quadratic(a: float, b: float, c: float, larger: bool) -> float: @classmethod def get_tree_vertices(cls, size) -> np.ndarray: - if size in cls.VERTEX_CACHE: return cls.VERTEX_CACHE[size] @@ -599,7 +588,6 @@ def get_tree_vertices(cls, size) -> np.ndarray: vertex_list = [] for x_value in [size.y / 8, -size.y / 8]: - for i in range(number_of_segments): angle = segment_angle * i @@ -616,7 +604,6 @@ def get_tree_vertices(cls, size) -> np.ndarray: @classmethod def get_player_vertices(cls, size) -> np.ndarray: - if size in cls.VERTEX_CACHE: return cls.VERTEX_CACHE[size] @@ -633,9 +620,7 @@ def get_player_vertices(cls, size) -> np.ndarray: x = 0 while x <= width: - for offset in range(-1, 2, 2): - x_value = x * offset # Use the quadratic formula to find the y values of the top and bottom of the pill diff --git a/rustplus/api/remote/events/event_handler.py b/rustplus/api/remote/events/event_handler.py index 1088437..7d65fe7 100644 --- a/rustplus/api/remote/events/event_handler.py +++ b/rustplus/api/remote/events/event_handler.py @@ -1,31 +1,16 @@ -import asyncio -import logging -from asyncio.futures import Future -from typing import Set, Coroutine, Any +from typing import Set, Union from ....utils import ServerID from .events import EntityEvent, TeamEvent, ChatEvent, ProtobufEvent from .registered_listener import RegisteredListener -from .event_loop_manager import EventLoopManager from ..rustplus_proto import AppMessage class EventHandler: @staticmethod - def schedule_event( - loop: asyncio.AbstractEventLoop, coro: Coroutine, arg: Any + async def run_entity_event( + name: Union[str, int], app_message: AppMessage, server_id: ServerID ) -> None: - def callback(inner_future: Future): - if inner_future.exception() is not None: - logging.getLogger("rustplus.py").exception(inner_future.exception()) - - future: Future = asyncio.run_coroutine_threadsafe(coro(arg), loop) - future.add_done_callback(callback) - - def run_entity_event( - self, name: str, app_message: AppMessage, server_id: ServerID - ) -> None: - handlers: Set[RegisteredListener] = EntityEvent.handlers.get_handlers( server_id ).get(str(name)) @@ -36,40 +21,30 @@ def run_entity_event( for handler in handlers.copy(): coro, event_type = handler.data - self.schedule_event( - EventLoopManager.get_loop(server_id), - coro, - EntityEvent(app_message, event_type), - ) - - def run_team_event(self, app_message: AppMessage, server_id: ServerID) -> None: + await coro(EntityEvent(app_message, event_type)) + @staticmethod + async def run_team_event(app_message: AppMessage, server_id: ServerID) -> None: handlers: Set[RegisteredListener] = TeamEvent.handlers.get_handlers(server_id) for handler in handlers.copy(): coro = handler.data - self.schedule_event( - EventLoopManager.get_loop(server_id), coro, TeamEvent(app_message) - ) - - def run_chat_event(self, app_message: AppMessage, server_id: ServerID) -> None: + await coro(TeamEvent(app_message)) + @staticmethod + async def run_chat_event(app_message: AppMessage, server_id: ServerID) -> None: handlers: Set[RegisteredListener] = ChatEvent.handlers.get_handlers(server_id) for handler in handlers.copy(): coro = handler.data - self.schedule_event( - EventLoopManager.get_loop(server_id), coro, ChatEvent(app_message) - ) - - def run_proto_event(self, byte_data: bytes, server_id: ServerID) -> None: + await coro(ChatEvent(app_message)) + @staticmethod + async def run_proto_event(byte_data: bytes, server_id: ServerID) -> None: handlers: Set[RegisteredListener] = ProtobufEvent.handlers.get_handlers( server_id ) for handler in handlers.copy(): coro = handler.data - self.schedule_event( - EventLoopManager.get_loop(server_id), coro, ProtobufEvent(byte_data) - ) + await coro(ProtobufEvent(byte_data)) diff --git a/rustplus/api/remote/events/event_loop_manager.py b/rustplus/api/remote/events/event_loop_manager.py index 63e3cb1..16392cd 100644 --- a/rustplus/api/remote/events/event_loop_manager.py +++ b/rustplus/api/remote/events/event_loop_manager.py @@ -5,7 +5,6 @@ class EventLoopManager: - _loop: Dict[ServerID, asyncio.AbstractEventLoop] = {} @staticmethod diff --git a/rustplus/api/remote/events/events.py b/rustplus/api/remote/events/events.py index 76b4052..6a6afaa 100644 --- a/rustplus/api/remote/events/events.py +++ b/rustplus/api/remote/events/events.py @@ -27,7 +27,6 @@ def item_is_blueprint(self) -> bool: class TeamEvent: - handlers = HandlerList() def __init__(self, app_message: AppMessage) -> None: @@ -44,7 +43,6 @@ def team_info(self) -> RustTeamInfo: class ChatEvent: - handlers = HandlerList() def __init__(self, app_message: AppMessage) -> None: @@ -56,7 +54,6 @@ def message(self) -> RustChatMessage: class EntityEvent: - handlers = EntityHandlerList() def __init__(self, app_message: AppMessage, entity_type) -> None: @@ -119,7 +116,6 @@ def is_new(self) -> bool: class ProtobufEvent: - handlers = HandlerList() def __init__(self, byte_data) -> None: diff --git a/rustplus/api/remote/events/handler_list.py b/rustplus/api/remote/events/handler_list.py index dfd9257..eb5235d 100644 --- a/rustplus/api/remote/events/handler_list.py +++ b/rustplus/api/remote/events/handler_list.py @@ -37,7 +37,6 @@ def unregister(self, listener: RegisteredListener, server_id: ServerID) -> None: self._handlers.get(server_id).get(listener.listener_id).remove(listener) def register(self, listener: RegisteredListener, server_id: ServerID) -> None: - if server_id not in self._handlers: self._handlers[server_id] = defaultdict(set) diff --git a/rustplus/api/remote/events/map_event_listener.py b/rustplus/api/remote/events/map_event_listener.py index 5166c5d..412ff52 100644 --- a/rustplus/api/remote/events/map_event_listener.py +++ b/rustplus/api/remote/events/map_event_listener.py @@ -36,18 +36,14 @@ def start(self, delay) -> None: self.gc.start() def _run(self) -> None: - while True: - try: - future = asyncio.run_coroutine_threadsafe( self.api.get_markers(), EventLoopManager.get_loop(self.api.server_id), ) new_highest_id = 0 for marker in future.result(): - new = False if marker.id in self.persistent_ids: diff --git a/rustplus/api/remote/events/registered_listener.py b/rustplus/api/remote/events/registered_listener.py index 8e7ce17..c3682b0 100644 --- a/rustplus/api/remote/events/registered_listener.py +++ b/rustplus/api/remote/events/registered_listener.py @@ -1,5 +1,8 @@ +from typing import Union + + class RegisteredListener: - def __init__(self, listener_id: str, data) -> None: + def __init__(self, listener_id: Union[str, int], data) -> None: self.listener_id = str(listener_id) self.data = data @@ -10,7 +13,6 @@ def get_coro(self): def __eq__(self, other) -> bool: if isinstance(other, RegisteredListener): - coro = self.data if isinstance(self.data, tuple): coro = self.data[0] diff --git a/rustplus/api/remote/expo_bundle_handler.py b/rustplus/api/remote/expo_bundle_handler.py index 0eac80d..3d35d5a 100644 --- a/rustplus/api/remote/expo_bundle_handler.py +++ b/rustplus/api/remote/expo_bundle_handler.py @@ -7,7 +7,6 @@ class MagicValueGrabber: @staticmethod def get_magic_value() -> int: - try: data = requests.get( "https://exp.host/@facepunch/RustCompanion", diff --git a/rustplus/api/remote/fcm_listener.py b/rustplus/api/remote/fcm_listener.py index dcded55..efed9d1 100644 --- a/rustplus/api/remote/fcm_listener.py +++ b/rustplus/api/remote/fcm_listener.py @@ -15,7 +15,6 @@ def start(self, daemon=False) -> None: self.thread = Thread(target=self.__fcm_listen, daemon=daemon).start() def __fcm_listen(self) -> None: - if self.data is None: raise ValueError("Data is None") diff --git a/rustplus/api/remote/heartbeat.py b/rustplus/api/remote/heartbeat.py index 9e3118c..bedf410 100644 --- a/rustplus/api/remote/heartbeat.py +++ b/rustplus/api/remote/heartbeat.py @@ -4,13 +4,11 @@ class HeartBeat: def __init__(self, rust_api) -> None: - self.rust_api = rust_api self.next_run = time.time() self.running = False async def start_beat(self) -> None: - if self.running: return @@ -19,21 +17,16 @@ async def start_beat(self) -> None: asyncio.create_task(self._heart_beat()) async def _heart_beat(self) -> None: - while True: - if time.time() >= self.next_run: - await self.beat() else: await asyncio.sleep(1) async def beat(self) -> None: - if self.rust_api.remote.ws is not None and self.rust_api.remote.is_open(): await self.rust_api.send_wakeup_request() def reset_rhythm(self) -> None: - self.next_run = time.time() + 240 diff --git a/rustplus/api/remote/ratelimiter.py b/rustplus/api/remote/ratelimiter.py index 9d28ff6..28e9ca2 100644 --- a/rustplus/api/remote/ratelimiter.py +++ b/rustplus/api/remote/ratelimiter.py @@ -1,6 +1,6 @@ import math import time -import threading +import asyncio from typing import Dict from ...exceptions.exceptions import RateLimitError @@ -37,7 +37,6 @@ def refresh(self) -> None: class RateLimiter: - SERVER_LIMIT = 50 SERVER_REFRESH_AMOUNT = 15 @@ -51,7 +50,7 @@ def default(cls) -> "RateLimiter": def __init__(self) -> None: self.socket_buckets: Dict[ServerID, TokenBucket] = {} self.server_buckets: Dict[str, TokenBucket] = {} - self.lock = threading.Lock() + self.lock = asyncio.Lock() def add_socket( self, @@ -69,66 +68,62 @@ def add_socket( self.SERVER_LIMIT, self.SERVER_LIMIT, 1, self.SERVER_REFRESH_AMOUNT ) - def can_consume(self, server_id: ServerID, amount: int = 1) -> bool: + async def can_consume(self, server_id: ServerID, amount: int = 1) -> bool: """ Returns whether the user can consume the amount of tokens provided """ - self.lock.acquire(blocking=True) - can_consume = True - - for bucket in [ - self.socket_buckets.get(server_id), - self.server_buckets.get(server_id.get_server_string()), - ]: - bucket.refresh() - if not bucket.can_consume(amount): - can_consume = False - - self.lock.release() + async with self.lock: + can_consume = True + + for bucket in [ + self.socket_buckets.get(server_id), + self.server_buckets.get(server_id.get_server_string()), + ]: + bucket.refresh() + if not bucket.can_consume(amount): + can_consume = False + return can_consume - def consume(self, server_id: ServerID, amount: int = 1) -> None: + async def consume(self, server_id: ServerID, amount: int = 1) -> None: """ Consumes an amount of tokens from the bucket. You should first check to see whether it is possible with can_consume """ - self.lock.acquire(blocking=True) - for bucket in [ - self.socket_buckets.get(server_id), - self.server_buckets.get(server_id.get_server_string()), - ]: - bucket.refresh() - if not bucket.can_consume(amount): - self.lock.release() - raise RateLimitError("Not Enough Tokens") - bucket.consume(amount) - self.lock.release() - - def get_estimated_delay_time(self, server_id: ServerID, target_cost: int) -> float: + async with self.lock: + for bucket in [ + self.socket_buckets.get(server_id), + self.server_buckets.get(server_id.get_server_string()), + ]: + bucket.refresh() + if not bucket.can_consume(amount): + self.lock.release() + raise RateLimitError("Not Enough Tokens") + bucket.consume(amount) + + async def get_estimated_delay_time(self, server_id: ServerID, target_cost: int) -> float: """ Returns how long until the amount of tokens needed will be available """ - self.lock.acquire(blocking=True) - delay = 0 - for bucket in [ - self.socket_buckets.get(server_id), - self.server_buckets.get(server_id.get_server_string()), - ]: - val = ( - math.ceil( - (((target_cost - bucket.current) / bucket.refresh_per_second) + 0.1) - * 100 + async with self.lock: + delay = 0 + for bucket in [ + self.socket_buckets.get(server_id), + self.server_buckets.get(server_id.get_server_string()), + ]: + val = ( + math.ceil( + (((target_cost - bucket.current) / bucket.refresh_per_second) + 0.1) + * 100 + ) + / 100 ) - / 100 - ) - if val > delay: - delay = val - self.lock.release() + if val > delay: + delay = val return delay - def remove(self, server_id: ServerID) -> None: + async def remove(self, server_id: ServerID) -> None: """ Removes the limiter """ - self.lock.acquire(blocking=True) - del self.socket_buckets[server_id] - self.lock.release() + async with self.lock: + del self.socket_buckets[server_id] diff --git a/rustplus/api/remote/rust_remote_interface.py b/rustplus/api/remote/rust_remote_interface.py index 7dd972b..97211a3 100644 --- a/rustplus/api/remote/rust_remote_interface.py +++ b/rustplus/api/remote/rust_remote_interface.py @@ -1,19 +1,19 @@ import asyncio import logging from asyncio import Future +from typing import Union, Dict + from .camera.camera_manager import CameraManager from .events import EventLoopManager, EntityEvent, RegisteredListener -from .events.event_handler import EventHandler from .rustplus_proto import AppRequest, AppMessage, AppEmpty, AppCameraSubscribe from .rustws import RustWebsocket, CONNECTED, PENDING_CONNECTION from .ratelimiter import RateLimiter from .expo_bundle_handler import MagicValueGrabber -from ...utils import ServerID +from ...utils import ServerID, YieldingEvent from ...conversation import ConversationFactory from ...commands import CommandHandler from ...exceptions import ( ClientNotConnectedError, - ResponseNotReceivedError, RequestError, SmartDeviceRegistrationError, ) @@ -26,13 +26,11 @@ def __init__( command_options, ratelimit_limit, ratelimit_refill, - websocket_length=600, use_proxy: bool = False, api=None, use_test_server: bool = False, rate_limiter: RateLimiter = None, ) -> None: - self.server_id = server_id self.api = api self.command_options = command_options @@ -48,29 +46,25 @@ def __init__( self.server_id, ratelimit_limit, ratelimit_limit, 1, ratelimit_refill ) self.ws = None - self.websocket_length = websocket_length - self.responses = {} - self.ignored_responses = [] - self.pending_for_response = {} - self.sent_requests = [] - self.command_handler = None + self.logger = logging.getLogger("rustplus.py") + self.ignored_responses = set() + self.pending_response_events: Dict[int, YieldingEvent] = {} + + self.command_handler = None if command_options is None: self.use_commands = False else: self.use_commands = True self.command_handler = CommandHandler(self.command_options, api) - self.event_handler = EventHandler() - self.magic_value = MagicValueGrabber.get_magic_value() self.conversation_factory = ConversationFactory(api) self.use_test_server = use_test_server self.pending_entity_subscriptions = [] - self.camera_manager: CameraManager = None + self.camera_manager: Union[CameraManager, None] = None async def connect(self, retries, delay, on_failure=None) -> None: - self.ws = RustWebsocket( server_id=self.server_id, remote=self, @@ -85,10 +79,9 @@ async def connect(self, retries, delay, on_failure=None) -> None: for entity_id, coroutine in self.pending_entity_subscriptions: self.handle_subscribing_entity(entity_id, coroutine) - def close(self) -> None: - + async def close(self) -> None: if self.ws is not None: - self.ws.close() + await self.ws.close() del self.ws self.ws = None @@ -103,10 +96,10 @@ def is_open(self) -> bool: return False async def send_message(self, request: AppRequest) -> None: - if self.ws is None: raise ClientNotConnectedError("No Current Websocket Connection") + self.pending_response_events[request.seq] = YieldingEvent() await self.ws.send_message(request) async def get_response( @@ -116,36 +109,27 @@ async def get_response( Returns a given response from the server. """ - attempts = 0 - - while seq in self.pending_for_response and seq not in self.responses: - - if seq in self.sent_requests: - - if attempts <= 40: + attempts = 1 - attempts += 1 - await asyncio.sleep(0.1) + while True: + event = self.pending_response_events.get(seq) + if event is None: + raise Exception("Event Doesn't exist") - else: + response: AppMessage = await event.event_wait_for(4) + if response is not None: + break - await self.send_message(app_request) - await asyncio.sleep(0.1) - attempts = 0 - - if attempts <= 10: - await asyncio.sleep(0.1) - attempts += 1 + await self.send_message(app_request) - else: - await self.send_message(app_request) - await asyncio.sleep(1) - attempts = 0 + if attempts % 150 == 0: + self.logger.info( + f"[RustPlus.py] Been waiting 10 minutes for a response for seq {seq}" + ) - if seq not in self.responses: - raise ResponseNotReceivedError("Not Received") + attempts += 1 - response = self.responses.pop(seq) + self.pending_response_events.pop(seq) if response.response.error.error == "rate_limit": logging.getLogger("rustplus.py").warning( @@ -153,27 +137,23 @@ async def get_response( ) # Fully Refill the bucket + bucket = self.ratelimiter.socket_buckets.get(self.server_id) + bucket.current = 0 - self.ratelimiter.socket_buckets.get(self.server_id).current = 0 - - while ( - self.ratelimiter.socket_buckets.get(self.server_id).current - < self.ratelimiter.socket_buckets.get(self.server_id).max - ): + while bucket.current < bucket.max: await asyncio.sleep(1) - self.ratelimiter.socket_buckets.get(self.server_id).refresh() + bucket.refresh() # Reattempt the sending with a full bucket cost = self.ws.get_proto_cost(app_request) while True: - - if self.ratelimiter.can_consume(self.server_id, cost): - self.ratelimiter.consume(self.server_id, cost) + if await self.ratelimiter.can_consume(self.server_id, cost): + await self.ratelimiter.consume(self.server_id, cost) break await asyncio.sleep( - self.ratelimiter.get_estimated_delay_time(self.server_id, cost) + await self.ratelimiter.get_estimated_delay_time(self.server_id, cost) ) await self.send_message(app_request) @@ -185,26 +165,23 @@ async def get_response( return response def handle_subscribing_entity(self, entity_id: int, coroutine) -> None: - if not self.is_open(): self.pending_entity_subscriptions.append((entity_id, coroutine)) return - async def get_entity_info(self: RustRemote, eid): - - await self.api._handle_ratelimit() + async def get_entity_info(remote: RustRemote, eid): + await remote.api._handle_ratelimit() - app_request: AppRequest = self.api._generate_protobuf() + app_request: AppRequest = remote.api._generate_protobuf() app_request.entityId = eid app_request.get_entity_info = AppEmpty() app_request.get_entity_info._serialized_on_wire = True - await self.send_message(app_request) + await remote.send_message(app_request) - return await self.get_response(app_request.seq, app_request, False) + return await remote.get_response(app_request.seq, app_request, False) def entity_event_callback(future_inner: Future) -> None: - entity_info = future_inner.result() if entity_info.response.HasField("error"): @@ -236,12 +213,11 @@ async def subscribe_to_camera( await self.send_message(app_request) if ignore_response: - self.ignored_responses.append(app_request.seq) + await self.add_ignored_response(app_request.seq) return app_request async def create_camera_manager(self, cam_id) -> CameraManager: - if self.camera_manager is not None: if self.camera_manager._cam_id == cam_id: return self.camera_manager @@ -253,3 +229,6 @@ async def create_camera_manager(self, cam_id) -> CameraManager: self.api, cam_id, app_message.response.camera_subscribe_info ) return self.camera_manager + + async def add_ignored_response(self, seq) -> None: + self.ignored_responses.add(seq) diff --git a/rustplus/api/remote/rustws.py b/rustplus/api/remote/rustws.py index a23db39..39425e0 100644 --- a/rustplus/api/remote/rustws.py +++ b/rustplus/api/remote/rustws.py @@ -3,18 +3,19 @@ import logging import time from datetime import datetime -from threading import Thread -from typing import Optional +from typing import Optional, Union import betterproto -import websocket +from asyncio import Task +from websockets.client import connect +from websockets.legacy.client import WebSocketClientProtocol from .camera.structures import RayPacket -from .events import EventLoopManager from .rustplus_proto import AppMessage, AppRequest +from .events import EventHandler from ..structures import RustChatMessage from ...exceptions import ClientNotConnectedError from ...conversation import Conversation -from ...utils import ServerID +from ...utils import ServerID, YieldingEvent CONNECTED = 1 PENDING_CONNECTION = 2 @@ -22,7 +23,7 @@ CLOSED = 3 -class RustWebsocket(websocket.WebSocket): +class RustWebsocket: def __init__( self, server_id: ServerID, @@ -33,9 +34,9 @@ def __init__( on_failure, delay, ): - + self.connection: Union[WebSocketClientProtocol, None] = None + self.task: Union[Task, None] = None self.server_id = server_id - self.thread: Thread = None self.connection_status = CLOSED self.use_proxy = use_proxy self.remote = remote @@ -47,20 +48,15 @@ def __init__( self.on_failure = on_failure self.delay = delay - super().__init__() - async def connect( self, retries=float("inf"), ignore_open_value: bool = False ) -> None: - if ( not self.connection_status == CONNECTED or ignore_open_value ) and not self.remote.is_pending(): - attempts = 0 while True: - if attempts >= retries: raise ConnectionAbortedError("Reached Retry Limit") @@ -81,11 +77,10 @@ async def connect( ) ) address += f"?v={str(self.magic_value)}" - super().connect(address) + self.connection = await connect(address, close_timeout=0) self.connected_time = time.time() break except Exception as exception: - print_error = True if not isinstance(exception, KeyboardInterrupt): @@ -114,16 +109,16 @@ async def connect( self.connection_status = CONNECTED if not ignore_open_value: - self.thread = Thread( - target=self.run, name="[RustPlus.py] WebsocketThread", daemon=True + self.task = asyncio.create_task( + self.run(), name="[RustPlus.py] Websocket Polling Task" ) - self.thread.start() - - def close(self) -> None: + async def close(self) -> None: self.connection_status = CLOSING - self.shutdown() - # super().close() + await self.connection.close() + self.connection = None + self.task.cancel() + self.task = None self.connection_status = CLOSED async def send_message(self, message: AppRequest) -> None: @@ -136,53 +131,46 @@ async def send_message(self, message: AppRequest) -> None: try: if self.use_test_server: - self.send(base64.b64encode(bytes(message)).decode("utf-8")) + await self.connection.send( + base64.b64encode(bytes(message)).decode("utf-8") + ) else: - self.send_binary(bytes(message)) - self.remote.pending_for_response[message.seq] = message - except Exception as e: + await self.connection.send(bytes(message)) + except Exception: + self.logger.exception("An exception occurred whilst sending a message") + while self.remote.is_pending(): await asyncio.sleep(0.5) - return await self.remote.send_message(message) - - def run(self) -> None: + return await self.send_message(message) + async def run(self) -> None: while self.connection_status == CONNECTED: try: - data = self.recv() + data = await self.connection.recv() - self.remote.event_handler.run_proto_event(data, self.server_id) + await EventHandler.run_proto_event(data, self.server_id) app_message = AppMessage() - if self.use_test_server: - app_message.parse(base64.b64decode(data)) - else: - app_message.parse(data) + app_message.parse( + base64.b64decode(data) if self.use_test_server else data + ) except Exception: if self.connection_status == CONNECTED: self.logger.warning( f"{datetime.now().strftime('%d/%m/%Y %H:%M:%S')} [RustPlus.py] Connection interrupted, Retrying" ) - asyncio.run_coroutine_threadsafe( - self.connect(ignore_open_value=True), - EventLoopManager.get_loop(self.server_id), - ).result() + await self.connect(ignore_open_value=True) + continue return try: - del self.remote.pending_for_response[app_message.response.seq] - except KeyError: - pass - - try: - self.handle_message(app_message) - except Exception as e: - self.logger.error(e) - - def handle_message(self, app_message: AppMessage) -> None: + await self.handle_message(app_message) + except Exception: + self.logger.exception("An Error occurred whilst handling the event") + async def handle_message(self, app_message: AppMessage) -> None: if app_message.response.seq in self.remote.ignored_responses: self.remote.ignored_responses.remove(app_message.response.seq) return @@ -195,13 +183,12 @@ def handle_message(self, app_message: AppMessage) -> None: # This means it is a command message = RustChatMessage(app_message.broadcast.team_message.message) - - self.remote.command_handler.run_command(message, prefix) + await self.remote.command_handler.run_command(message, prefix) if self.is_entity_broadcast(app_message): # This means that an entity has changed state - self.remote.event_handler.run_entity_event( + await EventHandler.run_entity_event( app_message.broadcast.entity_changed.entity_id, app_message, self.server_id, @@ -209,13 +196,13 @@ def handle_message(self, app_message: AppMessage) -> None: elif self.is_camera_broadcast(app_message): if self.remote.camera_manager is not None: - self.remote.camera_manager.add_packet( + await self.remote.camera_manager.add_packet( RayPacket(app_message.broadcast.camera_rays) ) elif self.is_team_broadcast(app_message): # This means that the team of the current player has changed - self.remote.event_handler.run_team_event(app_message, self.server_id) + await EventHandler.run_team_event(app_message, self.server_id) elif self.is_message(app_message): # This means that a message has been sent to the team chat @@ -231,36 +218,33 @@ def handle_message(self, app_message: AppMessage) -> None: ) conversation.get_answers().append(message) - conversation.run_coro( - conversation.get_current_prompt().on_response, args=[message] - ) + await conversation.get_current_prompt().on_response(message) if conversation.has_next(): conversation.increment_prompt() prompt = conversation.get_current_prompt() - prompt_string = conversation.run_coro(prompt.prompt, args=[]) - conversation.run_coro( - conversation.send_prompt, args=[prompt_string] - ) + prompt_string = await prompt.prompt() + await conversation.send_prompt(prompt_string) + else: prompt = conversation.get_current_prompt() - prompt_string = conversation.run_coro(prompt.on_finish, args=[]) + prompt_string = await prompt.on_finish() if prompt_string != "": - conversation.run_coro( - conversation.send_prompt, args=[prompt_string] - ) + await conversation.send_prompt(prompt_string) self.remote.conversation_factory.abort_conversation(steam_id) else: self.outgoing_conversation_messages.remove(message) # Conversation API end - self.remote.event_handler.run_chat_event(app_message, self.server_id) + await EventHandler.run_chat_event(app_message, self.server_id) else: # This means that it wasn't sent by the server and is a message from the server in response to an action - - self.remote.responses[app_message.response.seq] = app_message + event: YieldingEvent = self.remote.pending_response_events[ + app_message.response.seq + ] + event.set_with_value(app_message) def get_prefix(self, message: str) -> Optional[str]: if self.remote.use_commands: diff --git a/rustplus/api/remote/server_checker.py b/rustplus/api/remote/server_checker.py index da2bdf1..c4c4aa0 100644 --- a/rustplus/api/remote/server_checker.py +++ b/rustplus/api/remote/server_checker.py @@ -21,4 +21,6 @@ def _check_server(self) -> None: if "does not match your outgoing IP address" not in msg: self.logger.warning(f"Error from server Checker: {msg}") except Exception: - self.logger.exception(f"Unable to test connection to server - {self.ip}:{self.port}") + self.logger.exception( + f"Unable to test connection to server - {self.ip}:{self.port}" + ) diff --git a/rustplus/api/rust_api.py b/rustplus/api/rust_api.py index 41537bb..f901ff4 100644 --- a/rustplus/api/rust_api.py +++ b/rustplus/api/rust_api.py @@ -72,7 +72,6 @@ def __init__( ) async def get_time(self) -> RustTime: - await self._handle_ratelimit() app_request = self._generate_protobuf() @@ -86,7 +85,6 @@ async def get_time(self) -> RustTime: return format_time(response) async def send_team_message(self, message: str) -> None: - await self._handle_ratelimit(2) app_send_message = AppSendMessage() @@ -95,12 +93,11 @@ async def send_team_message(self, message: str) -> None: app_request = self._generate_protobuf() app_request.send_team_message = app_send_message - self.remote.ignored_responses.append(app_request.seq) + await self.remote.add_ignored_response(app_request.seq) await self.remote.send_message(app_request) async def get_info(self) -> RustInfo: - await self._handle_ratelimit() app_request = self._generate_protobuf() @@ -114,7 +111,6 @@ async def get_info(self) -> RustInfo: return RustInfo(response.response.info) async def get_team_chat(self) -> List[RustChatMessage]: - await self._handle_ratelimit() app_request = self._generate_protobuf() @@ -130,7 +126,6 @@ async def get_team_chat(self) -> List[RustChatMessage]: return [RustChatMessage(message) for message in messages] async def get_team_info(self) -> RustTeamInfo: - await self._handle_ratelimit() app_request = self._generate_protobuf() @@ -144,7 +139,6 @@ async def get_team_info(self) -> RustTeamInfo: return RustTeamInfo(app_message.response.team_info) async def get_markers(self) -> List[RustMarker]: - await self._handle_ratelimit() app_request = self._generate_protobuf() @@ -160,7 +154,6 @@ async def get_markers(self) -> List[RustMarker]: ] async def get_raw_map_data(self) -> RustMap: - await self._handle_ratelimit(5) app_request = self._generate_protobuf() @@ -181,7 +174,6 @@ async def get_map( override_images: dict = None, add_grid: bool = False, ) -> Image.Image: - if override_images is None: override_images = {} @@ -295,7 +287,6 @@ async def get_map( return game_map.resize((2000, 2000), Image.ANTIALIAS) async def get_entity_info(self, eid: int = None) -> RustEntityInfo: - await self._handle_ratelimit() if eid is None: @@ -313,7 +304,6 @@ async def get_entity_info(self, eid: int = None) -> RustEntityInfo: return RustEntityInfo(app_message.response.entity_info) async def _update_smart_device(self, eid: int, value: bool) -> None: - await self._handle_ratelimit() entity_value = AppSetEntityValue() @@ -324,26 +314,23 @@ async def _update_smart_device(self, eid: int, value: bool) -> None: app_request.entity_id = eid app_request.set_entity_value = entity_value - self.remote.ignored_responses.append(app_request.seq) + await self.remote.add_ignored_response(app_request.seq) await self.remote.send_message(app_request) async def turn_on_smart_switch(self, eid: int = None) -> None: - if eid is None: raise ValueError("EID cannot be None") await self._update_smart_device(eid, True) async def turn_off_smart_switch(self, eid: int = None) -> None: - if eid is None: raise ValueError("EID cannot be None") await self._update_smart_device(eid, False) async def promote_to_team_leader(self, steam_id: int = None) -> None: - if steam_id is None: raise ValueError("SteamID cannot be None") @@ -355,12 +342,11 @@ async def promote_to_team_leader(self, steam_id: int = None) -> None: app_request = self._generate_protobuf() app_request.promote_to_leader = leader_packet - self.remote.ignored_responses.append(app_request.seq) + await self.remote.add_ignored_response(app_request.seq) await self.remote.send_message(app_request) async def get_current_events(self) -> List[RustMarker]: - return [ marker for marker in (await self.get_markers()) @@ -374,7 +360,6 @@ async def get_current_events(self) -> List[RustMarker]: async def get_contents( self, eid: int = None, combine_stacks: bool = False ) -> RustContents: - if eid is None: raise ValueError("EID cannot be None") diff --git a/rustplus/api/structures/rust_marker.py b/rustplus/api/structures/rust_marker.py index 1a50317..cc1cc1d 100644 --- a/rustplus/api/structures/rust_marker.py +++ b/rustplus/api/structures/rust_marker.py @@ -85,7 +85,6 @@ def __str__(self) -> str: class RustMarker: - PlayerMarker = 1 ExplosionMarker = 2 VendingMachineMarker = 3 diff --git a/rustplus/api/structures/util.py b/rustplus/api/structures/util.py index 6ee683d..1dfac38 100644 --- a/rustplus/api/structures/util.py +++ b/rustplus/api/structures/util.py @@ -3,6 +3,5 @@ @dataclasses.dataclass class Vector: - x: float = 0 y: float = 0 diff --git a/rustplus/commands/command_handler.py b/rustplus/commands/command_handler.py index 2182688..c921d30 100644 --- a/rustplus/commands/command_handler.py +++ b/rustplus/commands/command_handler.py @@ -1,12 +1,11 @@ import asyncio -from asyncio.futures import Future from datetime import datetime from . import Command, CommandTime from ..api.structures import RustChatMessage from ..commands.command_options import CommandOptions from ..commands.command_data import CommandData -from ..api.remote.events import RegisteredListener, EventLoopManager +from ..api.remote.events import RegisteredListener class CommandHandler: @@ -16,22 +15,12 @@ def __init__(self, command_options: CommandOptions, api) -> None: self.api = api def register_command(self, data: CommandData) -> None: - if not asyncio.iscoroutinefunction(data.coro): raise TypeError("The event registered must be a coroutine") self.commands[data.coro.__name__] = data - @staticmethod - def _schedule_event(loop, coro, arg) -> None: - def callback(inner_future: Future): - inner_future.result() - - future: Future = asyncio.run_coroutine_threadsafe(coro(arg), loop) - future.add_done_callback(callback) - - def run_command(self, message: RustChatMessage, prefix) -> None: - + async def run_command(self, message: RustChatMessage, prefix) -> None: if prefix == self.command_options.prefix: command = message.message.split(" ")[0][len(prefix) :] else: @@ -40,16 +29,14 @@ def run_command(self, message: RustChatMessage, prefix) -> None: if command in self.commands: data = self.commands[command] - self._schedule_event( - EventLoopManager.get_loop(self.api.server_id), - data.coro, + await data.coro( Command( message.name, message.steam_id, CommandTime(datetime.utcfromtimestamp(message.time), message.time), command, message.message.split(" ")[1:], - ), + ) ) else: for command_name, data in self.commands.items(): @@ -57,9 +44,7 @@ def run_command(self, message: RustChatMessage, prefix) -> None: # or if it matches the callable function if command in data.aliases or data.callable_func(command): - self._schedule_event( - EventLoopManager.get_loop(self.api.server_id), - data.coro, + data.coro( Command( message.name, message.steam_id, diff --git a/rustplus/commands/command_options.py b/rustplus/commands/command_options.py index 3f56ac0..c12d463 100644 --- a/rustplus/commands/command_options.py +++ b/rustplus/commands/command_options.py @@ -7,7 +7,6 @@ class CommandOptions: def __init__( self, prefix: str = None, overruling_commands: List[str] = None ) -> None: - if prefix is None: raise PrefixNotDefinedError("No prefix") diff --git a/rustplus/conversation/conversation.py b/rustplus/conversation/conversation.py index 0d71544..87042cb 100644 --- a/rustplus/conversation/conversation.py +++ b/rustplus/conversation/conversation.py @@ -1,5 +1,5 @@ import asyncio -from typing import List +from typing import List, Any from .conversation_prompt import ConversationPrompt from ..api.remote.events import EventLoopManager @@ -12,7 +12,6 @@ def __init__( prompts: List[ConversationPrompt] = None, register=None, ) -> None: - if target is None: raise ValueError("target must be specified") self._target = target @@ -54,7 +53,7 @@ async def start(self) -> None: self._register(self._target, self) await self.send_prompt(await self._prompts[0].prompt()) - def run_coro(self, coro, args) -> None: + def run_coro(self, coro, args) -> Any: return asyncio.run_coroutine_threadsafe( coro(*args), EventLoopManager.get_loop(self._api.server_id) ).result() diff --git a/rustplus/conversation/conversation_factory.py b/rustplus/conversation/conversation_factory.py index 1c011b0..33b3b85 100644 --- a/rustplus/conversation/conversation_factory.py +++ b/rustplus/conversation/conversation_factory.py @@ -13,7 +13,6 @@ def __init__(self, api) -> None: self.gc_thread.start() def create_conversation(self, steamid: int) -> Conversation: - if steamid in self.conversations: raise ValueError("Conversation already exists") diff --git a/rustplus/utils/__init__.py b/rustplus/utils/__init__.py index d99f665..36cd83a 100644 --- a/rustplus/utils/__init__.py +++ b/rustplus/utils/__init__.py @@ -2,3 +2,4 @@ from .deprecated import deprecated from .grab_items import translate_id_to_stack from .server_id import ServerID +from .yielding_event import YieldingEvent diff --git a/rustplus/utils/deprecated.py b/rustplus/utils/deprecated.py index 604ebef..75f20f8 100644 --- a/rustplus/utils/deprecated.py +++ b/rustplus/utils/deprecated.py @@ -11,7 +11,6 @@ def deprecated(reason): """ def decorator(func1): - if inspect.isclass(func1): fmt1 = "Call to deprecated class {name} ({reason})." else: diff --git a/rustplus/utils/yielding_event.py b/rustplus/utils/yielding_event.py new file mode 100644 index 0000000..a10d21e --- /dev/null +++ b/rustplus/utils/yielding_event.py @@ -0,0 +1,27 @@ +import asyncio +import contextlib +from typing import Any, Union + + +class YieldingEvent(asyncio.Event): + def __init__(self) -> None: + self.value: Union[Any, None] = None + super().__init__() + + def set_with_value(self, value: Any) -> None: + self.value = value + super().set() + + def clear(self) -> None: + self.value = None + super().clear() + + async def wait(self) -> Any: + await super().wait() + return self.value + + async def event_wait_for(self, timeout) -> Any: + # suppress TimeoutError because we'll return False in case of timeout + with contextlib.suppress(asyncio.TimeoutError): + await asyncio.wait_for(self.wait(), timeout) + return self.value if self.is_set() else None