From d0f971f6812de45cef12e482777adacf4f5bd47e Mon Sep 17 00:00:00 2001 From: Danil Akhtarov <daxartio@gmail.com> Date: Sun, 26 Nov 2023 19:12:30 +0300 Subject: [PATCH] refactor: use default selector in server and client for teamwork --- sportorg/common/broker.py | 23 +-- sportorg/modules/teamwork/client.py | 44 ++-- sportorg/modules/teamwork/command.py | 29 +++ sportorg/modules/teamwork/server.py | 294 +++++++++++++-------------- tests/test_teamwork.py | 4 +- 5 files changed, 198 insertions(+), 196 deletions(-) create mode 100644 sportorg/modules/teamwork/command.py diff --git a/sportorg/common/broker.py b/sportorg/common/broker.py index 8e07365a..83da2e7e 100644 --- a/sportorg/common/broker.py +++ b/sportorg/common/broker.py @@ -36,23 +36,10 @@ def subscribe(self, name, call, priority=0): return self.add(name, priority).subscribe(call) def produce(self, name, *args, **kwargs): - logging.debug( - str(datetime.datetime.now()) + ' Broker.produce started for ' + name - ) if name not in self._consumers: - logging.debug( - str(datetime.datetime.now()) - + ' Broker.produce finished (no consumers) for ' - + name - ) return None if not isinstance(self._consumers[name], list): - logging.debug( - str(datetime.datetime.now()) - + ' Broker.produce finished (no consumers) for ' - + name - ) return None result = [] @@ -67,16 +54,12 @@ def produce(self, name, *args, **kwargs): r = method(*args, **kwargs) except AttributeError: self._logger.error( - 'Class `{}` does not implement `{}`'.format( - cls.__class__.__name__, method_name - ) + 'Class `%s` does not implement `%s`', + cls.__class__.__name__, + method_name, ) r = None if r: result.append(r) - - logging.debug( - str(datetime.datetime.now()) + ' Broker.produce finished for ' + name - ) return result if len(result) else None diff --git a/sportorg/modules/teamwork/client.py b/sportorg/modules/teamwork/client.py index d6657d00..603bdd8a 100644 --- a/sportorg/modules/teamwork/client.py +++ b/sportorg/modules/teamwork/client.py @@ -1,20 +1,20 @@ import queue -import select +import selectors import socket from threading import Event, Thread -from typing import Tuple +from typing import Tuple, cast import orjson +from .command import Command from .packet_header import Header, Operations -from .server import Command class ClientSender: def __init__(self, in_queue: queue.Queue): self._in_queue = in_queue - def send(self, conn: socket.socket) -> None: + def __call__(self, conn: socket.socket) -> None: try: while True: cmd = self._in_queue.get_nowait() @@ -32,7 +32,7 @@ def __init__(self, out_queue: queue.Queue): self._hdr = Header() self._is_new_pack = True - def read(self, conn: socket.socket) -> None: + def __call__(self, conn: socket.socket) -> None: data = conn.recv(self.MSG_SIZE) if not data: return @@ -76,39 +76,37 @@ def __init__( self._out_queue = out_queue self._stop_event = stop_event self._logger = logger - self._client_started = Event() + self._started = Event() - def join_client(self) -> None: - self._client_started.wait() + def wait(self) -> None: + self._started.wait() def run(self) -> None: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + selector = selectors.DefaultSelector() try: s.connect(self._addr) s.settimeout(5) s.setblocking(False) + selector.register(s, selectors.EVENT_READ | selectors.EVENT_WRITE) self._logger.info('Client started') - self._client_started.set() + self._started.set() sender = ClientSender(self._in_queue) receiver = ClientReceiver(self._out_queue) - sockets = [s] while True: if self._stop_event.is_set(): break - rread, rwrite, err = select.select(sockets, sockets, [], 1) - if rread: - receiver.read(s) - if rwrite: - sender.send(s) - - except ConnectionRefusedError as e: - self._logger.exception(e) - self._stop_event.set() - return + events = selector.select(timeout=1) + if not events: + continue + for key, mask in events: + if mask & selectors.EVENT_READ: + receiver(cast(socket.socket, key.fileobj)) + if mask & selectors.EVENT_WRITE: + sender(cast(socket.socket, key.fileobj)) except Exception as e: self._logger.exception(e) self._stop_event.set() - return - - s.close() + finally: + selector.close() self._logger.info('Client stopped') diff --git a/sportorg/modules/teamwork/command.py b/sportorg/modules/teamwork/command.py new file mode 100644 index 00000000..906f9527 --- /dev/null +++ b/sportorg/modules/teamwork/command.py @@ -0,0 +1,29 @@ +import socket +from typing import Optional + +import orjson + +from .packet_header import Header, ObjectTypes, Operations + + +class Command: + def __init__( + self, + data=None, + op=Operations.Update.name, + sender: Optional[socket.socket] = None, + ): + self.data = data + self.header = Header(data, op) + self.next_cmd_obj_type = ObjectTypes.Unknown.value + self._sender = sender + + def __repr__(self) -> str: + return str(self.data) + + def is_sender(self, sender: socket.socket) -> bool: + return self._sender is sender + + def get_packet(self) -> bytes: + pack_data = orjson.dumps(self.data) + return self.header.pack_header(len(pack_data)) + pack_data diff --git a/sportorg/modules/teamwork/server.py b/sportorg/modules/teamwork/server.py index 9aedd588..99e9528e 100644 --- a/sportorg/modules/teamwork/server.py +++ b/sportorg/modules/teamwork/server.py @@ -1,159 +1,143 @@ +import selectors import socket from queue import Empty, Queue -from threading import Event, Thread, main_thread +from threading import Event, Thread +from typing import List, Tuple, cast import orjson -from .packet_header import Header, ObjectTypes, Operations +from .command import Command +from .packet_header import Header, Operations -class Command: - def __init__(self, data=None, op=Operations.Update.name, addr=None): - self.data = data - self.addr = addr - self.header = Header(data, op) - self.addr_exclude = [] - self.next_cmd_obj_type = ObjectTypes.Unknown.value - - def __repr__(self) -> str: - return str(self.data) - - def exclude(self, addr): - self.addr_exclude.append(addr) - return self - - def get_packet(self): - pack_data = orjson.dumps(self.data) - return self.header.pack_header(len(pack_data)) + pack_data - - -class Connect: - def __init__(self, conn, addr): - self.conn = conn - self.addr = addr - self._alive = Event() - - def died(self): - self._alive.set() +class ServerReceiver: + MSG_SIZE = 1024 - def is_alive(self): - return not self._alive.is_set() - - -class ServerReceiverThread(Thread): - def __init__(self, conn, in_queue, out_queue, stop_event, logger): - super().__init__(daemon=True) - self.connect = conn + def __init__( + self, + selector: selectors.BaseSelector, + in_queue: Queue, + out_queue: Queue, + logger, + ): + self._selector = selector self._in_queue = in_queue self._out_queue = out_queue - self._stop_event = stop_event self._logger = logger - - def run(self): - with self.connect.conn: - self._logger.debug('Server receiver started') - self._logger.info('Connected by {}'.format(self.connect.addr)) - full_data = b'' - self.connect.conn.settimeout(5) - hdr = Header() - is_new_pack = True + self._full_data = b'' + self._hdr = Header() + self._is_new_pack = True + + def __call__(self, sock: socket.socket) -> None: + try: + data = sock.recv(self.MSG_SIZE) + if not data: + self._selector.unregister(sock) + sock.close() + return + except OSError as e: + self._logger.error(str(e)) + self._selector.unregister(sock) + sock.close() + return + try: + self._full_data += data while True: - try: - data = self.connect.conn.recv(1024) - if not data: + # getting Header + if self._is_new_pack: + if len(self._full_data) >= self._hdr.header_size: + self._hdr.unpack_header( + self._full_data[: self._hdr.header_size] + ) + self._full_data = self._full_data[self._hdr.header_size :] + self._is_new_pack = False + else: break - full_data += data - while True: - # getting Header - if is_new_pack: - if len(full_data) >= hdr.header_size: - hdr.unpack_header(full_data[: hdr.header_size]) - full_data = full_data[hdr.header_size :] - is_new_pack = False - else: - break - # Getting JSON data - else: - if len(full_data) >= hdr.size: - command = Command( - orjson.loads(full_data[: hdr.size].decode()), - Operations(hdr.op_type).name, - self.connect.addr, - ) - command.exclude(self.connect.addr) - self._out_queue.put(command) # for local - self._in_queue.put(command) # for child - full_data = full_data[hdr.size :] - is_new_pack = True - else: - break - - except socket.timeout: - if not main_thread().is_alive() or self._stop_event.is_set(): + # Getting JSON data + else: + if len(self._full_data) >= self._hdr.size: + command = Command( + orjson.loads(self._full_data[: self._hdr.size].decode()), + Operations(self._hdr.op_type).name, + sender=sock, + ) + self._out_queue.put(command) # for local + self._in_queue.put(command) # for child + self._full_data = self._full_data[self._hdr.size :] + self._is_new_pack = True + else: break - except ConnectionResetError as e: - self._logger.error(str(e)) - break - except Exception as e: - self._logger.error(str(e)) - break - self.connect.conn.close() - self.connect.died() - self._logger.info('Disconnect {}'.format(self.connect.addr)) + except Exception as e: + self._logger.error(str(e)) -class ServerSenderThread(Thread): - def __init__(self, in_queue, connections_queue, stop_event, logger): - super().__init__(daemon=True) - self.setName(self.__class__.__name__) - self._connections_queue = connections_queue - self._connections = [] +class ServerSender: + def __init__(self, selector: selectors.BaseSelector, in_queue: Queue, logger): + self._selector = selector self._in_queue = in_queue - self._stop_event = stop_event self._logger = logger - def run(self): - self._logger.debug('Server sender start') - while True: - try: - command = self._in_queue.get(timeout=5) - for connect in self._connections: + def __call__(self, socks: List[socket.socket]) -> None: + try: + while True: + command = self._in_queue.get_nowait() + for sock in socks: + if command.is_sender(sock): + continue try: - if ( - connect.addr not in command.addr_exclude - and connect.is_alive() - ): - connect.conn.sendall(command.get_packet()) - except ConnectionResetError as e: - self._logger.error(str(e)) - connect.died() + sock.sendall(command.get_packet()) except OSError as e: self._logger.error(str(e)) - connect.died() - except Empty: - while not self._connections_queue.empty(): - self._connections.append(self._connections_queue.get()) - if not main_thread().is_alive() or self._stop_event.is_set(): - break - except Exception as e: - self._logger.error(str(e)) - self._logger.debug('Server sender shutdown') - self._stop_event.set() + self._selector.unregister(sock) + sock.close() + except Empty: + pass + + +class ConnectionAcceptor: + def __init__( + self, + selector: selectors.BaseSelector, + in_queue: Queue, + out_queue: Queue, + logger, + ): + self._selector = selector + self._logger = logger + self._in_queue = in_queue + self._out_queue = out_queue + + def __call__(self, sock: socket.socket) -> None: + conn, addr = sock.accept() + self._selector.register( + conn, + selectors.EVENT_READ | selectors.EVENT_WRITE, + data=ServerReceiver( + self._selector, self._in_queue, self._out_queue, self._logger + ), + ) class ServerThread(Thread): - def __init__(self, addr, in_queue, out_queue, stop_event, logger): + def __init__( + self, + addr: Tuple[str, int], + in_queue: Queue, + out_queue: Queue, + stop_event: Event, + logger, + ): super().__init__(daemon=True) - self.setName(self.__class__.__name__) + self.setName('Teamwork Server') self.addr = addr self._in_queue = in_queue self._out_queue = out_queue self._stop_event = stop_event self._logger = logger - self._server_started = Event() + self._started = Event() - def join_server(self) -> None: - self._server_started.wait() + def wait(self) -> None: + self._started.wait() def run(self) -> None: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: @@ -162,42 +146,50 @@ def run(self) -> None: try: s.bind(self.addr) except Exception as e: + self._logger.error('Server start error') self._logger.debug(str(e)) + self._stop_event.set() return s.listen(1) s.settimeout(5) + s.setblocking(False) + selector = selectors.DefaultSelector() + selector.register( + s, + selectors.EVENT_READ, + data=ConnectionAcceptor( + selector, self._in_queue, self._out_queue, self._logger + ), + ) self._logger.info('Server started') - conns_queue = Queue() # type: ignore - sender = ServerSenderThread( - self._in_queue, conns_queue, self._stop_event, self._logger - ) - sender.start() - self._server_started.set() - - connections = [] + sender = ServerSender(selector, self._in_queue, self._logger) + self._started.set() while True: try: - conn, addr = s.accept() - connect = Connect(conn, addr) - conns_queue.put(connect) - srt = ServerReceiverThread( - connect, - self._in_queue, - self._out_queue, - self._stop_event, - self._logger, - ) - srt.start() - connections.append(srt) - except socket.timeout: - if not main_thread().is_alive() or self._stop_event.is_set(): + if self._stop_event.is_set(): break + events = selector.select(timeout=1) + if not events: + continue + + for key, mask in events: + if mask & selectors.EVENT_READ: + callback = key.data + callback(key.fileobj) + + ready_to_write = [ + key.fileobj + for key, mask in events + if mask & selectors.EVENT_WRITE + ] + if ready_to_write: + sender(cast(List[socket.socket], ready_to_write)) + except Exception as e: - self._logger.error(str(e)) - sender.join() - for srt in connections: - srt.join() - self._logger.info('Server stopped') + self._logger.exception(str(e)) + + selector.close() + self._logger.info('Server stopped') diff --git a/tests/test_teamwork.py b/tests/test_teamwork.py index 51d897b9..d8cdf417 100644 --- a/tests/test_teamwork.py +++ b/tests/test_teamwork.py @@ -13,14 +13,14 @@ def test_teamwork(): event = Event() server = ServerThread(('0.0.0.0', 50010), in_queue, out_queue, event, logging.root) server.start() - server.join_server() + server.wait() client_in_queue = Queue() client_out_queue = Queue() client = ClientThread( ('localhost', 50010), client_in_queue, client_out_queue, event, logging.root ) client.start() - client.join_client() + client.wait() time.sleep(5) in_queue.put(