Skip to content

Commit

Permalink
Refactored stream API to make malformed packet detection more reliable
Browse files Browse the repository at this point in the history
ajmcquilkin committed Mar 15, 2024
1 parent c2b4622 commit 7cd485e
Showing 1 changed file with 115 additions and 33 deletions.
148 changes: 115 additions & 33 deletions src/connections/stream_buffer.rs
Original file line number Diff line number Diff line change
@@ -42,6 +42,8 @@ pub enum StreamBufferError {
DecodeFailure(#[from] prost::DecodeError),
}

const PACKET_HEADER_SIZE: usize = 4;

impl StreamBuffer {
/// Creates a new StreamBuffer instance that will send decoded FromRadio packets
/// to the given broadcast channel.
@@ -111,11 +113,11 @@ impl StreamBuffer {
trace!("Packet buffer: {:?}", self.buffer);

// Check that the buffer can potentially contain a packet header
if self.buffer.len() < 4 {
if self.buffer.len() < PACKET_HEADER_SIZE {
debug!("Buffer data is shorter than packet header size, failing");
return Err(StreamBufferError::IncompletePacket {
buffer_size: self.buffer.len(),
packet_size: 4,
packet_size: PACKET_HEADER_SIZE,
});
}

@@ -135,17 +137,53 @@ impl StreamBuffer {
);

self.buffer = self.buffer[framing_index..].to_vec();

log::trace!("Buffer after shifting: {:?}", self.buffer);

framing_index = self.get_framing_index()?;
}

// Note: the framing index should always be 0 at this point, keeping for clarity
let incoming_packet_data_size = self.extract_data_size_from_header(framing_index)?;

self.validate_packet_in_buffer(incoming_packet_data_size, framing_index)?;

// Get packet data, excluding magic bytes
let packet_data =
self.extract_packet_from_buffer(incoming_packet_data_size, framing_index)?;

// Attempt to decode the current packet
let decoded_packet = protobufs::FromRadio::decode(packet_data.as_slice())?;

Ok(decoded_packet)
}

// All valid packets start with the sequence [0x94 0xc3 size_msb size_lsb], where
// size_msb and size_lsb collectively give the size of the incoming packet
// Note that the maximum packet size currently stands at 240 bytes, meaning an MSB is not needed
fn get_framing_index(&mut self) -> Result<usize, StreamBufferError> {
match self.buffer.iter().position(|&b| b == 0x94) {
Some(idx) => Ok(idx),
None => {
warn!("Could not find index of 0x94, purging buffer");
self.buffer.clear(); // Clear buffer since no packets exist
Err(StreamBufferError::MissingHeaderByte)
}
}
}

fn extract_data_size_from_header(
&self,
framing_index: usize,
) -> Result<usize, StreamBufferError> {
// Get the "framing byte" after the start of the packet header, or fail if not found
let framing_byte = match self.buffer.get(framing_index + 1) {
Some(val) => val,
None => {
debug!("Could not find framing byte, waiting for more data");
return Err(StreamBufferError::IncompletePacket {
buffer_size: self.buffer.len(),
packet_size: 4,
packet_size: PACKET_HEADER_SIZE,
});
}
};
@@ -180,26 +218,47 @@ impl StreamBuffer {

// Combine MSB and LSB of incoming packet size bytes
// Recall that packet size doesn't include the first four magic bytes
let incoming_packet_size: usize = usize::from(4 + u16::from_le_bytes([*lsb, *msb]));
let incoming_packet_data_size: usize = usize::from(u16::from_le_bytes([*lsb, *msb]));

return Ok(incoming_packet_data_size);
}

// Defer decoding until the correct number of bytes are received
if self.buffer.len() < incoming_packet_size {
warn!("Stream buffer size is less than size of packet");
fn validate_packet_in_buffer(
&mut self,
packet_data_size: usize,
framing_index: usize,
) -> Result<(), StreamBufferError> {
if self.buffer.len() < PACKET_HEADER_SIZE + packet_data_size {
return Err(StreamBufferError::IncompletePacket {
buffer_size: self.buffer.len(),
packet_size: incoming_packet_size,
packet_size: packet_data_size,
});
}

// Get packet data, excluding magic bytes
let packet: Vec<u8> = self.buffer[4..incoming_packet_size].to_vec();

// Packet is malformed if the start of another packet occurs within the
// defined limits of the current packet
let malformed_packet_detector_index = packet.iter().position(|&b| b == 0x94);
let packet_data_start_index = framing_index + PACKET_HEADER_SIZE;

trace!(
"Validating bytes in range [{}, {})",
packet_data_start_index,
packet_data_start_index + packet_data_size
);

// Packet is malformed if the start of another packet occurs within the defined limits of the current packet
let malformed_packet_detector_index = self
.buffer
.iter()
.enumerate()
// Only want to check within the range of the current packet's data (not header)
.filter(|&(i, _)| {
packet_data_start_index <= i && i < packet_data_start_index + packet_data_size
})
.position(|(_, b)| *b == 0x94)
// `position` returns the index from the filtered array, need to re-normalize to the original buffer
.map(|idx| idx + packet_data_start_index);

let malformed_packet_detector_byte = if let Some(index) = malformed_packet_detector_index {
packet.get(index + 1)
trace!("Found 0x94 at index {}", index);
self.buffer.get(index + 1)
} else {
None
};
@@ -220,27 +279,45 @@ impl StreamBuffer {
});
}

// Remove current packet from buffer based on start location of next packet
self.buffer = self.buffer[incoming_packet_size..].to_vec();

// Attempt to decode the current packet
let decoded_packet = protobufs::FromRadio::decode(packet.as_slice())?;

Ok(decoded_packet)
Ok(())
}

// All valid packets start with the sequence [0x94 0xc3 size_msb size_lsb], where
// size_msb and size_lsb collectively give the size of the incoming packet
// Note that the maximum packet size currently stands at 240 bytes, meaning an MSB is not needed
fn get_framing_index(&mut self) -> Result<usize, StreamBufferError> {
match self.buffer.iter().position(|&b| b == 0x94) {
Some(idx) => Ok(idx),
None => {
warn!("Could not find index of 0x94, purging buffer");
self.buffer.clear(); // Clear buffer since no packets exist
Err(StreamBufferError::MissingHeaderByte)
}
fn extract_packet_from_buffer(
&mut self,
packet_data_size: usize,
framing_index: usize,
) -> Result<Vec<u8>, StreamBufferError> {
if self.buffer.len() < packet_data_size {
return Err(StreamBufferError::IncompletePacket {
buffer_size: self.buffer.len(),
packet_size: packet_data_size,
});
}

let packet_size = PACKET_HEADER_SIZE + packet_data_size;

// Extract packet with header before removing header
let mut packet_data_with_header: Vec<u8> =
self.buffer.drain(framing_index..packet_size).collect();

trace!(
"Extracted packet data with header of length {:?} from buffer: {:?}",
packet_data_with_header.len(),
packet_data_with_header
);

// Remove header bytes
let packet_data: Vec<u8> = packet_data_with_header
.drain(PACKET_HEADER_SIZE..)
.collect();

trace!(
"Extracted packet data of length {:?} from buffer: {:?}",
packet_data.len(),
packet_data
);

Ok(packet_data)
}
}

@@ -376,4 +453,9 @@ mod tests {

assert_eq!(timeout_test(mock_rx.recv(), None).await, Some(valid_packet));
}

// #[tokio::test]
// async fn should_handle_incomplete_header_at_start_of_buffer() {
// // TODO
// }
}

0 comments on commit 7cd485e

Please sign in to comment.