Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix remote code execution due to pickle #37

Merged
merged 1 commit into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion udp_bridge/message_handler.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import io
import pickle
import zlib

from udp_bridge.aes_helper import AESCipher


class RestrictedUnpickler(pickle.Unpickler):
def find_class(self, module, name):
raise pickle.UnpicklingError("pickle loading restricted to base types")


class MessageHandler:
def __init__(self, encryption_key: str | None):
self.cipher = AESCipher(encryption_key)
Expand All @@ -15,4 +21,4 @@ def encrypt_and_encode(self, data: dict) -> bytes:
def decrypt_and_decode(self, msg: bytes):
decrypted_msg = self.cipher.decrypt(msg)
binary_msg = zlib.decompress(decrypted_msg)
return pickle.loads(binary_msg)
return RestrictedUnpickler(io.BytesIO(binary_msg)).load()
5 changes: 4 additions & 1 deletion udp_bridge/receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import rclpy
from rclpy.node import Node
from rclpy.qos import DurabilityPolicy, QoSProfile
from rclpy.serialization import deserialize_message
from rosidl_runtime_py.utilities import get_message

from udp_bridge.message_handler import MessageHandler

Expand Down Expand Up @@ -47,7 +49,8 @@ def handle_message(self, msg: bytes):
"""
try:
deserialized_msg = self.message_handler.decrypt_and_decode(msg)
data = deserialized_msg.get("data")
msg_type_name = deserialized_msg.get("msg_type_name")
data = deserialize_message(deserialized_msg.get("data"), get_message(msg_type_name))
topic: str = deserialized_msg.get("topic")
hostname: str = deserialized_msg.get("hostname")
latched: bool = deserialized_msg.get("latched")
Expand Down
65 changes: 34 additions & 31 deletions udp_bridge/sender.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
from rclpy.logging import LoggingSeverity
from rclpy.node import Node
from rclpy.qos import DurabilityPolicy, QoSProfile
from rclpy.serialization import serialize_message
from rclpy.subscription import Subscription
from rclpy.timer import Timer
from ros2topic.api import get_msg_class, get_topic_names
from rosidl_runtime_py.utilities import get_message

from udp_bridge.message_handler import MessageHandler

Expand All @@ -36,6 +37,7 @@ def __init__(self, topic: str, queue_size: int, message_handler: MessageHandler,
self.message_handler: MessageHandler = message_handler
self.node: Node = node
self.timer: Timer | None = None
self.msg_type_name: str = None

self.__subscriber: Subscription | None = None
self.__latched_subscriber: Subscription | None = None
Expand All @@ -53,43 +55,44 @@ def __subscribe(self, backoff=1.0):
self.timer.cancel()

data_class = None
topics = get_topic_names(node=self.node)
topic = next(filter(lambda t: t == self.topic, topics), None)

if topic is not None:
data_class = get_msg_class(self.node, topic)

if data_class is not None:
# topic is known
self.node.get_logger().debug(f"Want to subscribe to topic {self.topic}")
# find out if topic is latched / transient local
publisher_infos = self.node.get_publishers_info_by_topic(topic)
latched = any(info.qos_profile.durability == DurabilityPolicy.TRANSIENT_LOCAL for info in publisher_infos)
self.__subscriber = self.node.create_subscription(data_class, self.topic, self.__message_callback, 1)
if latched:
self.__latched_subscriber = self.node.create_subscription(
data_class,
self.topic,
lambda msg: self.__message_callback(msg, latched=True),
QoSProfile(depth=1, durability=DurabilityPolicy.TRANSIENT_LOCAL),
for topic, msg_type_names in self.node.get_topic_names_and_types():
if topic == self.topic:
self.msg_type_name = msg_type_names[0]
data_class = get_message(self.msg_type_name)
# topic is known
self.node.get_logger().debug(f"Want to subscribe to topic {self.topic}")
# find out if topic is latched / transient local
publisher_infos = self.node.get_publishers_info_by_topic(topic)
latched = any(
info.qos_profile.durability == DurabilityPolicy.TRANSIENT_LOCAL for info in publisher_infos
)
self.node.get_logger().debug(f"Subscribed to topic {self.topic}")
self.__subscriber = self.node.create_subscription(data_class, self.topic, self.__message_callback, 1)
if latched:
self.__latched_subscriber = self.node.create_subscription(
data_class,
self.topic,
lambda msg: self.__message_callback(msg, latched=True),
QoSProfile(depth=1, durability=DurabilityPolicy.TRANSIENT_LOCAL),
)
self.node.get_logger().debug(f"Subscribed to topic {self.topic}")
return

# topic is not yet known
if backoff > 10:
logging_severity = LoggingSeverity.WARN
else:
# topic is not yet known
if backoff > 10:
logging_severity = LoggingSeverity.WARN
else:
logging_severity = LoggingSeverity.DEBUG
self.node.get_logger().log(
f"Topic {self.topic} is not yet known. Retrying in {backoff} seconds", logging_severity
)
self.timer = self.node.create_timer(backoff, lambda: self.__subscribe(backoff * 1.2))
logging_severity = LoggingSeverity.DEBUG
self.node.get_logger().log(
f"Topic {self.topic} is not yet known. Retrying in {backoff} seconds", logging_severity
)
self.timer = self.node.create_timer(backoff, lambda: self.__subscribe(backoff * 1.2))

def __message_callback(self, data, latched=False):
encrypted_msg = self.message_handler.encrypt_and_encode(
{
"data": data,
"data": serialize_message(data),
"topic": self.topic,
"msg_type_name": self.msg_type_name,
"hostname": HOSTNAME,
"latched": latched,
}
Expand Down
Loading