Skip to content

Commit

Permalink
Avoid a dict lookup and int conversion to process every packet (#946)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco authored Sep 3, 2024
1 parent 88a256c commit 7d112a8
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 12 deletions.
4 changes: 3 additions & 1 deletion aioesphomeapi/_frame_helper/noise.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ cdef class APINoiseFrameHelper(APIFrameHelper):
@cython.locals(
msg=bytes,
type_high="unsigned char",
type_low="unsigned char"
type_low="unsigned char",
msg_type="unsigned int",
payload=bytes
)
cdef void _handle_frame(self, bytes frame)

Expand Down
4 changes: 3 additions & 1 deletion aioesphomeapi/_frame_helper/noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,9 @@ def _handle_frame(self, frame: bytes) -> None:
# N bytes: message data
type_high = msg[0]
type_low = msg[1]
self._connection.process_packet((type_high << 8) | type_low, msg[4:])
msg_type = (type_high << 8) | type_low
payload = msg[4:]
self._connection.process_packet(msg_type, payload)

def _handle_closed(self, frame: bytes) -> None: # pylint: disable=unused-argument
"""Handle a closed frame."""
Expand Down
4 changes: 3 additions & 1 deletion aioesphomeapi/connection.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ cpdef void handle_complex_message(
cdef object _handle_timeout
cdef object _handle_complex_message

cdef tuple MESSAGE_NUMBER_TO_PROTO


@cython.dataclasses.dataclass
cdef class ConnectionParams:
Expand Down Expand Up @@ -119,7 +121,7 @@ cdef class APIConnection:
cdef void send_messages(self, tuple messages)

@cython.locals(handlers=set, handlers_copy=set)
cpdef void process_packet(self, object msg_type_proto, object data)
cpdef void process_packet(self, unsigned int msg_type_proto, object data)

cdef void _async_cancel_pong_timer(self)

Expand Down
26 changes: 17 additions & 9 deletions aioesphomeapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@

_LOGGER = logging.getLogger(__name__)

MESSAGE_NUMBER_TO_PROTO = tuple(MESSAGE_TYPE_TO_PROTO.values())


PREFERRED_BUFFER_SIZE = 2097152 # Set buffer limit to 2MB
MIN_BUFFER_SIZE = 131072 # Minimum buffer size to use

Expand Down Expand Up @@ -888,22 +891,27 @@ def _set_fatal_exception_if_unset(self, err: Exception) -> None:
def process_packet(self, msg_type_proto: _int, data: _bytes) -> None:
"""Process an incoming packet."""
debug_enabled = self._debug_enabled
if (klass := MESSAGE_TYPE_TO_PROTO.get(msg_type_proto)) is None:
if debug_enabled:
_LOGGER.debug(
"%s: Skipping unknown message type %s",
self.log_name,
msg_type_proto,
)
return

try:
# MESSAGE_NUMBER_TO_PROTO is 0-indexed
# but the message type is 1-indexed
klass = MESSAGE_NUMBER_TO_PROTO[msg_type_proto - 1]
msg: message.Message = klass()
# MergeFromString instead of ParseFromString since
# ParseFromString will clear the message first and
# the msg is already empty.
msg.MergeFromString(data)
except Exception as e:
# IndexError will be very rare so we check for it
# after the broad exception catch to avoid having
# to check the exception type twice for the common case
if isinstance(e, IndexError):
if debug_enabled:
_LOGGER.debug(
"%s: Skipping unknown message type %s",
self.log_name,
msg_type_proto,
)
return
_LOGGER.error(
"%s: Invalid protobuf message: type=%s data=%s: %s",
self.log_name,
Expand Down
2 changes: 2 additions & 0 deletions aioesphomeapi/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,3 +393,5 @@ def __init__(self, error: BluetoothGATTError) -> None:
117: UpdateStateResponse,
118: UpdateCommandRequest,
}

MESSAGE_NUMBER_TO_PROTO = tuple(MESSAGE_TYPE_TO_PROTO.values())
9 changes: 9 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from __future__ import annotations

from aioesphomeapi.core import MESSAGE_TYPE_TO_PROTO


def test_order_and_no_missing_numbers_in_message_type_to_proto():
"""Test that MESSAGE_TYPE_TO_PROTO has no missing numbers."""
for idx, (k, v) in enumerate(MESSAGE_TYPE_TO_PROTO.items()):
assert idx + 1 == k

0 comments on commit 7d112a8

Please sign in to comment.