diff --git a/pychromecast/socket_client.py b/pychromecast/socket_client.py index 57655ac10..e7ea12585 100644 --- a/pychromecast/socket_client.py +++ b/pychromecast/socket_client.py @@ -10,31 +10,31 @@ from __future__ import annotations import abc -from dataclasses import dataclass import errno import json import logging -import select +import selectors import socket import ssl import threading import time from collections import defaultdict +from dataclasses import dataclass from struct import pack, unpack import zeroconf -from .controllers import CallbackType, BaseController +from .const import MESSAGE_TYPE, REQUEST_ID, SESSION_ID +from .controllers import BaseController, CallbackType from .controllers.media import MediaController from .controllers.receiver import CastStatus, CastStatusListener, ReceiverController -from .const import MESSAGE_TYPE, REQUEST_ID, SESSION_ID from .dial import get_host_from_service from .error import ( ChromecastConnectionError, ControllerNotRegistered, - UnsupportedNamespace, NotConnected, PyChromecastStopped, + UnsupportedNamespace, ) # pylint: disable-next=no-name-in-module @@ -64,8 +64,6 @@ CONNECTION_STATUS_FAILED_RESOLVE = "FAILED_RESOLVE" # The socket connection was lost and needs to be retried CONNECTION_STATUS_LOST = "LOST" -# Check for select poll method -SELECT_HAS_POLL = hasattr(select, "poll") HB_PING_TIME = 10 HB_PONG_TIME = 10 @@ -213,6 +211,11 @@ def __init__( self.connecting = True self.first_connection = True self.socket: socket.socket | ssl.SSLSocket | None = None + self.selector = selectors.DefaultSelector() + self.wakeup_selector_key = self.selector.register( + self.socketpair[0], selectors.EVENT_READ + ) + self.remote_selector_key: selectors.SelectorKey | None = None # dict mapping namespace on Controller objects self._handlers: dict[str, set[BaseController]] = defaultdict(set) @@ -236,8 +239,10 @@ def initialize_connection( # pylint:disable=too-many-statements, too-many-branc tries = self.tries if self.socket is not None: + self.selector.unregister(self.socket) self.socket.close() self.socket = None + self.remote_selector_key = None # Make sure nobody is blocking. for callback_function in self._request_callbacks.values(): @@ -286,10 +291,15 @@ def mdns_backoff( try: if self.socket is not None: # If we retry connecting, we need to clean up the socket again - self.socket.close() # type: ignore[unreachable] + self.selector.unregister(self.socket) # type: ignore[unreachable] + self.socket.close() self.socket = None + self.remote_selector_key = None self.socket = new_socket() + self.remote_selector_key = self.selector.register( + self.socket, selectors.EVENT_READ + ) self.socket.settimeout(self.timeout) self._report_connection_status( ConnectionStatus( @@ -557,20 +567,8 @@ def _run_once(self) -> int: assert self.socket is not None # poll the socket, as well as the socketpair to allow us to be interrupted - rlist = [self.socket, self.socketpair[0]] try: - if SELECT_HAS_POLL is True: - # Map file descriptors to socket objects because select.select does not support fd > 1024 - # https://stackoverflow.com/questions/14250751/how-to-increase-filedescriptors-range-in-python-select - fd_to_socket = {rlist_item.fileno(): rlist_item for rlist_item in rlist} - - poll_obj = select.poll() - for poll_fd in rlist: - poll_obj.register(poll_fd, select.POLLIN) - poll_result = poll_obj.poll() - can_read = [fd_to_socket[fd] for fd, _status in poll_result] - else: - can_read, _, _ = select.select(rlist, [], [], None) + ready = self.selector.select() except (ValueError, OSError) as exc: self.logger.error( "[%s(%s):%s] Error in select call: %s", @@ -582,9 +580,10 @@ def _run_once(self) -> int: self._force_recon = True return 0 + can_read = {key for key, _ in ready} # read message from chromecast message = None - if self.socket in can_read and not self._force_recon: + if self.remote_selector_key in can_read and not self._force_recon: try: message = self._read_message() except InterruptLoop as exc: @@ -620,7 +619,7 @@ def _run_once(self) -> int: else: data = _dict_from_message_payload(message) - if self.socketpair[0] in can_read: + if self.wakeup_selector_key in can_read: # Clear the socket's buffer self.socketpair[0].recv(128) @@ -765,6 +764,7 @@ def _cleanup(self) -> None: self.socketpair[0].close() self.socketpair[1].close() + self.selector.close() self.connecting = True