Skip to content

Commit

Permalink
added more type hints
Browse files Browse the repository at this point in the history
* added __future__ annotations for < 3.10 compatibility
* added as much type hints as possible
  • Loading branch information
rob miller committed Apr 30, 2022
1 parent efddec9 commit d8b1144
Show file tree
Hide file tree
Showing 24 changed files with 463 additions and 254 deletions.
21 changes: 14 additions & 7 deletions amqtt/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from websockets import WebSocketCommonProtocol
from websockets import ConnectionClosed
from asyncio import StreamReader, StreamWriter
from typing import Tuple, Any
import logging


Expand All @@ -23,11 +24,13 @@ async def read(self, n=-1) -> bytes:
return all read bytes. If the EOF was received and the internal buffer is
empty, return an empty bytes object. :return: packet read as bytes data
"""
raise NotImplementedError()

def feed_eof(self):
"""
Acknowledge EOF
"""
raise NotImplementedError()


class WriterAdapter:
Expand All @@ -38,25 +41,29 @@ class WriterAdapter:
the protocol used
"""

def write(self, data):
def write(self, data: bytes) -> None:
"""
write some data to the protocol layer
"""
raise NotImplementedError()

async def drain(self):
async def drain(self) -> None:
"""
Let the write buffer of the underlying transport a chance to be flushed.
"""
raise NotImplementedError()

def get_peer_info(self):
def get_peer_info(self) -> Tuple[Any, Any]:
"""
Return peer socket info (remote address and remote port as tuple
"""
raise NotImplementedError()

async def close(self):
async def close(self) -> None:
"""
Close the protocol connection
"""
raise NotImplementedError()


class WebSocketsReader(ReaderAdapter):
Expand Down Expand Up @@ -103,7 +110,7 @@ def __init__(self, protocol: WebSocketCommonProtocol):
self._protocol = protocol
self._stream = io.BytesIO(b"")

def write(self, data):
def write(self, data: bytes):
"""
write some data to the protocol layer
"""
Expand Down Expand Up @@ -161,7 +168,7 @@ def __init__(self, writer: StreamWriter):
self._writer = writer
self.is_closed = False # StreamWriter has no test for closed...we use our own

def write(self, data):
def write(self, data: bytes):
if not self.is_closed:
self._writer.write(data)

Expand Down Expand Up @@ -208,7 +215,7 @@ class BufferWriter(WriterAdapter):
def __init__(self, buffer=b""):
self._stream = io.BytesIO(buffer)

def write(self, data):
def write(self, data: bytes):
"""
write some data to the protocol layer
"""
Expand Down
115 changes: 87 additions & 28 deletions amqtt/broker.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
# Copyright (c) 2015 Nicolas JOUANIN
#
# See the file license.txt for copying permission.
from typing import Optional
from asyncio.events import AbstractEventLoop
from asyncio.locks import Semaphore
from asyncio.streams import StreamReader, StreamWriter
from typing import Optional, Union, Tuple, Dict, Any, List
import logging
import ssl
import websockets
import asyncio
import re
from re import Match
from asyncio import CancelledError
from collections import deque
from enum import Enum

from functools import partial
from transitions import Machine, MachineError
from amqtt.session import Session
from amqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler
from amqtt.mqtt.protocol.broker_handler import ProtocolHandler, BrokerProtocolHandler
from amqtt.errors import AMQTTException, MQTTException
from amqtt.utils import format_client_message, gen_client_id
from amqtt.adapters import (
Expand Down Expand Up @@ -57,15 +61,30 @@ class RetainedApplicationMessage:

__slots__ = ("source_session", "topic", "data", "qos")

def __init__(self, source_session, topic, data, qos=None):
def __init__(
self,
source_session: Optional[Session],
topic: str,
data: bytes,
qos: int = None,
):
self.source_session = source_session
self.topic = topic
self.data = data
self.qos = qos


class Server:
def __init__(self, listener_name, server_instance, max_connections=-1):
semaphore: Optional[Semaphore] = None
# TODO instance is asyncio.create_task
instance: Any
listener_name: str
conn_count: int = 0
max_connections: int

def __init__(
self, listener_name: str, server_instance: Any, max_connections: int = -1
):
self.logger = logging.getLogger(__name__)
self.instance = server_instance
self.conn_count = 0
Expand All @@ -74,8 +93,6 @@ def __init__(self, listener_name, server_instance, max_connections=-1):
self.max_connections = max_connections
if self.max_connections > 0:
self.semaphore = asyncio.Semaphore(self.max_connections)
else:
self.semaphore = None

async def acquire_connection(self):
if self.semaphore:
Expand Down Expand Up @@ -116,18 +133,25 @@ async def close_instance(self):
class BrokerContext(BaseContext):
"""
BrokerContext is used as the context passed to plugins interacting with the broker.
It act as an adapter to broker services from plugins developed for HBMQTT broker
It act as an adapter to broker services from plugins developed for aMQTT broker
"""

_broker_instance: "Broker"
config: Optional[dict]

def __init__(self, broker: "Broker") -> None:
super().__init__()
self.config = None
self._broker_instance = broker

async def broadcast_message(self, topic, data, qos=None):
async def broadcast_message(
self, topic: str, data: Union[bytes, bytearray], qos: Optional[int] = None
) -> None:
await self._broker_instance.internal_message_broadcast(topic, data, qos)

def retain_message(self, topic_name, data, qos=None):
def retain_message(
self, topic_name: str, data: Union[bytes, bytearray], qos: Optional[int] = None
):
self._broker_instance.retain_message(None, topic_name, data, qos)

@property
Expand Down Expand Up @@ -165,7 +189,23 @@ class Broker:
"stopped",
]

def __init__(self, config=None, loop=None, plugin_namespace=None):
# TODO: Client id str is Optional in session.py, however should be mandatory
_sessions: Dict[str, Tuple[Session, BrokerProtocolHandler]]
_subscriptions: Dict[str, List[Tuple[Session, int]]]
_retained_messages: Dict[str, RetainedApplicationMessage]
_servers: Dict[str, Server] = dict()

# TODO convert to a class or namedtuple?
# {"session": session, "topic": topic, "data": data}
_broadcast_queue: asyncio.Queue[Dict[str, Any]]
_config: Optional[Dict[Any, Any]]

def __init__(
self,
config: Optional[Dict[Any, Any]] = None,
loop: AbstractEventLoop = None,
plugin_namespace: str = None,
):
self.logger = logging.getLogger(__name__)
self.config = _defaults
if config is not None:
Expand All @@ -179,6 +219,7 @@ def __init__(self, config=None, loop=None, plugin_namespace=None):

self._servers = dict()
self._init_states()

self._sessions = dict()
self._subscriptions = dict()
self._retained_messages = dict()
Expand Down Expand Up @@ -381,21 +422,25 @@ async def shutdown(self):
await self.plugins_manager.fire_event(EVENT_BROKER_POST_SHUTDOWN)
self.transitions.stopping_success()

async def internal_message_broadcast(self, topic, data, qos=None):
async def internal_message_broadcast(
self, topic: str, data: bytes, qos: Optional[int] = None
):
return await self._broadcast_message(None, topic, data)

async def ws_connected(self, websocket, uri, listener_name):
async def ws_connected(self, websocket, uri, listener_name: str) -> None:
await self.client_connected(
listener_name, WebSocketsReader(websocket), WebSocketsWriter(websocket)
)

async def stream_connected(self, reader, writer, listener_name):
async def stream_connected(
self, reader: StreamReader, writer: StreamWriter, listener_name: str
) -> None:
await self.client_connected(
listener_name, StreamReaderAdapter(reader), StreamWriterAdapter(writer)
)

async def client_connected(
self, listener_name, reader: ReaderAdapter, writer: WriterAdapter
self, listener_name: str, reader: ReaderAdapter, writer: WriterAdapter
):
# Wait for connection available on listener
server = self._servers.get(listener_name, None)
Expand Down Expand Up @@ -652,12 +697,16 @@ async def client_connected(
self.logger.debug("%s Client disconnected" % client_session.client_id)
server.release_connection()

def _init_handler(self, session, reader, writer):
def _init_handler(
self, session: Session, reader: ReaderAdapter, writer: WriterAdapter
):
"""
Create a BrokerProtocolHandler and attach to a session
:return:
"""
handler = BrokerProtocolHandler(self.plugins_manager, self._loop)
handler = BrokerProtocolHandler(
plugins_manager=self.plugins_manager, loop=self._loop
)
handler.attach(session, reader, writer)
return handler

Expand Down Expand Up @@ -753,9 +802,9 @@ async def topic_filtering(self, session: Session, topic, action: Action):

def retain_message(
self,
source_session: Session,
source_session: Optional[Session],
topic_name: str,
data: bytearray,
data: Union[bytes, bytearray, None],
qos: Optional[int] = None,
) -> None:
if data is not None and data != b"":
Expand All @@ -771,7 +820,7 @@ def retain_message(
self.logger.debug("Clear retained messages for topic '%s'" % topic_name)
del self._retained_messages[topic_name]

async def add_subscription(self, subscription, session):
async def add_subscription(self, subscription, session: Session):
try:
a_filter = subscription[0]
if "#" in a_filter and not a_filter.endswith("#"):
Expand Down Expand Up @@ -843,15 +892,15 @@ def _del_all_subscriptions(self, session: Session) -> None:
:param session:
:return:
"""
filter_queue = deque()
filter_queue: deque = deque()
for topic in self._subscriptions:
if self._del_subscription(topic, session):
filter_queue.append(topic)
for topic in filter_queue:
if not self._subscriptions[topic]:
del self._subscriptions[topic]

def matches(self, topic, a_filter):
def matches(self, topic: str, a_filter: str) -> Union[bool, None, Match[str]]:
if "#" not in a_filter and "+" not in a_filter:
# if filter doesn't contain wildcard, return exact match
return a_filter == topic
Expand All @@ -865,6 +914,7 @@ def matches(self, topic, a_filter):
)
return match_pattern.fullmatch(topic)

# TODO define return type
async def _broadcast_loop(self):
running_tasks = deque()
try:
Expand Down Expand Up @@ -905,8 +955,8 @@ async def _broadcast_loop(self):
task = asyncio.ensure_future(
handler.mqtt_publish(
broadcast["topic"],
broadcast["data"],
qos,
broadcast["data"],
retain=False,
),
)
Expand Down Expand Up @@ -941,13 +991,19 @@ async def _broadcast_loop(self):
await asyncio.wait(running_tasks)
raise # reraise per CancelledError semantics

async def _broadcast_message(self, session, topic, data, force_qos=None):
broadcast = {"session": session, "topic": topic, "data": data}
async def _broadcast_message(
self,
session: Optional[Session],
topic: str,
data: bytes,
force_qos: Optional[int] = None,
) -> None:
broadcast: Dict[str, Any] = {"session": session, "topic": topic, "data": data}
if force_qos:
broadcast["qos"] = force_qos
await self._broadcast_queue.put(broadcast)

async def publish_session_retained_messages(self, session):
async def publish_session_retained_messages(self, session: Session) -> None:
self.logger.debug(
"Publishing %d messages retained for session %s"
% (
Expand All @@ -969,7 +1025,9 @@ async def publish_session_retained_messages(self, session):
if publish_tasks:
await asyncio.wait(publish_tasks)

async def publish_retained_messages_for_subscription(self, subscription, session):
async def publish_retained_messages_for_subscription(
self, subscription: Tuple[str, int], session: Session
) -> None:
self.logger.debug(
"Begin broadcasting messages retained due to subscription on '%s' from %s"
% (subscription[0], format_client_message(session=session))
Expand Down Expand Up @@ -1001,8 +1059,9 @@ def delete_session(self, client_id: str) -> None:
:param client_id:
:return:
"""
# TODO: Cleanup session delete logic
try:
session = self._sessions[client_id][0]
session: Optional[Session] = self._sessions[client_id][0]
except KeyError:
session = None
if session is None:
Expand All @@ -1018,7 +1077,7 @@ def delete_session(self, client_id: str) -> None:
)
del self._sessions[client_id]

def _get_handler(self, session):
def _get_handler(self, session: Session) -> Optional[ProtocolHandler]:
client_id = session.client_id
if client_id:
try:
Expand Down
Loading

0 comments on commit d8b1144

Please sign in to comment.