From 62705b85643e3cad3e1984287fe113bde2438b25 Mon Sep 17 00:00:00 2001 From: Timon Engelke Date: Thu, 18 Apr 2024 19:19:13 +0200 Subject: [PATCH] Fix remote code execution due to pickle --- udp_bridge/message_handler.py | 8 ++++- udp_bridge/receiver.py | 5 ++- udp_bridge/sender.py | 65 ++++++++++++++++++----------------- 3 files changed, 45 insertions(+), 33 deletions(-) diff --git a/udp_bridge/message_handler.py b/udp_bridge/message_handler.py index c0d9c68..e58904d 100644 --- a/udp_bridge/message_handler.py +++ b/udp_bridge/message_handler.py @@ -1,9 +1,15 @@ +import io import pickle import zlib from udp_bridge.aes_helper import AESCipher +class RestrictedUnpickler(pickle.Unpickler): + def find_class(self, module, name): + raise pickle.UnpicklingError("pickle loading restricted to base types") + + class MessageHandler: def __init__(self, encryption_key: str | None): self.cipher = AESCipher(encryption_key) @@ -15,4 +21,4 @@ def encrypt_and_encode(self, data: dict) -> bytes: def decrypt_and_decode(self, msg: bytes): decrypted_msg = self.cipher.decrypt(msg) binary_msg = zlib.decompress(decrypted_msg) - return pickle.loads(binary_msg) + return RestrictedUnpickler(io.BytesIO(binary_msg)).load() diff --git a/udp_bridge/receiver.py b/udp_bridge/receiver.py index 004970f..759753a 100755 --- a/udp_bridge/receiver.py +++ b/udp_bridge/receiver.py @@ -6,6 +6,8 @@ import rclpy from rclpy.node import Node from rclpy.qos import DurabilityPolicy, QoSProfile +from rclpy.serialization import deserialize_message +from rosidl_runtime_py.utilities import get_message from udp_bridge.message_handler import MessageHandler @@ -47,7 +49,8 @@ def handle_message(self, msg: bytes): """ try: deserialized_msg = self.message_handler.decrypt_and_decode(msg) - data = deserialized_msg.get("data") + msg_type_name = deserialized_msg.get("msg_type_name") + data = deserialize_message(deserialized_msg.get("data"), get_message(msg_type_name)) topic: str = deserialized_msg.get("topic") hostname: str = deserialized_msg.get("hostname") latched: bool = deserialized_msg.get("latched") diff --git a/udp_bridge/sender.py b/udp_bridge/sender.py index 03702ba..3d702e6 100755 --- a/udp_bridge/sender.py +++ b/udp_bridge/sender.py @@ -9,9 +9,10 @@ from rclpy.logging import LoggingSeverity from rclpy.node import Node from rclpy.qos import DurabilityPolicy, QoSProfile +from rclpy.serialization import serialize_message from rclpy.subscription import Subscription from rclpy.timer import Timer -from ros2topic.api import get_msg_class, get_topic_names +from rosidl_runtime_py.utilities import get_message from udp_bridge.message_handler import MessageHandler @@ -36,6 +37,7 @@ def __init__(self, topic: str, queue_size: int, message_handler: MessageHandler, self.message_handler: MessageHandler = message_handler self.node: Node = node self.timer: Timer | None = None + self.msg_type_name: str = None self.__subscriber: Subscription | None = None self.__latched_subscriber: Subscription | None = None @@ -53,43 +55,44 @@ def __subscribe(self, backoff=1.0): self.timer.cancel() data_class = None - topics = get_topic_names(node=self.node) - topic = next(filter(lambda t: t == self.topic, topics), None) - - if topic is not None: - data_class = get_msg_class(self.node, topic) - - if data_class is not None: - # topic is known - self.node.get_logger().debug(f"Want to subscribe to topic {self.topic}") - # find out if topic is latched / transient local - publisher_infos = self.node.get_publishers_info_by_topic(topic) - latched = any(info.qos_profile.durability == DurabilityPolicy.TRANSIENT_LOCAL for info in publisher_infos) - self.__subscriber = self.node.create_subscription(data_class, self.topic, self.__message_callback, 1) - if latched: - self.__latched_subscriber = self.node.create_subscription( - data_class, - self.topic, - lambda msg: self.__message_callback(msg, latched=True), - QoSProfile(depth=1, durability=DurabilityPolicy.TRANSIENT_LOCAL), + for topic, msg_type_names in self.node.get_topic_names_and_types(): + if topic == self.topic: + self.msg_type_name = msg_type_names[0] + data_class = get_message(self.msg_type_name) + # topic is known + self.node.get_logger().debug(f"Want to subscribe to topic {self.topic}") + # find out if topic is latched / transient local + publisher_infos = self.node.get_publishers_info_by_topic(topic) + latched = any( + info.qos_profile.durability == DurabilityPolicy.TRANSIENT_LOCAL for info in publisher_infos ) - self.node.get_logger().debug(f"Subscribed to topic {self.topic}") + self.__subscriber = self.node.create_subscription(data_class, self.topic, self.__message_callback, 1) + if latched: + self.__latched_subscriber = self.node.create_subscription( + data_class, + self.topic, + lambda msg: self.__message_callback(msg, latched=True), + QoSProfile(depth=1, durability=DurabilityPolicy.TRANSIENT_LOCAL), + ) + self.node.get_logger().debug(f"Subscribed to topic {self.topic}") + return + + # topic is not yet known + if backoff > 10: + logging_severity = LoggingSeverity.WARN else: - # topic is not yet known - if backoff > 10: - logging_severity = LoggingSeverity.WARN - else: - logging_severity = LoggingSeverity.DEBUG - self.node.get_logger().log( - f"Topic {self.topic} is not yet known. Retrying in {backoff} seconds", logging_severity - ) - self.timer = self.node.create_timer(backoff, lambda: self.__subscribe(backoff * 1.2)) + logging_severity = LoggingSeverity.DEBUG + self.node.get_logger().log( + f"Topic {self.topic} is not yet known. Retrying in {backoff} seconds", logging_severity + ) + self.timer = self.node.create_timer(backoff, lambda: self.__subscribe(backoff * 1.2)) def __message_callback(self, data, latched=False): encrypted_msg = self.message_handler.encrypt_and_encode( { - "data": data, + "data": serialize_message(data), "topic": self.topic, + "msg_type_name": self.msg_type_name, "hostname": HOSTNAME, "latched": latched, }