Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support latched topics in udp bridge #33

Merged
merged 3 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions udp_bridge/aes_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,18 @@ def __init__(self, key: str | None):
else:
self.encryption_key = None

def encrypt(self, message: str) -> bytes:
def encrypt(self, message: bytes) -> bytes:
if message == "":
raise ValueError("Cannot encrypt empty message")
if self.encryption_key is None:
return bytes(message, encoding="UTF-8")

return Fernet(key=self.encryption_key).encrypt(bytes(message, encoding="UTF-8"))
return Fernet(key=self.encryption_key).encrypt(message)

def decrypt(self, enc: bytes) -> str:
def decrypt(self, enc: bytes) -> bytes:
if len(enc) == 0:
raise ValueError("Cannot decrypt empty data")
if self.encryption_key is None:
return str(enc, encoding="UTF-8")

return str(Fernet(key=self.encryption_key).decrypt(enc), encoding="UTF-8")
return Fernet(key=self.encryption_key).decrypt(enc)
10 changes: 4 additions & 6 deletions udp_bridge/message_handler.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
import base64
import pickle
import zlib

from udp_bridge.aes_helper import AESCipher


class MessageHandler:
PACKAGE_DELIMITER = b"\xff\xff\xff"

def __init__(self, encryption_key: str | None):
self.cipher = AESCipher(encryption_key)

def encrypt_and_encode(self, data: dict) -> bytes:
serialized_data = base64.b64encode(pickle.dumps(data, pickle.HIGHEST_PROTOCOL)).decode("ASCII")
serialized_data = zlib.compress(pickle.dumps(data, pickle.HIGHEST_PROTOCOL))
return self.cipher.encrypt(serialized_data)

def dencrypt_and_decode(self, msg: bytes):
def decrypt_and_decode(self, msg: bytes):
decrypted_msg = self.cipher.decrypt(msg)
binary_msg = base64.b64decode(decrypted_msg)
binary_msg = zlib.decompress(decrypted_msg)
return pickle.loads(binary_msg)
12 changes: 4 additions & 8 deletions udp_bridge/receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,11 @@ def recv_message(self):
"""
Receive a message from the network, process it and publish it into ROS
"""
acc = b""
while rclpy.ok():
try:
acc += self.sock.recv(10240)

if acc[-3:] == MessageHandler.PACKAGE_DELIMITER:
self.handle_message(acc[:-3])
acc = b""

# 65535 is the upper limit for the size because of network properties
msg = self.sock.recv(65535)
self.handle_message(msg)
except socket.timeout:
pass

Expand All @@ -49,7 +45,7 @@ def handle_message(self, msg: bytes):
Handle a new message which came in from the socket
"""
try:
deserialized_msg = self.message_handler.dencrypt_and_decode(msg)
deserialized_msg = self.message_handler.decrypt_and_decode(msg)
data = deserialized_msg.get("data")
topic: str = deserialized_msg.get("topic")
hostname: str = deserialized_msg.get("hostname")
Expand Down
36 changes: 30 additions & 6 deletions udp_bridge/sender.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import rclpy
from bitbots_utils.utils import get_parameters_from_other_node
from rclpy.executors import SingleThreadedExecutor
from rclpy.logging import LoggingSeverity
from rclpy.node import Node
from rclpy.qos import DurabilityPolicy, QoSProfile
from rclpy.subscription import Subscription
from rclpy.timer import Timer
from ros2topic.api import get_msg_class, get_topic_names
Expand Down Expand Up @@ -36,6 +38,7 @@ def __init__(self, topic: str, queue_size: int, message_handler: MessageHandler,
self.timer: Timer | None = None

self.__subscriber: Subscription | None = None
self.__latched_subscriber: Subscription | None = None
self.__subscribe()

def __subscribe(self, backoff=1.0):
Expand All @@ -55,19 +58,34 @@ def __subscribe(self, backoff=1.0):

if topic is not None:
data_class = get_msg_class(self.node, topic)
self.node.get_logger().info(str(data_class))

if data_class is not None:
# topic is known
self.node.get_logger().info(f"Want to subscribe to topic {self.topic}")
self.node.get_logger().debug(f"Want to subscribe to topic {self.topic}")
# find out if topic is latched / transient local
publisher_infos = self.node.get_publishers_info_by_topic(topic)
latched = any(info.qos_profile.durability == DurabilityPolicy.TRANSIENT_LOCAL for info in publisher_infos)
self.__subscriber = self.node.create_subscription(data_class, self.topic, self.__message_callback, 1)
self.node.get_logger().info(f"Subscribed to topic {self.topic}")
if latched:
self.__latched_subscriber = self.node.create_subscription(
data_class,
self.topic,
lambda msg: self.__message_callback(msg, latched=True),
QoSProfile(depth=1, durability=DurabilityPolicy.TRANSIENT_LOCAL),
)
self.node.get_logger().debug(f"Subscribed to topic {self.topic}")
else:
# topic is not yet known
self.node.get_logger().info(f"Topic {self.topic} is not yet known. Retrying in {backoff} seconds")
if backoff > 10:
logging_severity = LoggingSeverity.WARN
else:
logging_severity = LoggingSeverity.DEBUG
self.node.get_logger().log(
f"Topic {self.topic} is not yet known. Retrying in {backoff} seconds", logging_severity
)
self.timer = self.node.create_timer(backoff, lambda: self.__subscribe(backoff * 1.2))

def __message_callback(self, data):
def __message_callback(self, data, latched=False):
encrypted_msg = self.message_handler.encrypt_and_encode(
{
"data": data,
Expand All @@ -81,6 +99,12 @@ def __message_callback(self, data):
except Full:
self.node.get_logger().warn(f"Could not enqueue new message of topic {self.topic}. Queue full.")

# for latched messages, republish them every ten seconds because we cannot latch on the other side
if latched:
if self.timer:
self.timer.cancel()
self.timer = self.node.create_timer(10.0, lambda: self.__message_callback(data, latched=True))


# @TODO: replace by usage of https://github.com/PickNikRobotics/generate_parameter_library
def validate_params(node: Node) -> bool:
Expand Down Expand Up @@ -157,7 +181,7 @@ def send_messages_in_queue(self):
data = subscriber.queue.get_nowait()

try:
self.sock.sendto(data + MessageHandler.PACKAGE_DELIMITER, (self.target, self.port))
self.sock.sendto(data, (self.target, self.port))
except Exception as e:
self.node.get_logger().error(
f"Could not send data of topic {subscriber.topic} to {self.target} with error {str(e)}"
Expand Down
Loading