From d0f971f6812de45cef12e482777adacf4f5bd47e Mon Sep 17 00:00:00 2001
From: Danil Akhtarov <daxartio@gmail.com>
Date: Sun, 26 Nov 2023 19:12:30 +0300
Subject: [PATCH] refactor: use default selector in server and client for
 teamwork

---
 sportorg/common/broker.py            |  23 +--
 sportorg/modules/teamwork/client.py  |  44 ++--
 sportorg/modules/teamwork/command.py |  29 +++
 sportorg/modules/teamwork/server.py  | 294 +++++++++++++--------------
 tests/test_teamwork.py               |   4 +-
 5 files changed, 198 insertions(+), 196 deletions(-)
 create mode 100644 sportorg/modules/teamwork/command.py

diff --git a/sportorg/common/broker.py b/sportorg/common/broker.py
index 8e07365a..83da2e7e 100644
--- a/sportorg/common/broker.py
+++ b/sportorg/common/broker.py
@@ -36,23 +36,10 @@ def subscribe(self, name, call, priority=0):
         return self.add(name, priority).subscribe(call)
 
     def produce(self, name, *args, **kwargs):
-        logging.debug(
-            str(datetime.datetime.now()) + ' Broker.produce started for ' + name
-        )
         if name not in self._consumers:
-            logging.debug(
-                str(datetime.datetime.now())
-                + ' Broker.produce finished (no consumers) for '
-                + name
-            )
             return None
 
         if not isinstance(self._consumers[name], list):
-            logging.debug(
-                str(datetime.datetime.now())
-                + ' Broker.produce finished (no consumers) for '
-                + name
-            )
             return None
 
         result = []
@@ -67,16 +54,12 @@ def produce(self, name, *args, **kwargs):
                     r = method(*args, **kwargs)
                 except AttributeError:
                     self._logger.error(
-                        'Class `{}` does not implement `{}`'.format(
-                            cls.__class__.__name__, method_name
-                        )
+                        'Class `%s` does not implement `%s`',
+                        cls.__class__.__name__,
+                        method_name,
                     )
                     r = None
 
             if r:
                 result.append(r)
-
-        logging.debug(
-            str(datetime.datetime.now()) + ' Broker.produce finished for ' + name
-        )
         return result if len(result) else None
diff --git a/sportorg/modules/teamwork/client.py b/sportorg/modules/teamwork/client.py
index d6657d00..603bdd8a 100644
--- a/sportorg/modules/teamwork/client.py
+++ b/sportorg/modules/teamwork/client.py
@@ -1,20 +1,20 @@
 import queue
-import select
+import selectors
 import socket
 from threading import Event, Thread
-from typing import Tuple
+from typing import Tuple, cast
 
 import orjson
 
+from .command import Command
 from .packet_header import Header, Operations
-from .server import Command
 
 
 class ClientSender:
     def __init__(self, in_queue: queue.Queue):
         self._in_queue = in_queue
 
-    def send(self, conn: socket.socket) -> None:
+    def __call__(self, conn: socket.socket) -> None:
         try:
             while True:
                 cmd = self._in_queue.get_nowait()
@@ -32,7 +32,7 @@ def __init__(self, out_queue: queue.Queue):
         self._hdr = Header()
         self._is_new_pack = True
 
-    def read(self, conn: socket.socket) -> None:
+    def __call__(self, conn: socket.socket) -> None:
         data = conn.recv(self.MSG_SIZE)
         if not data:
             return
@@ -76,39 +76,37 @@ def __init__(
         self._out_queue = out_queue
         self._stop_event = stop_event
         self._logger = logger
-        self._client_started = Event()
+        self._started = Event()
 
-    def join_client(self) -> None:
-        self._client_started.wait()
+    def wait(self) -> None:
+        self._started.wait()
 
     def run(self) -> None:
         with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+            selector = selectors.DefaultSelector()
             try:
                 s.connect(self._addr)
                 s.settimeout(5)
                 s.setblocking(False)
+                selector.register(s, selectors.EVENT_READ | selectors.EVENT_WRITE)
                 self._logger.info('Client started')
-                self._client_started.set()
+                self._started.set()
                 sender = ClientSender(self._in_queue)
                 receiver = ClientReceiver(self._out_queue)
-                sockets = [s]
                 while True:
                     if self._stop_event.is_set():
                         break
-                    rread, rwrite, err = select.select(sockets, sockets, [], 1)
-                    if rread:
-                        receiver.read(s)
-                    if rwrite:
-                        sender.send(s)
-
-            except ConnectionRefusedError as e:
-                self._logger.exception(e)
-                self._stop_event.set()
-                return
+                    events = selector.select(timeout=1)
+                    if not events:
+                        continue
+                    for key, mask in events:
+                        if mask & selectors.EVENT_READ:
+                            receiver(cast(socket.socket, key.fileobj))
+                        if mask & selectors.EVENT_WRITE:
+                            sender(cast(socket.socket, key.fileobj))
             except Exception as e:
                 self._logger.exception(e)
                 self._stop_event.set()
-                return
-
-        s.close()
+            finally:
+                selector.close()
         self._logger.info('Client stopped')
diff --git a/sportorg/modules/teamwork/command.py b/sportorg/modules/teamwork/command.py
new file mode 100644
index 00000000..906f9527
--- /dev/null
+++ b/sportorg/modules/teamwork/command.py
@@ -0,0 +1,29 @@
+import socket
+from typing import Optional
+
+import orjson
+
+from .packet_header import Header, ObjectTypes, Operations
+
+
+class Command:
+    def __init__(
+        self,
+        data=None,
+        op=Operations.Update.name,
+        sender: Optional[socket.socket] = None,
+    ):
+        self.data = data
+        self.header = Header(data, op)
+        self.next_cmd_obj_type = ObjectTypes.Unknown.value
+        self._sender = sender
+
+    def __repr__(self) -> str:
+        return str(self.data)
+
+    def is_sender(self, sender: socket.socket) -> bool:
+        return self._sender is sender
+
+    def get_packet(self) -> bytes:
+        pack_data = orjson.dumps(self.data)
+        return self.header.pack_header(len(pack_data)) + pack_data
diff --git a/sportorg/modules/teamwork/server.py b/sportorg/modules/teamwork/server.py
index 9aedd588..99e9528e 100644
--- a/sportorg/modules/teamwork/server.py
+++ b/sportorg/modules/teamwork/server.py
@@ -1,159 +1,143 @@
+import selectors
 import socket
 from queue import Empty, Queue
-from threading import Event, Thread, main_thread
+from threading import Event, Thread
+from typing import List, Tuple, cast
 
 import orjson
 
-from .packet_header import Header, ObjectTypes, Operations
+from .command import Command
+from .packet_header import Header, Operations
 
 
-class Command:
-    def __init__(self, data=None, op=Operations.Update.name, addr=None):
-        self.data = data
-        self.addr = addr
-        self.header = Header(data, op)
-        self.addr_exclude = []
-        self.next_cmd_obj_type = ObjectTypes.Unknown.value
-
-    def __repr__(self) -> str:
-        return str(self.data)
-
-    def exclude(self, addr):
-        self.addr_exclude.append(addr)
-        return self
-
-    def get_packet(self):
-        pack_data = orjson.dumps(self.data)
-        return self.header.pack_header(len(pack_data)) + pack_data
-
-
-class Connect:
-    def __init__(self, conn, addr):
-        self.conn = conn
-        self.addr = addr
-        self._alive = Event()
-
-    def died(self):
-        self._alive.set()
+class ServerReceiver:
+    MSG_SIZE = 1024
 
-    def is_alive(self):
-        return not self._alive.is_set()
-
-
-class ServerReceiverThread(Thread):
-    def __init__(self, conn, in_queue, out_queue, stop_event, logger):
-        super().__init__(daemon=True)
-        self.connect = conn
+    def __init__(
+        self,
+        selector: selectors.BaseSelector,
+        in_queue: Queue,
+        out_queue: Queue,
+        logger,
+    ):
+        self._selector = selector
         self._in_queue = in_queue
         self._out_queue = out_queue
-        self._stop_event = stop_event
         self._logger = logger
-
-    def run(self):
-        with self.connect.conn:
-            self._logger.debug('Server receiver started')
-            self._logger.info('Connected by {}'.format(self.connect.addr))
-            full_data = b''
-            self.connect.conn.settimeout(5)
-            hdr = Header()
-            is_new_pack = True
+        self._full_data = b''
+        self._hdr = Header()
+        self._is_new_pack = True
+
+    def __call__(self, sock: socket.socket) -> None:
+        try:
+            data = sock.recv(self.MSG_SIZE)
+            if not data:
+                self._selector.unregister(sock)
+                sock.close()
+                return
+        except OSError as e:
+            self._logger.error(str(e))
+            self._selector.unregister(sock)
+            sock.close()
+            return
+        try:
+            self._full_data += data
             while True:
-                try:
-                    data = self.connect.conn.recv(1024)
-                    if not data:
+                # getting Header
+                if self._is_new_pack:
+                    if len(self._full_data) >= self._hdr.header_size:
+                        self._hdr.unpack_header(
+                            self._full_data[: self._hdr.header_size]
+                        )
+                        self._full_data = self._full_data[self._hdr.header_size :]
+                        self._is_new_pack = False
+                    else:
                         break
-                    full_data += data
-                    while True:
-                        # getting Header
-                        if is_new_pack:
-                            if len(full_data) >= hdr.header_size:
-                                hdr.unpack_header(full_data[: hdr.header_size])
-                                full_data = full_data[hdr.header_size :]
-                                is_new_pack = False
-                            else:
-                                break
-                        # Getting JSON data
-                        else:
-                            if len(full_data) >= hdr.size:
-                                command = Command(
-                                    orjson.loads(full_data[: hdr.size].decode()),
-                                    Operations(hdr.op_type).name,
-                                    self.connect.addr,
-                                )
-                                command.exclude(self.connect.addr)
-                                self._out_queue.put(command)  # for local
-                                self._in_queue.put(command)  # for child
-                                full_data = full_data[hdr.size :]
-                                is_new_pack = True
-                            else:
-                                break
-
-                except socket.timeout:
-                    if not main_thread().is_alive() or self._stop_event.is_set():
+                # Getting JSON data
+                else:
+                    if len(self._full_data) >= self._hdr.size:
+                        command = Command(
+                            orjson.loads(self._full_data[: self._hdr.size].decode()),
+                            Operations(self._hdr.op_type).name,
+                            sender=sock,
+                        )
+                        self._out_queue.put(command)  # for local
+                        self._in_queue.put(command)  # for child
+                        self._full_data = self._full_data[self._hdr.size :]
+                        self._is_new_pack = True
+                    else:
                         break
-                except ConnectionResetError as e:
-                    self._logger.error(str(e))
-                    break
-                except Exception as e:
-                    self._logger.error(str(e))
-                    break
-        self.connect.conn.close()
-        self.connect.died()
-        self._logger.info('Disconnect {}'.format(self.connect.addr))
+        except Exception as e:
+            self._logger.error(str(e))
 
 
-class ServerSenderThread(Thread):
-    def __init__(self, in_queue, connections_queue, stop_event, logger):
-        super().__init__(daemon=True)
-        self.setName(self.__class__.__name__)
-        self._connections_queue = connections_queue
-        self._connections = []
+class ServerSender:
+    def __init__(self, selector: selectors.BaseSelector, in_queue: Queue, logger):
+        self._selector = selector
         self._in_queue = in_queue
-        self._stop_event = stop_event
         self._logger = logger
 
-    def run(self):
-        self._logger.debug('Server sender start')
-        while True:
-            try:
-                command = self._in_queue.get(timeout=5)
-                for connect in self._connections:
+    def __call__(self, socks: List[socket.socket]) -> None:
+        try:
+            while True:
+                command = self._in_queue.get_nowait()
+                for sock in socks:
+                    if command.is_sender(sock):
+                        continue
                     try:
-                        if (
-                            connect.addr not in command.addr_exclude
-                            and connect.is_alive()
-                        ):
-                            connect.conn.sendall(command.get_packet())
-                    except ConnectionResetError as e:
-                        self._logger.error(str(e))
-                        connect.died()
+                        sock.sendall(command.get_packet())
                     except OSError as e:
                         self._logger.error(str(e))
-                        connect.died()
-            except Empty:
-                while not self._connections_queue.empty():
-                    self._connections.append(self._connections_queue.get())
-                if not main_thread().is_alive() or self._stop_event.is_set():
-                    break
-            except Exception as e:
-                self._logger.error(str(e))
-        self._logger.debug('Server sender shutdown')
-        self._stop_event.set()
+                        self._selector.unregister(sock)
+                        sock.close()
+        except Empty:
+            pass
+
+
+class ConnectionAcceptor:
+    def __init__(
+        self,
+        selector: selectors.BaseSelector,
+        in_queue: Queue,
+        out_queue: Queue,
+        logger,
+    ):
+        self._selector = selector
+        self._logger = logger
+        self._in_queue = in_queue
+        self._out_queue = out_queue
+
+    def __call__(self, sock: socket.socket) -> None:
+        conn, addr = sock.accept()
+        self._selector.register(
+            conn,
+            selectors.EVENT_READ | selectors.EVENT_WRITE,
+            data=ServerReceiver(
+                self._selector, self._in_queue, self._out_queue, self._logger
+            ),
+        )
 
 
 class ServerThread(Thread):
-    def __init__(self, addr, in_queue, out_queue, stop_event, logger):
+    def __init__(
+        self,
+        addr: Tuple[str, int],
+        in_queue: Queue,
+        out_queue: Queue,
+        stop_event: Event,
+        logger,
+    ):
         super().__init__(daemon=True)
-        self.setName(self.__class__.__name__)
+        self.setName('Teamwork Server')
         self.addr = addr
         self._in_queue = in_queue
         self._out_queue = out_queue
         self._stop_event = stop_event
         self._logger = logger
-        self._server_started = Event()
+        self._started = Event()
 
-    def join_server(self) -> None:
-        self._server_started.wait()
+    def wait(self) -> None:
+        self._started.wait()
 
     def run(self) -> None:
         with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
@@ -162,42 +146,50 @@ def run(self) -> None:
             try:
                 s.bind(self.addr)
             except Exception as e:
+                self._logger.error('Server start error')
                 self._logger.debug(str(e))
+                self._stop_event.set()
                 return
             s.listen(1)
             s.settimeout(5)
+            s.setblocking(False)
+            selector = selectors.DefaultSelector()
+            selector.register(
+                s,
+                selectors.EVENT_READ,
+                data=ConnectionAcceptor(
+                    selector, self._in_queue, self._out_queue, self._logger
+                ),
+            )
 
             self._logger.info('Server started')
 
-            conns_queue = Queue()  # type: ignore
-            sender = ServerSenderThread(
-                self._in_queue, conns_queue, self._stop_event, self._logger
-            )
-            sender.start()
-            self._server_started.set()
-
-            connections = []
+            sender = ServerSender(selector, self._in_queue, self._logger)
+            self._started.set()
 
             while True:
                 try:
-                    conn, addr = s.accept()
-                    connect = Connect(conn, addr)
-                    conns_queue.put(connect)
-                    srt = ServerReceiverThread(
-                        connect,
-                        self._in_queue,
-                        self._out_queue,
-                        self._stop_event,
-                        self._logger,
-                    )
-                    srt.start()
-                    connections.append(srt)
-                except socket.timeout:
-                    if not main_thread().is_alive() or self._stop_event.is_set():
+                    if self._stop_event.is_set():
                         break
+                    events = selector.select(timeout=1)
+                    if not events:
+                        continue
+
+                    for key, mask in events:
+                        if mask & selectors.EVENT_READ:
+                            callback = key.data
+                            callback(key.fileobj)
+
+                    ready_to_write = [
+                        key.fileobj
+                        for key, mask in events
+                        if mask & selectors.EVENT_WRITE
+                    ]
+                    if ready_to_write:
+                        sender(cast(List[socket.socket], ready_to_write))
+
                 except Exception as e:
-                    self._logger.error(str(e))
-            sender.join()
-            for srt in connections:
-                srt.join()
-            self._logger.info('Server stopped')
+                    self._logger.exception(str(e))
+
+            selector.close()
+        self._logger.info('Server stopped')
diff --git a/tests/test_teamwork.py b/tests/test_teamwork.py
index 51d897b9..d8cdf417 100644
--- a/tests/test_teamwork.py
+++ b/tests/test_teamwork.py
@@ -13,14 +13,14 @@ def test_teamwork():
     event = Event()
     server = ServerThread(('0.0.0.0', 50010), in_queue, out_queue, event, logging.root)
     server.start()
-    server.join_server()
+    server.wait()
     client_in_queue = Queue()
     client_out_queue = Queue()
     client = ClientThread(
         ('localhost', 50010), client_in_queue, client_out_queue, event, logging.root
     )
     client.start()
-    client.join_client()
+    client.wait()
     time.sleep(5)
 
     in_queue.put(