Skip to content

Commit

Permalink
Merge pull request #11 from intercreate/fix/serial-coexistence
Browse files Browse the repository at this point in the history
fix: allow non SMP traffic on serial transport; test
  • Loading branch information
JPHutchins authored Mar 18, 2024
2 parents 5bd2fd4 + 0350710 commit 31717db
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 14 deletions.
124 changes: 110 additions & 14 deletions smpclient/transport/serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import logging
import math
from enum import IntEnum, unique
from functools import cached_property
from typing import Final

Expand Down Expand Up @@ -33,6 +34,31 @@ def _base64_max(size: int) -> int:
class SMPSerialTransport:
_POLLING_INTERVAL_S = 0.005

class _ReadBuffer:
"""The state of the read buffer."""

@unique
class State(IntEnum):
SMP = 0
"""An SMP start or continue delimiter has been received and the
`smp_buffer` is being filled with the remainder of the SMP packet.
"""

SER = 1
"""The SMP start delimiter has not been received and the
`ser_buffer` is being filled with data.
"""

def __init__(self) -> None:
self.smp = bytearray([])
"""The buffer for the SMP packet."""

self.ser = bytearray([])
"""The buffer for serial data that is not part of an SMP packet."""

self.state = SMPSerialTransport._ReadBuffer.State.SER
"""The state of the read buffer."""

def __init__(
self,
mtu: int = 4096,
Expand Down Expand Up @@ -62,7 +88,7 @@ def __init__(
inter_byte_timeout=inter_byte_timeout,
exclusive=exclusive,
)
self._buffer = bytearray([])
self._buffer = SMPSerialTransport._ReadBuffer()
logger.debug(f"Initialized {self.__class__.__name__}")

async def connect(self, address: str) -> None:
Expand Down Expand Up @@ -105,30 +131,100 @@ async def receive(self) -> bytes:
logger.debug(f"Finished receiving {len(e.value)} byte response")
return e.value

async def _readuntil(self, delimiter: bytes = b"\n") -> bytes:
async def _readuntil(self) -> bytes:
"""Read `bytes` until the `delimiter` then return the `bytes` including the `delimiter`."""

START_DELIMITER: Final = smppacket.SIXTY_NINE
CONTINUE_DELIMITER: Final = smppacket.FOUR_TWENTY
END_DELIMITER: Final = b"\n"

# fake async until I get around to replacing pyserial

i = 0
i_smp_start = 0
i_smp_end = 0
i_start: int | None = None
i_continue: int | None = None
while True:
# read the entire OS buffer
self._buffer.extend(self._conn.read_all() or [])

try: # search the buffer for the index of the delimiter
i = self._buffer.index(delimiter, i) + len(delimiter)
if self._buffer.state == SMPSerialTransport._ReadBuffer.State.SER:
# read the entire OS buffer
try:
self._buffer.ser.extend(self._conn.read_all() or [])
except StopIteration:
pass

try: # search the buffer for the index of the start delimiter
i_start = self._buffer.ser.index(START_DELIMITER)
except ValueError:
i_start = None

try: # search the buffer for the index of the continue delimiter
i_continue = self._buffer.ser.index(CONTINUE_DELIMITER)
except ValueError:
i_continue = None

if i_start is not None and i_continue is not None:
i_smp_start = min(i_start, i_continue)
elif i_start is not None:
i_smp_start = i_start
elif i_continue is not None:
i_smp_start = i_continue
else: # no delimiters found yet, clear non SMP data and wait
while True:
try: # search the buffer for newline characters
i = self._buffer.ser.index(b"\n")
try: # log as a string if possible
logger.warning(
f"{self._conn.port}: {self._buffer.ser[:i].decode()}"
)
except UnicodeDecodeError: # log as bytes if not
logger.warning(f"{self._conn.port}: {self._buffer.ser[:i].hex()}")
self._buffer.ser = self._buffer.ser[i + 1 :]
except ValueError:
break
await asyncio.sleep(SMPSerialTransport._POLLING_INTERVAL_S)
continue

if i_smp_start != 0: # log the rest of the serial buffer
try: # log as a string if possible
logger.warning(
f"{self._conn.port}: {self._buffer.ser[:i_smp_start].decode()}"
)
except UnicodeDecodeError: # log as bytes if not
logger.warning(f"{self._conn.port}: {self._buffer.ser[:i_smp_start].hex()}")

self._buffer.smp = self._buffer.ser[i_smp_start:]
self._buffer.ser.clear()
self._buffer.state = SMPSerialTransport._ReadBuffer.State.SMP
i_smp_end = 0

# don't await since the buffer may already contain the end delimiter

elif self._buffer.state == SMPSerialTransport._ReadBuffer.State.SMP:
# read the entire OS buffer
try:
self._buffer.smp.extend(self._conn.read_all() or [])
except StopIteration:
pass

try: # search the buffer for the index of the delimiter
i_smp_end = self._buffer.smp.index(END_DELIMITER, i_smp_end) + len(
END_DELIMITER
)
except ValueError: # delimiter not found yet, wait
await asyncio.sleep(SMPSerialTransport._POLLING_INTERVAL_S)
continue

# out is everything up to and including the delimiter
out = self._buffer[:i]
out = self._buffer.smp[:i_smp_end]
logger.debug(f"Received {len(out)} byte chunk")

# there may be some leftover to save for the next read
self._buffer = self._buffer[i:]
# there may be some leftover to save for the next read, but
# it's not necessarily SMP data
self._buffer.ser = self._buffer.smp[i_smp_end:]

return out
self._buffer.state = SMPSerialTransport._ReadBuffer.State.SER

except ValueError: # delimiter not found yet, wait
await asyncio.sleep(SMPSerialTransport._POLLING_INTERVAL_S)
return out

async def send_and_receive(self, data: bytes) -> bytes:
await self.send(data)
Expand Down
40 changes: 40 additions & 0 deletions tests/test_smp_serial_transport.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tests for `SMPSerialTransport`."""

import logging
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch

import pytest
Expand Down Expand Up @@ -122,6 +123,45 @@ async def test_readuntil() -> None:
assert p == await t._readuntil()


@pytest.mark.asyncio
async def test_readuntil_with_smp_server_logging(caplog: pytest.LogCaptureFixture) -> None:
t = SMPSerialTransport()
m1 = EchoWrite.Response(sequence=0, r="Hello pytest!")
m2 = EchoWrite.Response(sequence=1, r="Hello computer!")
p1 = [p for p in smppacket.encode(m1.BYTES, 8)]
p2 = [p for p in smppacket.encode(m2.BYTES, 8)]
packets = p1 + p2

t._conn.read_all = MagicMock( # type: ignore
side_effect=(
[b"Hi, there!"]
+ [b"newline \n"]
+ [b"Another line\nAgain \n"]
+ [b"log with no newline"]
+ p1
+ [b"Thought \n I'd just say hi!\n"]
+ [bytes([0, 1, 2, 3])]
+ [b"Bye!\n"]
+ p2
+ [b"One more thing...\n"]
+ [b"We \n could \n use \n newlines\n"]
)
)

t._conn.port = "/dev/ttyUSB0"

with caplog.at_level(logging.WARNING):
for p in packets:
assert p == await t._readuntil()

messages = {r.message for r in caplog.records}
assert "/dev/ttyUSB0: Hi, there!newline " in messages
assert "/dev/ttyUSB0: Another line" in messages
assert "/dev/ttyUSB0: Again " in messages
assert "/dev/ttyUSB0: log with no newline" in messages
assert "/dev/ttyUSB0: Thought \n I'd just say hi!\n\x00\x01\x02\x03Bye!\n" in messages


@pytest.mark.asyncio
async def test_send_and_receive() -> None:
t = SMPSerialTransport()
Expand Down

0 comments on commit 31717db

Please sign in to comment.