Skip to content

Commit

Permalink
added more type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
rob miller committed Apr 24, 2022
1 parent efddec9 commit 5dedd1d
Show file tree
Hide file tree
Showing 23 changed files with 250 additions and 183 deletions.
8 changes: 4 additions & 4 deletions amqtt/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class WriterAdapter:
the protocol used
"""

def write(self, data):
def write(self, data: bytes):
"""
write some data to the protocol layer
"""
Expand Down Expand Up @@ -103,7 +103,7 @@ def __init__(self, protocol: WebSocketCommonProtocol):
self._protocol = protocol
self._stream = io.BytesIO(b"")

def write(self, data):
def write(self, data: bytes):
"""
write some data to the protocol layer
"""
Expand Down Expand Up @@ -161,7 +161,7 @@ def __init__(self, writer: StreamWriter):
self._writer = writer
self.is_closed = False # StreamWriter has no test for closed...we use our own

def write(self, data):
def write(self, data: bytes):
if not self.is_closed:
self._writer.write(data)

Expand Down Expand Up @@ -208,7 +208,7 @@ class BufferWriter(WriterAdapter):
def __init__(self, buffer=b""):
self._stream = io.BytesIO(buffer)

def write(self, data):
def write(self, data: bytes):
"""
write some data to the protocol layer
"""
Expand Down
37 changes: 22 additions & 15 deletions amqtt/broker.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# Copyright (c) 2015 Nicolas JOUANIN
#
# See the file license.txt for copying permission.
from typing import Optional
from asyncio.events import AbstractEventLoop
from typing import Optional, Union, Type, Tuple, Dict
import logging
import ssl
import websockets
import asyncio
import re
from re import Match
from asyncio import CancelledError
from collections import deque
from enum import Enum
Expand Down Expand Up @@ -57,15 +59,15 @@ class RetainedApplicationMessage:

__slots__ = ("source_session", "topic", "data", "qos")

def __init__(self, source_session, topic, data, qos=None):
def __init__(self, source_session: Optional[Session], topic: str, data: bytes, qos: int = None):
self.source_session = source_session
self.topic = topic
self.data = data
self.qos = qos


class Server:
def __init__(self, listener_name, server_instance, max_connections=-1):
def __init__(self, listener_name: str, server_instance, max_connections: int = -1):
self.logger = logging.getLogger(__name__)
self.instance = server_instance
self.conn_count = 0
Expand Down Expand Up @@ -124,10 +126,10 @@ def __init__(self, broker: "Broker") -> None:
self.config = None
self._broker_instance = broker

async def broadcast_message(self, topic, data, qos=None):
async def broadcast_message(self, topic: str, data: bytes, qos: Optional[int] = None):
await self._broker_instance.internal_message_broadcast(topic, data, qos)

def retain_message(self, topic_name, data, qos=None):
def retain_message(self, topic_name: str, data: bytes, qos: Optional[int] = None):
self._broker_instance.retain_message(None, topic_name, data, qos)

@property
Expand Down Expand Up @@ -165,7 +167,11 @@ class Broker:
"stopped",
]

def __init__(self, config=None, loop=None, plugin_namespace=None):
_sessions: Dict[str, Tuple[Session, BrokerProtocolHandler]]
_subscriptions: Dict[str, Tuple[Session, int]]
_retained_messages: Dict[str, RetainedApplicationMessage]

def __init__(self, config=None, loop: AbstractEventLoop = None, plugin_namespace: str = None):
self.logger = logging.getLogger(__name__)
self.config = _defaults
if config is not None:
Expand All @@ -179,6 +185,7 @@ def __init__(self, config=None, loop=None, plugin_namespace=None):

self._servers = dict()
self._init_states()

self._sessions = dict()
self._subscriptions = dict()
self._retained_messages = dict()
Expand Down Expand Up @@ -381,7 +388,7 @@ async def shutdown(self):
await self.plugins_manager.fire_event(EVENT_BROKER_POST_SHUTDOWN)
self.transitions.stopping_success()

async def internal_message_broadcast(self, topic, data, qos=None):
async def internal_message_broadcast(self, topic: str, data: bytes, qos: Optional[int] = None):
return await self._broadcast_message(None, topic, data)

async def ws_connected(self, websocket, uri, listener_name):
Expand Down Expand Up @@ -652,7 +659,7 @@ async def client_connected(
self.logger.debug("%s Client disconnected" % client_session.client_id)
server.release_connection()

def _init_handler(self, session, reader, writer):
def _init_handler(self, session: Session, reader: Type[ReaderAdapter], writer: Type[WriterAdapter]):
"""
Create a BrokerProtocolHandler and attach to a session
:return:
Expand Down Expand Up @@ -753,7 +760,7 @@ async def topic_filtering(self, session: Session, topic, action: Action):

def retain_message(
self,
source_session: Session,
source_session: Optional[Session],
topic_name: str,
data: bytearray,
qos: Optional[int] = None,
Expand All @@ -771,7 +778,7 @@ def retain_message(
self.logger.debug("Clear retained messages for topic '%s'" % topic_name)
del self._retained_messages[topic_name]

async def add_subscription(self, subscription, session):
async def add_subscription(self, subscription, session: Session):
try:
a_filter = subscription[0]
if "#" in a_filter and not a_filter.endswith("#"):
Expand Down Expand Up @@ -851,7 +858,7 @@ def _del_all_subscriptions(self, session: Session) -> None:
if not self._subscriptions[topic]:
del self._subscriptions[topic]

def matches(self, topic, a_filter):
def matches(self, topic: str, a_filter: str) -> Union[bool, None, Match[str]]:
if "#" not in a_filter and "+" not in a_filter:
# if filter doesn't contain wildcard, return exact match
return a_filter == topic
Expand Down Expand Up @@ -941,13 +948,13 @@ async def _broadcast_loop(self):
await asyncio.wait(running_tasks)
raise # reraise per CancelledError semantics

async def _broadcast_message(self, session, topic, data, force_qos=None):
async def _broadcast_message(self, session: Optional[Session], topic: str, data: bytes, force_qos: Optional[bool] = None):
broadcast = {"session": session, "topic": topic, "data": data}
if force_qos:
broadcast["qos"] = force_qos
await self._broadcast_queue.put(broadcast)

async def publish_session_retained_messages(self, session):
async def publish_session_retained_messages(self, session: Session):
self.logger.debug(
"Publishing %d messages retained for session %s"
% (
Expand All @@ -969,7 +976,7 @@ async def publish_session_retained_messages(self, session):
if publish_tasks:
await asyncio.wait(publish_tasks)

async def publish_retained_messages_for_subscription(self, subscription, session):
async def publish_retained_messages_for_subscription(self, subscription, session: Session):
self.logger.debug(
"Begin broadcasting messages retained due to subscription on '%s' from %s"
% (subscription[0], format_client_message(session=session))
Expand Down Expand Up @@ -1018,7 +1025,7 @@ def delete_session(self, client_id: str) -> None:
)
del self._sessions[client_id]

def _get_handler(self, session):
def _get_handler(self, session: Session):
client_id = session.client_id
if client_id:
try:
Expand Down
7 changes: 4 additions & 3 deletions amqtt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import copy
from urllib.parse import urlparse, urlunparse
from functools import wraps
from typing import List, Tuple

from amqtt.session import Session
from amqtt.mqtt.connack import CONNECTION_ACCEPTED
Expand Down Expand Up @@ -310,7 +311,7 @@ def get_retain_and_qos():
)

@mqtt_connected
async def subscribe(self, topics):
async def subscribe(self, topics: List[Tuple[str,int]]):
"""
Subscribe to some topics.
Expand All @@ -332,7 +333,7 @@ async def subscribe(self, topics):
return await self._handler.mqtt_subscribe(topics, self.session.next_packet_id)

@mqtt_connected
async def unsubscribe(self, topics):
async def unsubscribe(self, topics: List[str]):
"""
Unsubscribe from some topics.
Expand All @@ -349,7 +350,7 @@ async def unsubscribe(self, topics):
"""
await self._handler.mqtt_unsubscribe(topics, self.session.next_packet_id)

async def deliver_message(self, timeout=None):
async def deliver_message(self, timeout: int = None):
"""
Deliver next received message.
Expand Down
11 changes: 6 additions & 5 deletions amqtt/codecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import asyncio
from struct import pack, unpack
from amqtt.errors import NoDataException
from amqtt.adapters import ReaderAdapter


def bytes_to_hex_str(data):
Expand All @@ -15,7 +16,7 @@ def bytes_to_hex_str(data):
return "0x" + "".join(format(b, "02x") for b in data)


def bytes_to_int(data):
def bytes_to_int(data: bytes):
"""
convert a sequence of bytes to an integer using big endian byte ordering
:param data: byte sequence
Expand All @@ -41,7 +42,7 @@ def int_to_bytes(int_value: int, length: int) -> bytes:
return pack(fmt, int_value)


async def read_or_raise(reader, n=-1):
async def read_or_raise(reader: ReaderAdapter, n=-1):
"""
Read a given byte number from Stream. NoDataException is raised if read gives no data
:param reader: reader adapter
Expand All @@ -57,7 +58,7 @@ async def read_or_raise(reader, n=-1):
return data


async def decode_string(reader) -> str:
async def decode_string(reader: ReaderAdapter) -> str:
"""
Read a string from a reader and decode it according to MQTT string specification
:param reader: Stream reader
Expand All @@ -75,7 +76,7 @@ async def decode_string(reader) -> str:
return ""


async def decode_data_with_length(reader) -> bytes:
async def decode_data_with_length(reader: ReaderAdapter) -> bytes:
"""
Read data from a reader. Data is prefixed with 2 bytes length
:param reader: Stream reader
Expand All @@ -98,7 +99,7 @@ def encode_data_with_length(data: bytes) -> bytes:
return int_to_bytes(data_length, 2) + data


async def decode_packet_id(reader) -> int:
async def decode_packet_id(reader: ReaderAdapter) -> int:
"""
Read a packet ID as 2-bytes int from stream according to MQTT specification (2.3.1)
:param reader: Stream reader
Expand Down
30 changes: 17 additions & 13 deletions amqtt/mqtt/connack.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,28 @@
# Copyright (c) 2015 Nicolas JOUANIN
#
# See the file license.txt for copying permission.
# Required for type hints in classes that self reference for python < v3.10
from __future__ import annotations
from typing import Optional

from amqtt.mqtt.packet import CONNACK, MQTTPacket, MQTTFixedHeader, MQTTVariableHeader
from amqtt.codecs import read_or_raise, bytes_to_int
from amqtt.errors import AMQTTException
from amqtt.adapters import ReaderAdapter

CONNECTION_ACCEPTED = 0x00
UNACCEPTABLE_PROTOCOL_VERSION = 0x01
IDENTIFIER_REJECTED = 0x02
SERVER_UNAVAILABLE = 0x03
BAD_USERNAME_PASSWORD = 0x04
NOT_AUTHORIZED = 0x05
CONNECTION_ACCEPTED: int = 0x00
UNACCEPTABLE_PROTOCOL_VERSION: int = 0x01
IDENTIFIER_REJECTED: int = 0x02
SERVER_UNAVAILABLE: int = 0x03
BAD_USERNAME_PASSWORD: int = 0x04
NOT_AUTHORIZED: int = 0x05


class ConnackVariableHeader(MQTTVariableHeader):

__slots__ = ("session_parent", "return_code")

def __init__(self, session_parent=None, return_code=None):
def __init__(self, session_parent: Optional[int] = None, return_code: Optional[int] = None):
super().__init__()
self.session_parent = session_parent
self.return_code = return_code
Expand Down Expand Up @@ -57,22 +61,22 @@ def return_code(self):
return self.variable_header.return_code

@return_code.setter
def return_code(self, return_code):
def return_code(self, return_code: int):
self.variable_header.return_code = return_code

@property
def session_parent(self):
return self.variable_header.session_parent

@session_parent.setter
def session_parent(self, session_parent):
def session_parent(self, session_parent: int):
self.variable_header.session_parent = session_parent

def __init__(
self,
fixed: MQTTFixedHeader = None,
variable_header: ConnackVariableHeader = None,
payload=None,
fixed: Optional[MQTTFixedHeader] = None,
variable_header: Optional[ConnackVariableHeader] = None,
payload = None,
):
if fixed is None:
header = MQTTFixedHeader(CONNACK, 0x00)
Expand All @@ -88,7 +92,7 @@ def __init__(
self.payload = None

@classmethod
def build(cls, session_parent=None, return_code=None):
def build(cls, session_parent: int = None, return_code: int = None) -> ConnackPacket:
v_header = ConnackVariableHeader(session_parent, return_code)
packet = ConnackPacket(variable_header=v_header)
return packet
Loading

0 comments on commit 5dedd1d

Please sign in to comment.