From 5dedd1d8f59cfbc9107456821cf69215c462224a Mon Sep 17 00:00:00 2001 From: rob miller Date: Sun, 24 Apr 2022 17:30:20 +0800 Subject: [PATCH] added more type hints --- amqtt/adapters.py | 8 +-- amqtt/broker.py | 37 ++++++----- amqtt/client.py | 7 ++- amqtt/codecs.py | 11 ++-- amqtt/mqtt/connack.py | 30 +++++---- amqtt/mqtt/connect.py | 91 ++++++++++++++------------- amqtt/mqtt/constants.py | 6 +- amqtt/mqtt/packet.py | 48 +++++++------- amqtt/mqtt/protocol/broker_handler.py | 17 +++-- amqtt/mqtt/protocol/client_handler.py | 15 ++--- amqtt/mqtt/protocol/handler.py | 26 ++++---- amqtt/mqtt/puback.py | 2 +- amqtt/mqtt/pubcomp.py | 2 +- amqtt/mqtt/publish.py | 16 ++--- amqtt/mqtt/pubrec.py | 6 +- amqtt/mqtt/pubrel.py | 6 +- amqtt/mqtt/suback.py | 16 ++--- amqtt/mqtt/subscribe.py | 10 +-- amqtt/mqtt/unsuback.py | 3 +- amqtt/mqtt/unsubscribe.py | 8 ++- amqtt/session.py | 57 ++++++++++++----- amqtt/utils.py | 7 +-- amqtt/version.py | 4 +- 23 files changed, 250 insertions(+), 183 deletions(-) diff --git a/amqtt/adapters.py b/amqtt/adapters.py index b37a86c6..b16b6eda 100644 --- a/amqtt/adapters.py +++ b/amqtt/adapters.py @@ -38,7 +38,7 @@ class WriterAdapter: the protocol used """ - def write(self, data): + def write(self, data: bytes): """ write some data to the protocol layer """ @@ -103,7 +103,7 @@ def __init__(self, protocol: WebSocketCommonProtocol): self._protocol = protocol self._stream = io.BytesIO(b"") - def write(self, data): + def write(self, data: bytes): """ write some data to the protocol layer """ @@ -161,7 +161,7 @@ def __init__(self, writer: StreamWriter): self._writer = writer self.is_closed = False # StreamWriter has no test for closed...we use our own - def write(self, data): + def write(self, data: bytes): if not self.is_closed: self._writer.write(data) @@ -208,7 +208,7 @@ class BufferWriter(WriterAdapter): def __init__(self, buffer=b""): self._stream = io.BytesIO(buffer) - def write(self, data): + def write(self, data: bytes): """ write some data to the protocol layer """ diff --git a/amqtt/broker.py b/amqtt/broker.py index c91a9af4..b99ad7d8 100644 --- a/amqtt/broker.py +++ b/amqtt/broker.py @@ -1,12 +1,14 @@ # Copyright (c) 2015 Nicolas JOUANIN # # See the file license.txt for copying permission. -from typing import Optional +from asyncio.events import AbstractEventLoop +from typing import Optional, Union, Type, Tuple, Dict import logging import ssl import websockets import asyncio import re +from re import Match from asyncio import CancelledError from collections import deque from enum import Enum @@ -57,7 +59,7 @@ class RetainedApplicationMessage: __slots__ = ("source_session", "topic", "data", "qos") - def __init__(self, source_session, topic, data, qos=None): + def __init__(self, source_session: Optional[Session], topic: str, data: bytes, qos: int = None): self.source_session = source_session self.topic = topic self.data = data @@ -65,7 +67,7 @@ def __init__(self, source_session, topic, data, qos=None): class Server: - def __init__(self, listener_name, server_instance, max_connections=-1): + def __init__(self, listener_name: str, server_instance, max_connections: int = -1): self.logger = logging.getLogger(__name__) self.instance = server_instance self.conn_count = 0 @@ -124,10 +126,10 @@ def __init__(self, broker: "Broker") -> None: self.config = None self._broker_instance = broker - async def broadcast_message(self, topic, data, qos=None): + async def broadcast_message(self, topic: str, data: bytes, qos: Optional[int] = None): await self._broker_instance.internal_message_broadcast(topic, data, qos) - def retain_message(self, topic_name, data, qos=None): + def retain_message(self, topic_name: str, data: bytes, qos: Optional[int] = None): self._broker_instance.retain_message(None, topic_name, data, qos) @property @@ -165,7 +167,11 @@ class Broker: "stopped", ] - def __init__(self, config=None, loop=None, plugin_namespace=None): + _sessions: Dict[str, Tuple[Session, BrokerProtocolHandler]] + _subscriptions: Dict[str, Tuple[Session, int]] + _retained_messages: Dict[str, RetainedApplicationMessage] + + def __init__(self, config=None, loop: AbstractEventLoop = None, plugin_namespace: str = None): self.logger = logging.getLogger(__name__) self.config = _defaults if config is not None: @@ -179,6 +185,7 @@ def __init__(self, config=None, loop=None, plugin_namespace=None): self._servers = dict() self._init_states() + self._sessions = dict() self._subscriptions = dict() self._retained_messages = dict() @@ -381,7 +388,7 @@ async def shutdown(self): await self.plugins_manager.fire_event(EVENT_BROKER_POST_SHUTDOWN) self.transitions.stopping_success() - async def internal_message_broadcast(self, topic, data, qos=None): + async def internal_message_broadcast(self, topic: str, data: bytes, qos: Optional[int] = None): return await self._broadcast_message(None, topic, data) async def ws_connected(self, websocket, uri, listener_name): @@ -652,7 +659,7 @@ async def client_connected( self.logger.debug("%s Client disconnected" % client_session.client_id) server.release_connection() - def _init_handler(self, session, reader, writer): + def _init_handler(self, session: Session, reader: Type[ReaderAdapter], writer: Type[WriterAdapter]): """ Create a BrokerProtocolHandler and attach to a session :return: @@ -753,7 +760,7 @@ async def topic_filtering(self, session: Session, topic, action: Action): def retain_message( self, - source_session: Session, + source_session: Optional[Session], topic_name: str, data: bytearray, qos: Optional[int] = None, @@ -771,7 +778,7 @@ def retain_message( self.logger.debug("Clear retained messages for topic '%s'" % topic_name) del self._retained_messages[topic_name] - async def add_subscription(self, subscription, session): + async def add_subscription(self, subscription, session: Session): try: a_filter = subscription[0] if "#" in a_filter and not a_filter.endswith("#"): @@ -851,7 +858,7 @@ def _del_all_subscriptions(self, session: Session) -> None: if not self._subscriptions[topic]: del self._subscriptions[topic] - def matches(self, topic, a_filter): + def matches(self, topic: str, a_filter: str) -> Union[bool, None, Match[str]]: if "#" not in a_filter and "+" not in a_filter: # if filter doesn't contain wildcard, return exact match return a_filter == topic @@ -941,13 +948,13 @@ async def _broadcast_loop(self): await asyncio.wait(running_tasks) raise # reraise per CancelledError semantics - async def _broadcast_message(self, session, topic, data, force_qos=None): + async def _broadcast_message(self, session: Optional[Session], topic: str, data: bytes, force_qos: Optional[bool] = None): broadcast = {"session": session, "topic": topic, "data": data} if force_qos: broadcast["qos"] = force_qos await self._broadcast_queue.put(broadcast) - async def publish_session_retained_messages(self, session): + async def publish_session_retained_messages(self, session: Session): self.logger.debug( "Publishing %d messages retained for session %s" % ( @@ -969,7 +976,7 @@ async def publish_session_retained_messages(self, session): if publish_tasks: await asyncio.wait(publish_tasks) - async def publish_retained_messages_for_subscription(self, subscription, session): + async def publish_retained_messages_for_subscription(self, subscription, session: Session): self.logger.debug( "Begin broadcasting messages retained due to subscription on '%s' from %s" % (subscription[0], format_client_message(session=session)) @@ -1018,7 +1025,7 @@ def delete_session(self, client_id: str) -> None: ) del self._sessions[client_id] - def _get_handler(self, session): + def _get_handler(self, session: Session): client_id = session.client_id if client_id: try: diff --git a/amqtt/client.py b/amqtt/client.py index d14cb01a..5b4d46a7 100644 --- a/amqtt/client.py +++ b/amqtt/client.py @@ -8,6 +8,7 @@ import copy from urllib.parse import urlparse, urlunparse from functools import wraps +from typing import List, Tuple from amqtt.session import Session from amqtt.mqtt.connack import CONNECTION_ACCEPTED @@ -310,7 +311,7 @@ def get_retain_and_qos(): ) @mqtt_connected - async def subscribe(self, topics): + async def subscribe(self, topics: List[Tuple[str,int]]): """ Subscribe to some topics. @@ -332,7 +333,7 @@ async def subscribe(self, topics): return await self._handler.mqtt_subscribe(topics, self.session.next_packet_id) @mqtt_connected - async def unsubscribe(self, topics): + async def unsubscribe(self, topics: List[str]): """ Unsubscribe from some topics. @@ -349,7 +350,7 @@ async def unsubscribe(self, topics): """ await self._handler.mqtt_unsubscribe(topics, self.session.next_packet_id) - async def deliver_message(self, timeout=None): + async def deliver_message(self, timeout: int = None): """ Deliver next received message. diff --git a/amqtt/codecs.py b/amqtt/codecs.py index 45cb7c52..65f2cef6 100644 --- a/amqtt/codecs.py +++ b/amqtt/codecs.py @@ -4,6 +4,7 @@ import asyncio from struct import pack, unpack from amqtt.errors import NoDataException +from amqtt.adapters import ReaderAdapter def bytes_to_hex_str(data): @@ -15,7 +16,7 @@ def bytes_to_hex_str(data): return "0x" + "".join(format(b, "02x") for b in data) -def bytes_to_int(data): +def bytes_to_int(data: bytes): """ convert a sequence of bytes to an integer using big endian byte ordering :param data: byte sequence @@ -41,7 +42,7 @@ def int_to_bytes(int_value: int, length: int) -> bytes: return pack(fmt, int_value) -async def read_or_raise(reader, n=-1): +async def read_or_raise(reader: ReaderAdapter, n=-1): """ Read a given byte number from Stream. NoDataException is raised if read gives no data :param reader: reader adapter @@ -57,7 +58,7 @@ async def read_or_raise(reader, n=-1): return data -async def decode_string(reader) -> str: +async def decode_string(reader: ReaderAdapter) -> str: """ Read a string from a reader and decode it according to MQTT string specification :param reader: Stream reader @@ -75,7 +76,7 @@ async def decode_string(reader) -> str: return "" -async def decode_data_with_length(reader) -> bytes: +async def decode_data_with_length(reader: ReaderAdapter) -> bytes: """ Read data from a reader. Data is prefixed with 2 bytes length :param reader: Stream reader @@ -98,7 +99,7 @@ def encode_data_with_length(data: bytes) -> bytes: return int_to_bytes(data_length, 2) + data -async def decode_packet_id(reader) -> int: +async def decode_packet_id(reader: ReaderAdapter) -> int: """ Read a packet ID as 2-bytes int from stream according to MQTT specification (2.3.1) :param reader: Stream reader diff --git a/amqtt/mqtt/connack.py b/amqtt/mqtt/connack.py index 7af0b53f..c114a3e6 100644 --- a/amqtt/mqtt/connack.py +++ b/amqtt/mqtt/connack.py @@ -1,24 +1,28 @@ # Copyright (c) 2015 Nicolas JOUANIN # # See the file license.txt for copying permission. +# Required for type hints in classes that self reference for python < v3.10 +from __future__ import annotations +from typing import Optional + from amqtt.mqtt.packet import CONNACK, MQTTPacket, MQTTFixedHeader, MQTTVariableHeader from amqtt.codecs import read_or_raise, bytes_to_int from amqtt.errors import AMQTTException from amqtt.adapters import ReaderAdapter -CONNECTION_ACCEPTED = 0x00 -UNACCEPTABLE_PROTOCOL_VERSION = 0x01 -IDENTIFIER_REJECTED = 0x02 -SERVER_UNAVAILABLE = 0x03 -BAD_USERNAME_PASSWORD = 0x04 -NOT_AUTHORIZED = 0x05 +CONNECTION_ACCEPTED: int = 0x00 +UNACCEPTABLE_PROTOCOL_VERSION: int = 0x01 +IDENTIFIER_REJECTED: int = 0x02 +SERVER_UNAVAILABLE: int = 0x03 +BAD_USERNAME_PASSWORD: int = 0x04 +NOT_AUTHORIZED: int = 0x05 class ConnackVariableHeader(MQTTVariableHeader): __slots__ = ("session_parent", "return_code") - def __init__(self, session_parent=None, return_code=None): + def __init__(self, session_parent: Optional[int] = None, return_code: Optional[int] = None): super().__init__() self.session_parent = session_parent self.return_code = return_code @@ -57,7 +61,7 @@ def return_code(self): return self.variable_header.return_code @return_code.setter - def return_code(self, return_code): + def return_code(self, return_code: int): self.variable_header.return_code = return_code @property @@ -65,14 +69,14 @@ def session_parent(self): return self.variable_header.session_parent @session_parent.setter - def session_parent(self, session_parent): + def session_parent(self, session_parent: int): self.variable_header.session_parent = session_parent def __init__( self, - fixed: MQTTFixedHeader = None, - variable_header: ConnackVariableHeader = None, - payload=None, + fixed: Optional[MQTTFixedHeader] = None, + variable_header: Optional[ConnackVariableHeader] = None, + payload = None, ): if fixed is None: header = MQTTFixedHeader(CONNACK, 0x00) @@ -88,7 +92,7 @@ def __init__( self.payload = None @classmethod - def build(cls, session_parent=None, return_code=None): + def build(cls, session_parent: int = None, return_code: int = None) -> ConnackPacket: v_header = ConnackVariableHeader(session_parent, return_code) packet = ConnackPacket(variable_header=v_header) return packet diff --git a/amqtt/mqtt/connect.py b/amqtt/mqtt/connect.py index 3f7debea..61773cd7 100644 --- a/amqtt/mqtt/connect.py +++ b/amqtt/mqtt/connect.py @@ -1,6 +1,9 @@ # Copyright (c) 2015 Nicolas JOUANIN # # See the file license.txt for copying permission. +# Required for type hints in classes that self reference for python < v3.10 +from __future__ import annotations +from typing import Optional from amqtt.codecs import ( bytes_to_int, @@ -27,16 +30,16 @@ class ConnectVariableHeader(MQTTVariableHeader): __slots__ = ("proto_name", "proto_level", "flags", "keep_alive") - USERNAME_FLAG = 0x80 - PASSWORD_FLAG = 0x40 - WILL_RETAIN_FLAG = 0x20 - WILL_FLAG = 0x04 - WILL_QOS_MASK = 0x18 - CLEAN_SESSION_FLAG = 0x02 - RESERVED_FLAG = 0x01 + USERNAME_FLAG: int = 0x80 + PASSWORD_FLAG: int = 0x40 + WILL_RETAIN_FLAG: int = 0x20 + WILL_FLAG: int = 0x04 + WILL_QOS_MASK: int = 0x18 + CLEAN_SESSION_FLAG: int = 0x02 + RESERVED_FLAG: int = 0x01 def __init__( - self, connect_flags=0x00, keep_alive=0, proto_name="MQTT", proto_level=0x04 + self, connect_flags: int = 0x00, keep_alive: int = 0, proto_name: str = "MQTT", proto_level: int = 0x04 ): super().__init__() self.proto_name = proto_name @@ -115,7 +118,7 @@ def will_qos(self, val: int): self.flags |= val << 3 @classmethod - async def from_stream(cls, reader: ReaderAdapter, fixed_header: MQTTFixedHeader): + async def from_stream(cls, reader: ReaderAdapter, fixed_header: MQTTFixedHeader) -> ConnectVariableHeader: # protocol name protocol_name = await decode_string(reader) @@ -133,7 +136,7 @@ async def from_stream(cls, reader: ReaderAdapter, fixed_header: MQTTFixedHeader) return cls(flags, keep_alive, protocol_name, protocol_level) - def to_bytes(self): + def to_bytes(self) -> bytearray: out = bytearray() # Protocol name @@ -161,11 +164,11 @@ class ConnectPayload(MQTTPayload): def __init__( self, - client_id=None, - will_topic=None, - will_message=None, - username=None, - password=None, + client_id: Optional[str] = None, + will_topic: str = None, + will_message: Optional[bytes] = None, + username: Optional[str] = None, + password: Optional[str] = None, ): super().__init__() self.client_id_is_random = False @@ -190,7 +193,7 @@ async def from_stream( reader: ReaderAdapter, fixed_header: MQTTFixedHeader, variable_header: ConnectVariableHeader, - ): + ) -> ConnectPayload: payload = cls() # Client identifier try: @@ -230,7 +233,7 @@ async def from_stream( def to_bytes( self, fixed_header: MQTTFixedHeader, variable_header: ConnectVariableHeader - ): + ) -> bytearray: out = bytearray() # Client identifier out.extend(encode_string(self.client_id)) @@ -253,7 +256,7 @@ class ConnectPacket(MQTTPacket): PAYLOAD = ConnectPayload @property - def proto_name(self): + def proto_name(self) -> str: return self.variable_header.proto_name @proto_name.setter @@ -261,55 +264,55 @@ def proto_name(self, name: str): self.variable_header.proto_name = name @property - def proto_level(self): + def proto_level(self) -> int: return self.variable_header.proto_level @proto_level.setter - def proto_level(self, level): + def proto_level(self, level: int): self.variable_header.proto_level = level @property - def username_flag(self): + def username_flag(self) -> bool: return self.variable_header.username_flag @username_flag.setter - def username_flag(self, flag): + def username_flag(self, flag: bool): self.variable_header.username_flag = flag @property - def password_flag(self): + def password_flag(self) -> bool: return self.variable_header.password_flag @password_flag.setter - def password_flag(self, flag): + def password_flag(self, flag: bool): self.variable_header.password_flag = flag @property - def clean_session_flag(self): + def clean_session_flag(self) -> bool: return self.variable_header.clean_session_flag @clean_session_flag.setter - def clean_session_flag(self, flag): + def clean_session_flag(self, flag: bool): self.variable_header.clean_session_flag = flag @property - def will_retain_flag(self): + def will_retain_flag(self) -> bool: return self.variable_header.will_retain_flag @will_retain_flag.setter - def will_retain_flag(self, flag): + def will_retain_flag(self, flag: bool): self.variable_header.will_retain_flag = flag @property - def will_qos(self): + def will_qos(self) -> int: return self.variable_header.will_qos @will_qos.setter - def will_qos(self, flag): - self.variable_header.will_qos = flag + def will_qos(self, will_qos: int): + self.variable_header.will_qos = will_qos @property - def will_flag(self): + def will_flag(self) -> bool: return self.variable_header.will_flag @will_flag.setter @@ -325,11 +328,11 @@ def reserved_flag(self, flag): self.variable_header.reserved_flag = flag @property - def client_id(self): + def client_id(self) -> str: return self.payload.client_id @client_id.setter - def client_id(self, client_id): + def client_id(self, client_id: str): self.payload.client_id = client_id @property @@ -341,43 +344,43 @@ def client_id_is_random(self, client_id_is_random: bool): self.payload.client_id_is_random = client_id_is_random @property - def will_topic(self): + def will_topic(self) -> str: return self.payload.will_topic @will_topic.setter - def will_topic(self, will_topic): + def will_topic(self, will_topic: str): self.payload.will_topic = will_topic @property - def will_message(self): + def will_message(self) -> bytes: return self.payload.will_message @will_message.setter - def will_message(self, will_message): + def will_message(self, will_message: bytes): self.payload.will_message = will_message @property - def username(self): + def username(self) -> str: return self.payload.username @username.setter - def username(self, username): + def username(self, username: str): self.payload.username = username @property - def password(self): + def password(self) -> str: return self.payload.password @password.setter - def password(self, password): + def password(self, password: str): self.payload.password = password @property - def keep_alive(self): + def keep_alive(self) -> int: return self.variable_header.keep_alive @keep_alive.setter - def keep_alive(self, keep_alive): + def keep_alive(self, keep_alive: int): self.variable_header.keep_alive = keep_alive def __init__( diff --git a/amqtt/mqtt/constants.py b/amqtt/mqtt/constants.py index 841d2819..656aeecf 100644 --- a/amqtt/mqtt/constants.py +++ b/amqtt/mqtt/constants.py @@ -2,6 +2,6 @@ # # See the file license.txt for copying permission. -QOS_0 = 0x00 -QOS_1 = 0x01 -QOS_2 = 0x02 +QOS_0: int = 0x00 +QOS_1: int = 0x01 +QOS_2: int = 0x02 diff --git a/amqtt/mqtt/packet.py b/amqtt/mqtt/packet.py index 01f9b1be..3da18b5f 100644 --- a/amqtt/mqtt/packet.py +++ b/amqtt/mqtt/packet.py @@ -1,7 +1,9 @@ # Copyright (c) 2015 Nicolas JOUANIN # # See the file license.txt for copying permission. +from __future__ import annotations import asyncio +from typing import Optional from amqtt.codecs import ( bytes_to_hex_str, @@ -15,34 +17,34 @@ from struct import unpack -RESERVED_0 = 0x00 -CONNECT = 0x01 -CONNACK = 0x02 -PUBLISH = 0x03 -PUBACK = 0x04 -PUBREC = 0x05 -PUBREL = 0x06 -PUBCOMP = 0x07 -SUBSCRIBE = 0x08 -SUBACK = 0x09 -UNSUBSCRIBE = 0x0A -UNSUBACK = 0x0B -PINGREQ = 0x0C -PINGRESP = 0x0D -DISCONNECT = 0x0E -RESERVED_15 = 0x0F +RESERVED_0: int = 0x00 +CONNECT: int = 0x01 +CONNACK: int = 0x02 +PUBLISH: int = 0x03 +PUBACK: int = 0x04 +PUBREC: int = 0x05 +PUBREL: int = 0x06 +PUBCOMP: int = 0x07 +SUBSCRIBE: int = 0x08 +SUBACK: int = 0x09 +UNSUBSCRIBE: int = 0x0A +UNSUBACK: int = 0x0B +PINGREQ: int = 0x0C +PINGRESP: int = 0x0D +DISCONNECT: int = 0x0E +RESERVED_15: int = 0x0F class MQTTFixedHeader: __slots__ = ("packet_type", "remaining_length", "flags") - def __init__(self, packet_type, flags=0, length=0): + def __init__(self, packet_type, flags: int = 0, length: int = 0): self.packet_type = packet_type self.remaining_length = length self.flags = flags - def to_bytes(self): + def to_bytes(self) -> bytearray: def encode_remaining_length(length: int): encoded = bytearray() while True: @@ -74,11 +76,11 @@ async def to_stream(self, writer: WriterAdapter): writer.write(self.to_bytes()) @property - def bytes_length(self): + def bytes_length(self) -> int: return len(self.to_bytes()) @classmethod - async def from_stream(cls, reader: ReaderAdapter): + async def from_stream(cls, reader: ReaderAdapter) -> Optional[MQTTFixedHeader]: """ Read and decode MQTT message fixed header from stream :return: FixedHeader instance @@ -133,14 +135,14 @@ async def to_stream(self, writer: asyncio.StreamWriter): writer.write(self.to_bytes()) await writer.drain() - def to_bytes(self) -> bytes: + def to_bytes(self) -> bytearray: """ Serialize header data to a byte array conforming to MQTT protocol :return: serialized data """ @property - def bytes_length(self): + def bytes_length(self) -> int: return len(self.to_bytes()) @classmethod @@ -268,7 +270,7 @@ async def from_stream( return instance @property - def bytes_length(self): + def bytes_length(self) -> int: return len(self.to_bytes()) def __repr__(self): diff --git a/amqtt/mqtt/protocol/broker_handler.py b/amqtt/mqtt/protocol/broker_handler.py index f7f6c5ad..be03a68b 100644 --- a/amqtt/mqtt/protocol/broker_handler.py +++ b/amqtt/mqtt/protocol/broker_handler.py @@ -1,7 +1,12 @@ # Copyright (c) 2015 Nicolas JOUANIN # # See the file license.txt for copying permission. -from asyncio import futures, Queue +# Required for type hints in classes that self reference for python < v3.10 +from __future__ import annotations +from asyncio import futures, Queue, AbstractEventLoop +from typing import Tuple + +from amqtt.mqtt.disconnect import DisconnectPacket from amqtt.mqtt.protocol.handler import ProtocolHandler from amqtt.mqtt.connack import ( CONNECTION_ACCEPTED, @@ -28,7 +33,7 @@ class BrokerProtocolHandler(ProtocolHandler): def __init__( - self, plugins_manager: PluginManager, session: Session = None, loop=None + self, plugins_manager: PluginManager, session: Session = None, loop: AbstractEventLoop = None ): super().__init__(plugins_manager, session, loop) self._disconnect_waiter = None @@ -45,7 +50,7 @@ async def stop(self): if self._disconnect_waiter is not None and not self._disconnect_waiter.done(): self._disconnect_waiter.set_result(None) - async def wait_disconnect(self): + async def wait_disconnect(self) -> futures.Future: return await self._disconnect_waiter def handle_write_timeout(self): @@ -55,7 +60,7 @@ def handle_read_timeout(self): if self._disconnect_waiter is not None and not self._disconnect_waiter.done(): self._disconnect_waiter.set_result(None) - async def handle_disconnect(self, disconnect): + async def handle_disconnect(self, disconnect: DisconnectPacket): self.logger.debug("Client disconnecting") if self._disconnect_waiter and not self._disconnect_waiter.done(): self.logger.debug("Setting waiter result to %r" % disconnect) @@ -116,8 +121,8 @@ async def mqtt_connack_authorize(self, authorize: bool): @classmethod async def init_from_connect( - cls, reader: ReaderAdapter, writer: WriterAdapter, plugins_manager, loop=None - ): + cls, reader: ReaderAdapter, writer: WriterAdapter, plugins_manager, loop: AbstractEventLoop = None + ) -> Tuple[BrokerProtocolHandler, Session]: """ :param reader: diff --git a/amqtt/mqtt/protocol/client_handler.py b/amqtt/mqtt/protocol/client_handler.py index 3fd83a27..9189b1f8 100644 --- a/amqtt/mqtt/protocol/client_handler.py +++ b/amqtt/mqtt/protocol/client_handler.py @@ -2,7 +2,8 @@ # # See the file license.txt for copying permission. import asyncio -from asyncio import futures +from asyncio import futures, AbstractEventLoop +from typing import List, Tuple from amqtt.mqtt.protocol.handler import ProtocolHandler, EVENT_MQTT_PACKET_RECEIVED from amqtt.mqtt.disconnect import DisconnectPacket from amqtt.mqtt.pingreq import PingReqPacket @@ -19,7 +20,7 @@ class ClientProtocolHandler(ProtocolHandler): def __init__( - self, plugins_manager: PluginManager, session: Session = None, loop=None + self, plugins_manager: PluginManager, session: Session = None, loop: AbstractEventLoop = None ): super().__init__(plugins_manager, session, loop=loop) self._ping_task = None @@ -42,7 +43,7 @@ async def stop(self): if not self._disconnect_waiter.done(): self._disconnect_waiter.cancel() - def _build_connect_packet(self): + def _build_connect_packet(self) -> ConnectPacket: vh = ConnectVariableHeader() payload = ConnectPayload() @@ -73,7 +74,7 @@ def _build_connect_packet(self): packet = ConnectPacket(vh=vh, payload=payload) return packet - async def mqtt_connect(self): + async def mqtt_connect(self) -> int: connect_packet = self._build_connect_packet() await self._send_packet(connect_packet) connack = await ConnackPacket.from_stream(self.reader) @@ -93,9 +94,9 @@ def handle_write_timeout(self): def handle_read_timeout(self): pass - async def mqtt_subscribe(self, topics, packet_id): + async def mqtt_subscribe(self, topics: List[Tuple[str,int]], packet_id: int) -> List[int]: """ - :param topics: array of topics [{'filter':'/a/b', 'qos': 0x00}, ...] + :param topics: array of topics [('$SYS/broker/uptime', QOS_1), ('$SYS/broker/load/#', QOS_2),] :return: """ @@ -122,7 +123,7 @@ async def handle_suback(self, suback: SubackPacket): % packet_id ) - async def mqtt_unsubscribe(self, topics, packet_id): + async def mqtt_unsubscribe(self, topics: List[str], packet_id: int): """ :param topics: array of topics ['/a/b', ...] diff --git a/amqtt/mqtt/protocol/handler.py b/amqtt/mqtt/protocol/handler.py index 8eb5a45c..ca91711e 100644 --- a/amqtt/mqtt/protocol/handler.py +++ b/amqtt/mqtt/protocol/handler.py @@ -1,12 +1,14 @@ # Copyright (c) 2015 Nicolas JOUANIN # # See the file license.txt for copying permission. +# Required for type hints in classes that self reference for python < v3.10 +from __future__ import annotations import logging import collections import itertools - import asyncio -from asyncio import InvalidStateError +from asyncio import InvalidStateError, AbstractEventLoop +from typing import Type, TypeVar from amqtt.mqtt import packet_class from amqtt.mqtt.connack import ConnackPacket @@ -29,6 +31,7 @@ DISCONNECT, RESERVED_15, MQTTFixedHeader, + MQTTPacket, ) from amqtt.mqtt.pingresp import PingRespPacket from amqtt.mqtt.pingreq import PingReqPacket @@ -44,6 +47,7 @@ from amqtt.mqtt.disconnect import DisconnectPacket from amqtt.adapters import ReaderAdapter, WriterAdapter from amqtt.session import ( + ApplicationMessage, Session, OutgoingApplicationMessage, IncomingApplicationMessage, @@ -69,7 +73,7 @@ class ProtocolHandler: """ def __init__( - self, plugins_manager: PluginManager, session: Session = None, loop=None + self, plugins_manager: PluginManager, session: Session = None, loop: AbstractEventLoop = None ): self.logger = logging.getLogger(__name__) if session: @@ -105,7 +109,7 @@ def _init_session(self, session: Session): if self.keepalive_timeout <= 0: self.keepalive_timeout = None - def attach(self, session, reader: ReaderAdapter, writer: WriterAdapter): + def attach(self, session, reader: Type[ReaderAdapter], writer: Type[WriterAdapter]): if self.session: raise ProtocolHandlerException("Handler is already attached to a session") self._init_session(session) @@ -117,7 +121,7 @@ def detach(self): self.reader = None self.writer = None - def _is_attached(self): + def _is_attached(self) -> bool: if self.session: return True else: @@ -189,7 +193,7 @@ async def _retry_deliveries(self): ) self.logger.debug("End messages delivery retries") - async def mqtt_publish(self, topic, data, qos, retain, ack_timeout=None): + async def mqtt_publish(self, topic: str, data: bytes, qos: int, retain: bool, ack_timeout: int = None): """ Sends a MQTT publish message and manages messages flows. This methods doesn't return until the message has been acknowledged by receiver or timeout occur @@ -220,7 +224,7 @@ async def mqtt_publish(self, topic, data, qos, retain, ack_timeout=None): return message - async def _handle_message_flow(self, app_message): + async def _handle_message_flow(self, app_message: ApplicationMessage): """ Handle protocol flow for incoming and outgoing messages, depending on service level and according to MQTT spec. paragraph 4.3-Quality of Service levels and protocol flows @@ -236,7 +240,7 @@ async def _handle_message_flow(self, app_message): else: raise AMQTTException("Unexcepted QOS value '%d" % str(app_message.qos)) - async def _handle_qos0_message_flow(self, app_message): + async def _handle_qos0_message_flow(self, app_message: Type[ApplicationMessage]): """ Handle QOS_0 application message acknowledgment For incoming messages, this method stores the message @@ -264,7 +268,7 @@ async def _handle_qos0_message_flow(self, app_message): "delivered messages queue full. QOS_0 message discarded" ) - async def _handle_qos1_message_flow(self, app_message): + async def _handle_qos1_message_flow(self, app_message: Type[ApplicationMessage]): """ Handle QOS_1 application message acknowledgment For incoming messages, this method stores the message and reply with PUBACK @@ -308,7 +312,7 @@ async def _handle_qos1_message_flow(self, app_message): await self._send_packet(puback) app_message.puback_packet = puback - async def _handle_qos2_message_flow(self, app_message): + async def _handle_qos2_message_flow(self, app_message: Type[ApplicationMessage]): """ Handle QOS_2 application message acknowledgment For incoming messages, this method stores the message, sends PUBREC, waits for PUBREL, initiate delivery @@ -510,7 +514,7 @@ async def _reader_loop(self): self.logger.debug("Reader coro stopped") await self.stop() - async def _send_packet(self, packet): + async def _send_packet(self, packet: MQTTPacket): try: async with self._write_lock: await packet.to_stream(self.writer) diff --git a/amqtt/mqtt/puback.py b/amqtt/mqtt/puback.py index d2b99e88..7bb341ed 100644 --- a/amqtt/mqtt/puback.py +++ b/amqtt/mqtt/puback.py @@ -15,7 +15,7 @@ class PubackPacket(MQTTPacket): PAYLOAD = None @property - def packet_id(self): + def packet_id(self) -> int: return self.variable_header.packet_id @packet_id.setter diff --git a/amqtt/mqtt/pubcomp.py b/amqtt/mqtt/pubcomp.py index 2356c012..66c41377 100644 --- a/amqtt/mqtt/pubcomp.py +++ b/amqtt/mqtt/pubcomp.py @@ -15,7 +15,7 @@ class PubcompPacket(MQTTPacket): PAYLOAD = None @property - def packet_id(self): + def packet_id(self) -> int: return self.variable_header.packet_id @packet_id.setter diff --git a/amqtt/mqtt/publish.py b/amqtt/mqtt/publish.py index d0828117..251fae2a 100644 --- a/amqtt/mqtt/publish.py +++ b/amqtt/mqtt/publish.py @@ -1,6 +1,8 @@ # Copyright (c) 2015 Nicolas JOUANIN # # See the file license.txt for copying permission. +# Required for type hints in classes that self reference for python < v3.10 +from __future__ import annotations import asyncio from amqtt.mqtt.packet import ( @@ -32,7 +34,7 @@ def __repr__(self): self.topic_name, self.packet_id ) - def to_bytes(self): + def to_bytes(self) -> bytearray: out = bytearray() out.extend(encode_string(self.topic_name)) if self.packet_id is not None: @@ -42,7 +44,7 @@ def to_bytes(self): @classmethod async def from_stream( cls, reader: asyncio.StreamReader, fixed_header: MQTTFixedHeader - ): + ) -> PublishVariableHeader: topic_name = await decode_string(reader) has_qos = (fixed_header.flags >> 1) & 0x03 if has_qos: @@ -71,7 +73,7 @@ async def from_stream( reader: asyncio.StreamReader, fixed_header: MQTTFixedHeader, variable_header: MQTTVariableHeader, - ): + ) -> PublishPayload: data = bytearray() data_length = fixed_header.remaining_length - variable_header.bytes_length length_read = 0 @@ -156,7 +158,7 @@ def qos(self, val: int): self.fixed_header.flags |= val << 1 @property - def packet_id(self): + def packet_id(self) -> int: return self.variable_header.packet_id @packet_id.setter @@ -164,7 +166,7 @@ def packet_id(self, val: int): self.variable_header.packet_id = val @property - def data(self): + def data(self) -> bytes: return self.payload.data @data.setter @@ -172,7 +174,7 @@ def data(self, data: bytes): self.payload.data = data @property - def topic_name(self): + def topic_name(self) -> str: return self.variable_header.topic_name @topic_name.setter @@ -182,7 +184,7 @@ def topic_name(self, name: str): @classmethod def build( cls, topic_name: str, message: bytes, packet_id: int, dup_flag, qos, retain - ): + ) -> PublishPacket: v_header = PublishVariableHeader(topic_name, packet_id) payload = PublishPayload(message) packet = PublishPacket(variable_header=v_header, payload=payload) diff --git a/amqtt/mqtt/pubrec.py b/amqtt/mqtt/pubrec.py index 8d7b0bcb..4f8ce349 100644 --- a/amqtt/mqtt/pubrec.py +++ b/amqtt/mqtt/pubrec.py @@ -1,6 +1,8 @@ # Copyright (c) 2015 Nicolas JOUANIN # # See the file license.txt for copying permission. +from __future__ import annotations + from amqtt.mqtt.packet import ( MQTTPacket, MQTTFixedHeader, @@ -15,7 +17,7 @@ class PubrecPacket(MQTTPacket): PAYLOAD = None @property - def packet_id(self): + def packet_id(self) -> int: return self.variable_header.packet_id @packet_id.setter @@ -41,7 +43,7 @@ def __init__( self.payload = None @classmethod - def build(cls, packet_id: int): + def build(cls, packet_id: int) -> PubrecPacket: v_header = PacketIdVariableHeader(packet_id) packet = PubrecPacket(variable_header=v_header) return packet diff --git a/amqtt/mqtt/pubrel.py b/amqtt/mqtt/pubrel.py index 34a8c115..fb1bcba4 100644 --- a/amqtt/mqtt/pubrel.py +++ b/amqtt/mqtt/pubrel.py @@ -1,6 +1,8 @@ # Copyright (c) 2015 Nicolas JOUANIN # # See the file license.txt for copying permission. +from __future__ import annotations + from amqtt.mqtt.packet import ( MQTTPacket, MQTTFixedHeader, @@ -15,7 +17,7 @@ class PubrelPacket(MQTTPacket): PAYLOAD = None @property - def packet_id(self): + def packet_id(self) -> int: return self.variable_header.packet_id @packet_id.setter @@ -41,6 +43,6 @@ def __init__( self.payload = None @classmethod - def build(cls, packet_id): + def build(cls, packet_id) -> PubrelPacket: variable_header = PacketIdVariableHeader(packet_id) return PubrelPacket(variable_header=variable_header) diff --git a/amqtt/mqtt/suback.py b/amqtt/mqtt/suback.py index 83a1a7f2..5e9820fa 100644 --- a/amqtt/mqtt/suback.py +++ b/amqtt/mqtt/suback.py @@ -1,6 +1,8 @@ # Copyright (c) 2015 Nicolas JOUANIN # # See the file license.txt for copying permission. +from __future__ import annotations + from amqtt.mqtt.packet import ( MQTTPacket, MQTTFixedHeader, @@ -18,10 +20,10 @@ class SubackPayload(MQTTPayload): __slots__ = ("return_codes",) - RETURN_CODE_00 = 0x00 - RETURN_CODE_01 = 0x01 - RETURN_CODE_02 = 0x02 - RETURN_CODE_80 = 0x80 + RETURN_CODE_00: int = 0x00 + RETURN_CODE_01: int = 0x01 + RETURN_CODE_02: int = 0x02 + RETURN_CODE_80: int = 0x80 def __init__(self, return_codes=None): super().__init__() @@ -32,7 +34,7 @@ def __repr__(self): def to_bytes( self, fixed_header: MQTTFixedHeader, variable_header: MQTTVariableHeader - ): + ) -> bytes: out = b"" for return_code in self.return_codes: out += int_to_bytes(return_code, 1) @@ -44,7 +46,7 @@ async def from_stream( reader: ReaderAdapter, fixed_header: MQTTFixedHeader, variable_header: MQTTVariableHeader, - ): + ) -> SubackPayload: return_codes = [] bytes_to_read = fixed_header.remaining_length - variable_header.bytes_length for i in range(0, bytes_to_read): @@ -82,7 +84,7 @@ def __init__( self.payload = payload @classmethod - def build(cls, packet_id, return_codes): + def build(cls, packet_id, return_codes) -> SubackPacket: variable_header = cls.VARIABLE_HEADER(packet_id) payload = cls.PAYLOAD(return_codes) return cls(variable_header=variable_header, payload=payload) diff --git a/amqtt/mqtt/subscribe.py b/amqtt/mqtt/subscribe.py index fdff8404..c5b802e4 100644 --- a/amqtt/mqtt/subscribe.py +++ b/amqtt/mqtt/subscribe.py @@ -1,6 +1,8 @@ # Copyright (c) 2015 Nicolas JOUANIN # # See the file license.txt for copying permission. +# Required for type hints in classes that self reference for python < v3.10 +from __future__ import annotations import asyncio from amqtt.mqtt.packet import ( @@ -25,13 +27,13 @@ class SubscribePayload(MQTTPayload): __slots__ = ("topics",) - def __init__(self, topics=None): + def __init__(self, topics: list = None): super().__init__() self.topics = topics or [] def to_bytes( self, fixed_header: MQTTFixedHeader, variable_header: MQTTVariableHeader - ): + ) -> bytes: out = b"" for topic in self.topics: out += encode_string(topic[0]) @@ -44,7 +46,7 @@ async def from_stream( reader: asyncio.StreamReader, fixed_header: MQTTFixedHeader, variable_header: MQTTVariableHeader, - ): + ) -> SubscribePayload: topics = [] payload_length = fixed_header.remaining_length - variable_header.bytes_length read_bytes = 0 @@ -88,7 +90,7 @@ def __init__( self.payload = payload @classmethod - def build(cls, topics, packet_id): + def build(cls, topics: list, packet_id: int) -> SubscribePacket: v_header = PacketIdVariableHeader(packet_id) payload = SubscribePayload(topics) return SubscribePacket(variable_header=v_header, payload=payload) diff --git a/amqtt/mqtt/unsuback.py b/amqtt/mqtt/unsuback.py index bd6155ae..91c4aa52 100644 --- a/amqtt/mqtt/unsuback.py +++ b/amqtt/mqtt/unsuback.py @@ -1,6 +1,7 @@ # Copyright (c) 2015 Nicolas JOUANIN # # See the file license.txt for copying permission. +from __future__ import annotations from amqtt.mqtt.packet import ( MQTTPacket, MQTTFixedHeader, @@ -35,6 +36,6 @@ def __init__( self.payload = payload @classmethod - def build(cls, packet_id): + def build(cls, packet_id) -> UnsubackPacket: variable_header = PacketIdVariableHeader(packet_id) return cls(variable_header=variable_header) diff --git a/amqtt/mqtt/unsubscribe.py b/amqtt/mqtt/unsubscribe.py index 4ca1942b..7875e4b6 100644 --- a/amqtt/mqtt/unsubscribe.py +++ b/amqtt/mqtt/unsubscribe.py @@ -1,6 +1,8 @@ # Copyright (c) 2015 Nicolas JOUANIN # # See the file license.txt for copying permission. +# Required for type hints in classes that self reference for python < v3.10 +from __future__ import annotations import asyncio from amqtt.mqtt.packet import ( @@ -25,7 +27,7 @@ def __init__(self, topics=None): def to_bytes( self, fixed_header: MQTTFixedHeader, variable_header: MQTTVariableHeader - ): + ) -> bytes: out = b"" for topic in self.topics: out += encode_string(topic) @@ -37,7 +39,7 @@ async def from_stream( reader: asyncio.StreamReader, fixed_header: MQTTFixedHeader, variable_header: MQTTVariableHeader, - ): + ) -> UnubscribePayload: topics = [] payload_length = fixed_header.remaining_length - variable_header.bytes_length read_bytes = 0 @@ -76,7 +78,7 @@ def __init__( self.payload = payload @classmethod - def build(cls, topics, packet_id): + def build(cls, topics, packet_id) -> UnsubscribePacket: v_header = PacketIdVariableHeader(packet_id) payload = UnubscribePayload(topics) return UnsubscribePacket(variable_header=v_header, payload=payload) diff --git a/amqtt/session.py b/amqtt/session.py index 7213e223..fab97cc0 100644 --- a/amqtt/session.py +++ b/amqtt/session.py @@ -1,10 +1,17 @@ # Copyright (c) 2015 Nicolas JOUANIN # # See the file license.txt for copying permission. -from transitions import Machine from asyncio import Queue +from typing import Optional from collections import OrderedDict + +from transitions import Machine + from amqtt.mqtt.publish import PublishPacket +from amqtt.mqtt.puback import PubackPacket +from amqtt.mqtt.pubrec import PubrecPacket +from amqtt.mqtt.pubrel import PubrelPacket +from amqtt.mqtt.pubcomp import PubcompPacket from amqtt.errors import AMQTTException OUTGOING = 0 @@ -29,8 +36,8 @@ class ApplicationMessage: "pubrel_packet", "pubcomp_packet", ) - - def __init__(self, packet_id, topic, qos, data, retain): + + def __init__(self, packet_id: int, topic: str, qos: int, data: bytes, retain: bool): self.packet_id = packet_id """ Publish message `packet identifier `_""" @@ -46,22 +53,22 @@ def __init__(self, packet_id, topic, qos, data, retain): self.retain = retain """ Publish message retain flag""" - self.publish_packet = None + self.publish_packet: Optional[PublishPacket] = None """ :class:`amqtt.mqtt.publish.PublishPacket` instance corresponding to the `PUBLISH `_ packet in the messages flow. ``None`` if the PUBLISH packet has not already been received or sent.""" - self.puback_packet = None + self.puback_packet: Optional[PubackPacket] = None """ :class:`amqtt.mqtt.puback.PubackPacket` instance corresponding to the `PUBACK `_ packet in the messages flow. ``None`` if QoS != QOS_1 or if the PUBACK packet has not already been received or sent.""" - self.pubrec_packet = None + self.pubrec_packet: Optional[PubrecPacket] = None """ :class:`amqtt.mqtt.puback.PubrecPacket` instance corresponding to the `PUBREC `_ packet in the messages flow. ``None`` if QoS != QOS_2 or if the PUBREC packet has not already been received or sent.""" - self.pubrel_packet = None + self.pubrel_packet: Optional[PubrelPacket] = None """ :class:`amqtt.mqtt.puback.PubrelPacket` instance corresponding to the `PUBREL `_ packet in the messages flow. ``None`` if QoS != QOS_2 or if the PUBREL packet has not already been received or sent.""" - self.pubcomp_packet = None + self.pubcomp_packet: Optional[PubcompPacket] = None """ :class:`amqtt.mqtt.puback.PubrelPacket` instance corresponding to the `PUBCOMP `_ packet in the messages flow. ``None`` if QoS != QOS_2 or if the PUBCOMP packet has not already been received or sent.""" - def build_publish_packet(self, dup=False): + def build_publish_packet(self, dup: bool = False) -> PublishPacket: """ Build :class:`amqtt.mqtt.publish.PublishPacket` from attributes @@ -84,7 +91,7 @@ class IncomingApplicationMessage(ApplicationMessage): __slots__ = ("direction",) - def __init__(self, packet_id, topic, qos, data, retain): + def __init__(self, packet_id: int, topic: str, qos, data, retain): super().__init__(packet_id, topic, qos, data, retain) self.direction = INCOMING @@ -97,7 +104,7 @@ class OutgoingApplicationMessage(ApplicationMessage): __slots__ = ("direction",) - def __init__(self, packet_id, topic, qos, data, retain): + def __init__(self, packet_id: int, topic, qos, data, retain): super().__init__(packet_id, topic, qos, data, retain) self.direction = OUTGOING @@ -105,6 +112,26 @@ def __init__(self, packet_id, topic, qos, data, retain): class Session: states = ["new", "connected", "disconnected"] + remote_address: Optional[str] = None + remote_port: Optional[int] = None + client_idL: Optional[str] = None + clean_session = None + will_flag: bool = False + will_message = None + will_qos = None + will_retain = None + will_topic = None + keep_alive = 0 + publish_retry_delay: int = 0 + broker_uri = None + username = None + password = None + cafile = None + capath = None + cadata = None + _packet_id: int = 0 + parent: int = 0 + def __init__(self): self._init_states() self.remote_address = None @@ -158,7 +185,7 @@ def _init_states(self): ) @property - def next_packet_id(self): + def next_packet_id(self) -> int: self._packet_id += 1 if self._packet_id > 65535: self._packet_id = 1 @@ -174,15 +201,15 @@ def next_packet_id(self): return self._packet_id @property - def inflight_in_count(self): + def inflight_in_count(self) -> int: return len(self.inflight_in) @property - def inflight_out_count(self): + def inflight_out_count(self) -> int: return len(self.inflight_out) @property - def retained_messages_count(self): + def retained_messages_count(self) -> int: return self.retained_messages.qsize() def __repr__(self): diff --git a/amqtt/utils.py b/amqtt/utils.py index 08f4a715..7f9c5c04 100644 --- a/amqtt/utils.py +++ b/amqtt/utils.py @@ -3,14 +3,13 @@ # See the file license.txt for copying permission. from __future__ import annotations - +from typing import Any, TYPE_CHECKING import logging import random import yaml -import typing -if typing.TYPE_CHECKING: +if TYPE_CHECKING: from amqtt.session import Session logger = logging.getLogger(__name__) @@ -36,7 +35,7 @@ def gen_client_id() -> str: return gen_id -def read_yaml_config(config_file: str) -> dict: +def read_yaml_config(config_file: str) -> Any: config = None try: with open(config_file) as stream: diff --git a/amqtt/version.py b/amqtt/version.py index 6d2241a5..3d23ba50 100644 --- a/amqtt/version.py +++ b/amqtt/version.py @@ -7,14 +7,14 @@ import amqtt -def get_version(): +def get_version() -> str: warnings.warn( "amqtt.version.get_version() is deprecated, use amqtt.__version__ instead" ) return amqtt.__version__ -def get_git_changeset(): +def get_git_changeset() -> str: """Returns a numeric identifier of the latest git changeset. The result is the UTC timestamp of the changeset in YYYYMMDDHHMMSS format. This value isn't guaranteed to be unique, but collisions are very unlikely,