diff --git a/neqo-transport/src/addr_valid.rs b/neqo-transport/src/addr_valid.rs index a8fcd76ab9..bff239e597 100644 --- a/neqo-transport/src/addr_valid.rs +++ b/neqo-transport/src/addr_valid.rs @@ -16,7 +16,7 @@ use crate::cid::ConnectionId; use crate::packet::PacketBuilder; use crate::recovery::RecoveryToken; use crate::stats::FrameStats; -use crate::Res; +use crate::{Error, Res}; use smallvec::SmallVec; use std::convert::TryFrom; @@ -355,10 +355,11 @@ impl NewTokenState { builder: &mut PacketBuilder, tokens: &mut Vec, stats: &mut FrameStats, - ) { + ) -> Res<()> { if let Self::Server(ref mut sender) = self { - sender.write_frames(builder, tokens, stats); + sender.write_frames(builder, tokens, stats)?; } + Ok(()) } /// If this a server, buffer a NEW_TOKEN for sending. @@ -429,18 +430,22 @@ impl NewTokenSender { builder: &mut PacketBuilder, tokens: &mut Vec, stats: &mut FrameStats, - ) { + ) -> Res<()> { for t in self.tokens.iter_mut() { if t.needs_sending && t.len() <= builder.remaining() { t.needs_sending = false; builder.encode_varint(crate::frame::FRAME_TYPE_NEW_TOKEN); builder.encode_vvec(&t.token); + if builder.len() > builder.limit() { + return Err(Error::InternalError(7)); + } tokens.push(RecoveryToken::NewToken(t.seqno)); stats.new_token += 1; } } + Ok(()) } pub fn lost(&mut self, seqno: usize) { diff --git a/neqo-transport/src/cid.rs b/neqo-transport/src/cid.rs index 478ff00392..fc858ce064 100644 --- a/neqo-transport/src/cid.rs +++ b/neqo-transport/src/cid.rs @@ -497,10 +497,10 @@ impl ConnectionIdManager { entry: &ConnectionIdEntry<[u8; 16]>, builder: &mut PacketBuilder, stats: &mut FrameStats, - ) -> bool { + ) -> Res { let len = 1 + Encoder::varint_len(entry.seqno) + 1 + 1 + entry.cid.len() + 16; if builder.remaining() < len { - return false; + return Ok(false); } builder.encode_varint(FRAME_TYPE_NEW_CONNECTION_ID); @@ -508,9 +508,12 @@ impl ConnectionIdManager { builder.encode_varint(0u64); builder.encode_vec(1, &entry.cid); builder.encode(&entry.srt); + if builder.len() > builder.limit() { + return Err(Error::InternalError(8)); + } stats.new_connection_id += 1; - true + Ok(true) } pub fn write_frames( @@ -518,14 +521,14 @@ impl ConnectionIdManager { builder: &mut PacketBuilder, tokens: &mut Vec, stats: &mut FrameStats, - ) { + ) -> Res<()> { if self.generator.deref().borrow().generates_empty_cids() { debug_assert_eq!(self.generator.borrow_mut().generate_cid().unwrap().len(), 0); - return; + return Ok(()); } while let Some(entry) = self.lost_new_connection_id.pop() { - if self.write_entry(&entry, builder, stats) { + if self.write_entry(&entry, builder, stats)? { tokens.push(RecoveryToken::NewConnectionId(entry)); } else { // This shouldn't happen often. @@ -550,10 +553,11 @@ impl ConnectionIdManager { .add_local(ConnectionIdEntry::new(seqno, cid.clone(), ())); let entry = ConnectionIdEntry::new(seqno, cid, srt); - debug_assert!(self.write_entry(&entry, builder, stats)); + debug_assert!(self.write_entry(&entry, builder, stats)?); tokens.push(RecoveryToken::NewConnectionId(entry)); } } + Ok(()) } pub fn lost(&mut self, entry: &ConnectionIdEntry<[u8; 16]>) { diff --git a/neqo-transport/src/connection/mod.rs b/neqo-transport/src/connection/mod.rs index 9acea2254c..1fe571ba9d 100644 --- a/neqo-transport/src/connection/mod.rs +++ b/neqo-transport/src/connection/mod.rs @@ -1633,7 +1633,7 @@ impl Connection { address_validation: &AddressValidationInfo, quic_version: QuicVersion, grease_quic_bit: bool, - ) -> (PacketType, PacketBuilder) { + ) -> Res<(PacketType, PacketBuilder)> { let pt = PacketType::from(cspace); let mut builder = if pt == PacketType::Short { qdebug!("Building Short dcid {}", path.remote_cid()); @@ -1656,17 +1656,17 @@ impl Connection { }; builder.scramble(grease_quic_bit); if pt == PacketType::Initial { - builder.initial_token(address_validation.token()); + builder.initial_token(address_validation.token())?; } - (pt, builder) + Ok((pt, builder)) } fn add_packet_number( builder: &mut PacketBuilder, tx: &CryptoDxState, largest_acknowledged: Option, - ) -> PacketNumber { + ) -> Res { // Get the packet number and work out how long it is. let pn = tx.next_pn(); let unacked_range = if let Some(la) = largest_acknowledged { @@ -1680,8 +1680,8 @@ impl Connection { - usize::try_from(unacked_range.leading_zeros() / 8).unwrap(); // pn_len can't be zero (unacked_range is > 0) // TODO(mt) also use `4*path CWND/path MTU` to set a minimum length. - builder.pn(pn, pn_len); - pn + builder.pn(pn, pn_len)?; + Ok(pn) } fn can_grease_quic_bit(&self) -> bool { @@ -1715,13 +1715,19 @@ impl Connection { &AddressValidationInfo::None, version, grease_quic_bit, - ); + )?; builder.set_limit(min(path.amplification_limit(), path.mtu()) - tx.expansion()); + if builder.limit() > 2048 { + return Err(Error::InternalError(9)); + } + if builder.len() > builder.limit() { + return Err(Error::InternalError(25)); + } let _ = Self::add_packet_number( &mut builder, tx, self.loss_recovery.largest_acknowledged_pn(*space), - ); + )?; // ConnectionError::Application is only allowed at 1RTT. let sanitized = if *space == PNSpace::ApplicationData { @@ -1733,6 +1739,9 @@ impl Connection { .as_ref() .unwrap_or(&close) .write_frame(&mut builder); + if builder.len() > builder.limit() { + return Err(Error::InternalError(10)); + } encoder = builder.build(tx)?; } @@ -1750,14 +1759,14 @@ impl Connection { builder: &mut PacketBuilder, mut pad: bool, now: Instant, - ) -> (Vec, bool, bool) { + ) -> Res<(Vec, bool, bool)> { let mut tokens = Vec::new(); let stats = &mut self.stats.borrow_mut().frame_tx; let primary = path.borrow().is_primary(); let mut ack_eliciting = false; let ack_token = if primary { - self.acks.write_frame(space, now, builder, stats) + self.acks.write_frame(space, now, builder, stats)? } else { None }; @@ -1770,7 +1779,7 @@ impl Connection { // The probing code needs to know so it can track that. if path .borrow_mut() - .write_frames(builder, stats, full_mtu, now) + .write_frames(builder, stats, full_mtu, now)? { pad = true; ack_eliciting = true; @@ -1782,18 +1791,18 @@ impl Connection { if let Some(t) = ack_token { tokens.push(t); } - return (tokens, false, false); + return Ok((tokens, false, false)); } if primary { if space == PNSpace::ApplicationData && self.role == Role::Server { - if let Some(t) = self.state_signaling.write_done(builder) { + if let Some(t) = self.state_signaling.write_done(builder)? { tokens.push(t); stats.handshake_done += 1; } } - if let Some(t) = self.crypto.streams.write_frame(space, builder) { + if let Some(t) = self.crypto.streams.write_frame(space, builder)? { tokens.push(t); stats.crypto += 1; } @@ -1801,12 +1810,13 @@ impl Connection { if space == PNSpace::ApplicationData { self.flow_mgr .borrow_mut() - .write_frames(builder, &mut tokens, stats); + .write_frames(builder, &mut tokens, stats)?; - self.send_streams.write_frames(builder, &mut tokens, stats); - self.new_token.write_frames(builder, &mut tokens, stats); - self.cid_manager.write_frames(builder, &mut tokens, stats); - self.paths.write_frames(builder, &mut tokens, stats); + self.send_streams + .write_frames(builder, &mut tokens, stats)?; + self.new_token.write_frames(builder, &mut tokens, stats)?; + self.cid_manager.write_frames(builder, &mut tokens, stats)?; + self.paths.write_frames(builder, &mut tokens, stats)?; } } @@ -1816,6 +1826,9 @@ impl Connection { // Nothing ack-eliciting and we need to probe; send PING. debug_assert_ne!(builder.remaining(), 0); builder.encode_varint(crate::frame::FRAME_TYPE_PING); + if builder.len() > builder.limit() { + return Err(Error::InternalError(11)); + } stats.ping += 1; stats.all += 1; ack_eliciting = true; @@ -1829,7 +1842,7 @@ impl Connection { // And avoid padding if we don't have a full MTU available. pad &= ack_eliciting && space == PNSpace::ApplicationData && full_mtu; if pad { - builder.pad(); + builder.pad()?; stats.padding += 1; stats.all += 1; } @@ -1838,7 +1851,7 @@ impl Connection { tokens.push(t); } stats.all += tokens.len(); - (tokens, ack_eliciting, pad) + Ok((tokens, ack_eliciting, pad)) } /// Build a datagram, possibly from multiple packets (for different PN @@ -1877,12 +1890,12 @@ impl Connection { &self.address_validation, version, grease_quic_bit, - ); + )?; let pn = Self::add_packet_number( &mut builder, tx, self.loss_recovery.largest_acknowledged_pn(*space), - ); + )?; let payload_start = builder.len(); // Work out if we have space left. @@ -1894,10 +1907,16 @@ impl Connection { } let limit = profile.limit() - aead_expansion; builder.set_limit(limit); + if builder.limit() > 2048 { + return Err(Error::InternalError(12)); + } + if builder.len() > builder.limit() { + return Err(Error::InternalError(13)); + } // Add frames to the packet. let (tokens, ack_eliciting, padded) = - self.write_frames(path, *space, &profile, &mut builder, needs_padding, now); + self.write_frames(path, *space, &profile, &mut builder, needs_padding, now)?; if builder.packet_empty() { // Nothing to include in this packet. diff --git a/neqo-transport/src/connection/state.rs b/neqo-transport/src/connection/state.rs index c1622079f6..8cf2bae3c4 100644 --- a/neqo-transport/src/connection/state.rs +++ b/neqo-transport/src/connection/state.rs @@ -17,7 +17,7 @@ use crate::frame::{ use crate::packet::PacketBuilder; use crate::path::PathRef; use crate::recovery::RecoveryToken; -use crate::{ConnectionError, Error}; +use crate::{ConnectionError, Error, Res}; #[derive(Clone, Debug, PartialEq, Eq)] /// The state of the Connection. @@ -185,13 +185,16 @@ impl StateSignaling { *self = Self::HandshakeDone } - pub fn write_done(&mut self, builder: &mut PacketBuilder) -> Option { + pub fn write_done(&mut self, builder: &mut PacketBuilder) -> Res> { if matches!(self, Self::HandshakeDone) && builder.remaining() >= 1 { *self = Self::Idle; builder.encode_varint(FRAME_TYPE_HANDSHAKE_DONE); - Some(RecoveryToken::HandshakeDone) + if builder.len() > builder.limit() { + return Err(Error::InternalError(14)); + } + Ok(Some(RecoveryToken::HandshakeDone)) } else { - None + Ok(None) } } diff --git a/neqo-transport/src/connection/tests/idle.rs b/neqo-transport/src/connection/tests/idle.rs index ec62af93f4..cb88bcc57c 100644 --- a/neqo-transport/src/connection/tests/idle.rs +++ b/neqo-transport/src/connection/tests/idle.rs @@ -253,12 +253,14 @@ fn idle_caching() { let crypto = server .crypto .streams - .write_frame(PNSpace::Initial, &mut builder); + .write_frame(PNSpace::Initial, &mut builder) + .unwrap(); assert!(crypto.is_some()); let crypto = server .crypto .streams - .write_frame(PNSpace::Initial, &mut builder); + .write_frame(PNSpace::Initial, &mut builder) + .unwrap(); assert!(crypto.is_none()); let dgram = server.process_output(middle).dgram(); diff --git a/neqo-transport/src/crypto.rs b/neqo-transport/src/crypto.rs index 6f3823d6db..148c555661 100644 --- a/neqo-transport/src/crypto.rs +++ b/neqo-transport/src/crypto.rs @@ -1241,14 +1241,14 @@ impl CryptoStreams { &mut self, space: PNSpace, builder: &mut PacketBuilder, - ) -> Option { + ) -> Res> { let cs = self.get_mut(space).unwrap(); if let Some((offset, data)) = cs.tx.next_bytes() { let mut header_len = 1 + Encoder::varint_len(offset) + 1; // Don't bother if there isn't room for the header and some data. if builder.remaining() < header_len + 1 { - return None; + return Ok(None); } // Calculate length of data based on the minimum of: // - available data @@ -1261,16 +1261,20 @@ impl CryptoStreams { builder.encode_varint(crate::frame::FRAME_TYPE_CRYPTO); builder.encode_varint(offset); builder.encode_vvec(&data[..length]); + if builder.len() > builder.limit() { + return Err(Error::InternalError(15)); + } + cs.tx.mark_as_sent(offset, length); qdebug!("CRYPTO for {} offset={}, len={}", space, offset, length); - Some(RecoveryToken::Crypto(CryptoRecoveryToken { + Ok(Some(RecoveryToken::Crypto(CryptoRecoveryToken { space, offset, length, - })) + }))) } else { - None + Ok(None) } } } diff --git a/neqo-transport/src/flow_mgr.rs b/neqo-transport/src/flow_mgr.rs index fe3b14a061..1888950751 100644 --- a/neqo-transport/src/flow_mgr.rs +++ b/neqo-transport/src/flow_mgr.rs @@ -20,7 +20,7 @@ use crate::recv_stream::RecvStreams; use crate::send_stream::SendStreams; use crate::stats::FrameStats; use crate::stream_id::{StreamId, StreamIndex, StreamIndexes, StreamType}; -use crate::AppError; +use crate::{AppError, Error, Res}; type FlowFrame = Frame<'static>; pub type FlowControlRecoveryToken = FlowFrame; @@ -277,7 +277,7 @@ impl FlowMgr { builder: &mut PacketBuilder, tokens: &mut Vec, stats: &mut FrameStats, - ) { + ) -> Res<()> { while let Some(frame) = self.peek() { // All these frames are bags of varints, so we can just extract the // varints and use common code for writing. @@ -348,11 +348,15 @@ impl FlowMgr { for v in values { builder.encode_varint(v); } + if builder.len() > builder.limit() { + return Err(Error::InternalError(16)); + } tokens.push(RecoveryToken::Flow(self.next().unwrap())); } else { - return; + return Ok(()); } } + Ok(()) } } diff --git a/neqo-transport/src/packet/mod.rs b/neqo-transport/src/packet/mod.rs index 8e3a5287b0..6ad1e3b264 100644 --- a/neqo-transport/src/packet/mod.rs +++ b/neqo-transport/src/packet/mod.rs @@ -231,6 +231,10 @@ impl PacketBuilder { self.limit = limit; } + pub fn limit(&mut self) -> usize { + self.limit + } + /// How many bytes remain against the size limit for the builder. #[must_use] pub fn remaining(&self) -> usize { @@ -238,8 +242,14 @@ impl PacketBuilder { } /// Pad with "PADDING" frames. - pub fn pad(&mut self) { + pub fn pad(&mut self) -> Res<()> { self.encoder.pad_to(self.limit, 0); + if self.len() > self.limit { + qwarn!("Packet contents are more than the limit"); + debug_assert!(false); + return Err(Error::InternalError(17)); + } + Ok(()) } /// Add unpredictable values for unprotected parts of the packet. @@ -252,18 +262,25 @@ impl PacketBuilder { /// For an Initial packet, encode the token. /// If you fail to do this, then you will not get a valid packet. - pub fn initial_token(&mut self, token: &[u8]) { + pub fn initial_token(&mut self, token: &[u8]) -> Res<()> { debug_assert_eq!( self.encoder[self.header.start] & 0xb0, PACKET_BIT_LONG | PACKET_TYPE_INITIAL << 4 ); self.encoder.encode_vvec(token); + + if self.len() > self.limit { + qwarn!("Packet contents are more than the limit"); + debug_assert!(false); + return Err(Error::InternalError(18)); + } + Ok(()) } /// Add a packet number of the given size. /// For a long header packet, this also inserts a dummy length. /// The length is filled in after calling `build`. - pub fn pn(&mut self, pn: PacketNumber, pn_len: usize) { + pub fn pn(&mut self, pn: PacketNumber, pn_len: usize) -> Res<()> { // Reserve space for a length in long headers. if self.is_long() { self.offsets.len = self.encoder.len(); @@ -282,6 +299,13 @@ impl PacketBuilder { self.encoder[self.header.start] |= u8::try_from(pn_len - 1).unwrap(); self.header.end = self.encoder.len(); self.pn = pn; + + if self.len() > self.limit { + qwarn!("Packet contents are more than the limit"); + debug_assert!(false); + return Err(Error::InternalError(19)); + } + Ok(()) } fn write_len(&mut self, expansion: usize) { @@ -854,8 +878,8 @@ mod tests { &ConnectionId::from(&[][..]), &ConnectionId::from(SERVER_CID), ); - builder.initial_token(&[]); - builder.pn(1, 2); + builder.initial_token(&[]).unwrap(); + builder.pn(1, 2).unwrap(); builder.encode(&SAMPLE_INITIAL_PAYLOAD); let packet = builder.build(&mut prot).expect("build"); assert_eq!(&packet[..], SAMPLE_INITIAL); @@ -916,7 +940,7 @@ mod tests { fixture_init(); let mut builder = PacketBuilder::short(Encoder::new(), true, &ConnectionId::from(SERVER_CID)); - builder.pn(0, 1); + builder.pn(0, 1).unwrap(); builder.encode(SAMPLE_SHORT_PAYLOAD); // Enough payload for sampling. let packet = builder .build(&mut CryptoDxState::test_default()) @@ -932,7 +956,7 @@ mod tests { let mut builder = PacketBuilder::short(Encoder::new(), true, &ConnectionId::from(SERVER_CID)); builder.scramble(true); - builder.pn(0, 1); + builder.pn(0, 1).unwrap(); firsts.push(builder[0]); } let is_set = |bit| move |v| v & bit == bit; @@ -995,14 +1019,14 @@ mod tests { &ConnectionId::from(SERVER_CID), &ConnectionId::from(CLIENT_CID), ); - builder.pn(0, 1); + builder.pn(0, 1).unwrap(); builder.encode(&[0; 3]); let encoder = builder.build(&mut prot).expect("build"); assert_eq!(encoder.len(), 45); let first = encoder.clone(); let mut builder = PacketBuilder::short(encoder, false, &ConnectionId::from(SERVER_CID)); - builder.pn(1, 3); + builder.pn(1, 3).unwrap(); builder.encode(&[0]); // Minimal size (packet number is big enough). let encoder = builder.build(&mut prot).expect("build"); assert_eq!( @@ -1029,7 +1053,7 @@ mod tests { &ConnectionId::from(&[][..]), &ConnectionId::from(&[][..]), ); - builder.pn(0, 1); + builder.pn(0, 1).unwrap(); builder.encode(&[1, 2, 3]); let packet = builder.build(&mut CryptoDxState::test_default()).unwrap(); assert_eq!(&packet[..], EXPECTED); @@ -1048,7 +1072,7 @@ mod tests { &ConnectionId::from(&[][..]), &ConnectionId::from(&[][..]), ); - builder.pn(0, 1); + builder.pn(0, 1).unwrap(); builder.scramble(true); if (builder[0] & PACKET_BIT_FIXED_QUIC) == 0 { found_unset = true; @@ -1069,8 +1093,8 @@ mod tests { &ConnectionId::from(&[][..]), &ConnectionId::from(SERVER_CID), ); - builder.initial_token(&[]); - builder.pn(1, 2); + builder.initial_token(&[]).unwrap(); + builder.pn(1, 2).unwrap(); let encoder = builder.abort(); assert!(encoder.is_empty()); } diff --git a/neqo-transport/src/path.rs b/neqo-transport/src/path.rs index 9bf0fc0c76..3a40da0844 100644 --- a/neqo-transport/src/path.rs +++ b/neqo-transport/src/path.rs @@ -21,6 +21,7 @@ use crate::frame::{ use crate::packet::PacketBuilder; use crate::recovery::RecoveryToken; use crate::stats::FrameStats; +use crate::{Error, Res}; use neqo_common::{hex, qdebug, qinfo, qtrace, Datagram, Encoder}; use neqo_crypto::random; @@ -312,7 +313,7 @@ impl Paths { builder: &mut PacketBuilder, tokens: &mut Vec, stats: &mut FrameStats, - ) { + ) -> Res<()> { while let Some(seqno) = self.to_retire.pop() { if builder.remaining() < 1 + Encoder::varint_len(seqno) { self.to_retire.push(seqno); @@ -320,9 +321,13 @@ impl Paths { } builder.encode_varint(FRAME_TYPE_RETIRE_CONNECTION_ID); builder.encode_varint(seqno); + if builder.len() > builder.limit() { + return Err(Error::InternalError(20)); + } tokens.push(RecoveryToken::RetireConnectionId(seqno)); stats.retire_connection_id += 1; } + Ok(()) } pub fn lost_retire_cid(&mut self, lost: u64) { @@ -599,9 +604,9 @@ impl Path { stats: &mut FrameStats, mtu: bool, // Whether the packet we're writing into will be a full MTU. now: Instant, - ) -> bool { + ) -> Res { if builder.remaining() < 9 { - return false; + return Ok(false); } // Send PATH_RESPONSE. @@ -609,6 +614,9 @@ impl Path { qtrace!([self], "Responding to path challenge {}", hex(&challenge)); builder.encode_varint(FRAME_TYPE_PATH_RESPONSE); builder.encode(&challenge[..]); + if builder.len() > builder.limit() { + return Err(Error::InternalError(21)); + } // These frames are not retransmitted in the usual fashion. // There is no token, therefore we need to count `all` specially. @@ -616,7 +624,7 @@ impl Path { stats.all += 1; if builder.remaining() < 9 { - return true; + return Ok(true); } true } else { @@ -629,6 +637,9 @@ impl Path { let data = <[u8; 8]>::try_from(&random(8)[..]).unwrap(); builder.encode_varint(FRAME_TYPE_PATH_CHALLENGE); builder.encode(&data); + if builder.len() > builder.limit() { + return Err(Error::InternalError(22)); + } // As above, no recovery token. stats.path_challenge += 1; @@ -640,9 +651,9 @@ impl Path { mtu, sent: now, }; - true + Ok(true) } else { - resp_sent + Ok(resp_sent) } } diff --git a/neqo-transport/src/send_stream.rs b/neqo-transport/src/send_stream.rs index b6b9eea5f5..c41f8a4a98 100644 --- a/neqo-transport/src/send_stream.rs +++ b/neqo-transport/src/send_stream.rs @@ -536,7 +536,7 @@ impl SendStream { (length, false) } - pub fn write_frame(&mut self, builder: &mut PacketBuilder) -> Option { + pub fn write_frame(&mut self, builder: &mut PacketBuilder) -> Res> { let id = self.stream_id; let final_size = self.final_size(); if let Some((offset, data)) = self.next_bytes() { @@ -549,14 +549,14 @@ impl SendStream { }; if overhead > builder.remaining() { qtrace!("SendStream::write_frame no space for header"); - return None; + return Ok(None); } let (length, fill) = Self::length_and_fill(data.len(), builder.remaining() - overhead); let fin = final_size.map_or(false, |fs| fs == offset + u64::try_from(length).unwrap()); if length == 0 && !fin { qtrace!("SendStream::write_frame no data, no fin"); - return None; + return Ok(None); } // Write the stream out. @@ -571,15 +571,19 @@ impl SendStream { builder.encode_vvec(&data[..length]); } + if builder.len() > builder.limit() { + return Err(Error::InternalError(23)); + } + self.mark_as_sent(offset, length, fin); - Some(RecoveryToken::Stream(StreamRecoveryToken { + Ok(Some(RecoveryToken::Stream(StreamRecoveryToken { id, offset, length, fin, - })) + }))) } else { - None + Ok(None) } } @@ -868,13 +872,14 @@ impl SendStreams { builder: &mut PacketBuilder, tokens: &mut Vec, stats: &mut FrameStats, - ) { + ) -> Res<()> { for (_, stream) in self { - if let Some(t) = stream.write_frame(builder) { + if let Some(t) = stream.write_frame(builder)? { tokens.push(t); stats.stream += 1; } } + Ok(()) } } @@ -1315,7 +1320,8 @@ mod tests { // Write a small frame: no fin. let written = builder.len(); builder.set_limit(written + 6); - ss.write_frames(&mut builder, &mut tokens, &mut FrameStats::default()); + ss.write_frames(&mut builder, &mut tokens, &mut FrameStats::default()) + .unwrap(); assert_eq!(builder.len(), written + 6); assert_eq!(tokens.len(), 1); let f1_token = tokens.remove(0); @@ -1324,7 +1330,8 @@ mod tests { // Write the rest: fin. let written = builder.len(); builder.set_limit(written + 200); - ss.write_frames(&mut builder, &mut tokens, &mut FrameStats::default()); + ss.write_frames(&mut builder, &mut tokens, &mut FrameStats::default()) + .unwrap(); assert_eq!(builder.len(), written + 10); assert_eq!(tokens.len(), 1); let f2_token = tokens.remove(0); @@ -1332,7 +1339,8 @@ mod tests { // Should be no more data to frame. let written = builder.len(); - ss.write_frames(&mut builder, &mut tokens, &mut FrameStats::default()); + ss.write_frames(&mut builder, &mut tokens, &mut FrameStats::default()) + .unwrap(); assert_eq!(builder.len(), written); assert!(tokens.is_empty()); @@ -1346,7 +1354,8 @@ mod tests { // Next frame should not set fin even though stream has fin but frame // does not include end of stream let written = builder.len(); - ss.write_frames(&mut builder, &mut tokens, &mut FrameStats::default()); + ss.write_frames(&mut builder, &mut tokens, &mut FrameStats::default()) + .unwrap(); assert_eq!(builder.len(), written + 7); // Needs a length this time. assert_eq!(tokens.len(), 1); let f4_token = tokens.remove(0); @@ -1361,7 +1370,8 @@ mod tests { // Next frame should set fin because it includes end of stream let written = builder.len(); - ss.write_frames(&mut builder, &mut tokens, &mut FrameStats::default()); + ss.write_frames(&mut builder, &mut tokens, &mut FrameStats::default()) + .unwrap(); assert_eq!(builder.len(), written + 10); assert_eq!(tokens.len(), 1); let f5_token = tokens.remove(0); @@ -1384,19 +1394,22 @@ mod tests { let mut tokens = Vec::new(); let mut builder = PacketBuilder::short(Encoder::new(), false, &[]); - ss.write_frames(&mut builder, &mut tokens, &mut FrameStats::default()); + ss.write_frames(&mut builder, &mut tokens, &mut FrameStats::default()) + .unwrap(); let f1_token = tokens.remove(0); assert!(matches!(&f1_token, RecoveryToken::Stream(x) if x.offset == 0)); assert!(matches!(&f1_token, RecoveryToken::Stream(x) if x.length == 10)); assert!(matches!(&f1_token, RecoveryToken::Stream(x) if !x.fin)); // Should be no more data to frame - ss.write_frames(&mut builder, &mut tokens, &mut FrameStats::default()); + ss.write_frames(&mut builder, &mut tokens, &mut FrameStats::default()) + .unwrap(); assert!(tokens.is_empty()); ss.get_mut(StreamId::from(0)).unwrap().close(); - ss.write_frames(&mut builder, &mut tokens, &mut FrameStats::default()); + ss.write_frames(&mut builder, &mut tokens, &mut FrameStats::default()) + .unwrap(); let f2_token = tokens.remove(0); assert!(matches!(&f2_token, RecoveryToken::Stream(x) if x.offset == 10)); assert!(matches!(&f2_token, RecoveryToken::Stream(x) if x.length == 0)); @@ -1410,7 +1423,8 @@ mod tests { } // Next frame should set fin - ss.write_frames(&mut builder, &mut tokens, &mut FrameStats::default()); + ss.write_frames(&mut builder, &mut tokens, &mut FrameStats::default()) + .unwrap(); let f3_token = tokens.remove(0); assert!(matches!(&f3_token, RecoveryToken::Stream(x) if x.offset == 10)); assert!(matches!(&f3_token, RecoveryToken::Stream(x) if x.length == 0)); @@ -1424,7 +1438,8 @@ mod tests { } // Next frame should set fin and include all data - ss.write_frames(&mut builder, &mut tokens, &mut FrameStats::default()); + ss.write_frames(&mut builder, &mut tokens, &mut FrameStats::default()) + .unwrap(); let f4_token = tokens.remove(0); assert!(matches!(&f4_token, RecoveryToken::Stream(x) if x.offset == 0)); assert!(matches!(&f4_token, RecoveryToken::Stream(x) if x.length == 10)); @@ -1551,7 +1566,7 @@ mod tests { // No frame should be sent here. let mut builder = PacketBuilder::short(Encoder::new(), false, &[]); - assert!(s.write_frame(&mut builder).is_none()); + assert!(s.write_frame(&mut builder).unwrap().is_none()); } /// Create a `SendStream` and force it into a state where it believes that @@ -1591,7 +1606,7 @@ mod tests { let mut builder = PacketBuilder::short(Encoder::new(), false, &[]); let header_len = builder.len(); builder.set_limit(header_len + space); - let token = s.write_frame(&mut builder); + let token = s.write_frame(&mut builder).unwrap(); qtrace!("STREAM frame: {}", hex_with_len(&builder[header_len..])); token.is_some() } @@ -1685,7 +1700,7 @@ mod tests { let mut builder = PacketBuilder::short(Encoder::new(), false, &[]); let header_len = builder.len(); builder.set_limit(header_len + DATA16384.len() + 2); - let token = s.write_frame(&mut builder); + let token = s.write_frame(&mut builder).unwrap(); assert!(token.is_some()); // Expect STREAM + FIN only. assert_eq!(&builder[header_len..header_len + 2], &[0b1001, 0]); @@ -1700,7 +1715,7 @@ mod tests { let mut builder = PacketBuilder::short(Encoder::new(), false, &[]); let header_len = builder.len(); builder.set_limit(header_len + DATA16384.len() + 3); - let token = s.write_frame(&mut builder); + let token = s.write_frame(&mut builder).unwrap(); assert!(token.is_some()); // Expect STREAM + LEN + FIN. assert_eq!( @@ -1726,7 +1741,7 @@ mod tests { let mut builder = PacketBuilder::short(Encoder::new(), false, &[]); let header_len = builder.len(); builder.set_limit(header_len + 66); - let token = s.write_frame(&mut builder); + let token = s.write_frame(&mut builder).unwrap(); assert!(token.is_some()); // Expect STREAM + FIN only. assert_eq!(&builder[header_len..header_len + 2], &[0b1001, 0]); @@ -1737,7 +1752,7 @@ mod tests { let mut builder = PacketBuilder::short(Encoder::new(), false, &[]); let header_len = builder.len(); builder.set_limit(header_len + 67); - let token = s.write_frame(&mut builder); + let token = s.write_frame(&mut builder).unwrap(); assert!(token.is_some()); // Expect STREAM + LEN, not FIN. assert_eq!(&builder[header_len..header_len + 3], &[0b1010, 0, 63]); diff --git a/neqo-transport/src/tracking.rs b/neqo-transport/src/tracking.rs index ac01e83a1d..52e50e56ab 100644 --- a/neqo-transport/src/tracking.rs +++ b/neqo-transport/src/tracking.rs @@ -20,6 +20,7 @@ use neqo_crypto::{Epoch, TLS_EPOCH_HANDSHAKE, TLS_EPOCH_INITIAL}; use crate::packet::{PacketBuilder, PacketNumber, PacketType}; use crate::recovery::RecoveryToken; use crate::stats::FrameStats; +use crate::{Error, Res}; use smallvec::{smallvec, SmallVec}; @@ -636,9 +637,15 @@ impl AckTracker { now: Instant, builder: &mut PacketBuilder, stats: &mut FrameStats, - ) -> Option { - self.get_mut(pn_space) - .and_then(|space| space.write_frame(now, builder, stats)) + ) -> Res> { + let res = self + .get_mut(pn_space) + .and_then(|space| space.write_frame(now, builder, stats)); + + if builder.len() > builder.limit() { + return Err(Error::InternalError(24)); + } + Ok(res) } } @@ -848,12 +855,14 @@ mod tests { .set_received(*NOW, 0, true); // The reference time for `ack_time` has to be in the past or we filter out the timer. assert!(tracker.ack_time(*NOW - Duration::from_millis(1)).is_some()); - let token = tracker.write_frame( - PNSpace::Initial, - *NOW, - &mut builder, - &mut FrameStats::default(), - ); + let token = tracker + .write_frame( + PNSpace::Initial, + *NOW, + &mut builder, + &mut FrameStats::default(), + ) + .unwrap(); assert!(token.is_some()); // Mark another packet as received so we have cause to send another ACK in that space. @@ -875,6 +884,7 @@ mod tests { &mut builder, &mut FrameStats::default() ) + .unwrap() .is_none()); if let RecoveryToken::Ack(tok) = token.unwrap() { tracker.acked(&tok); // Should be a noop. @@ -895,12 +905,14 @@ mod tests { let mut builder = PacketBuilder::short(Encoder::new(), false, &[]); builder.set_limit(10); - let token = tracker.write_frame( - PNSpace::Initial, - *NOW, - &mut builder, - &mut FrameStats::default(), - ); + let token = tracker + .write_frame( + PNSpace::Initial, + *NOW, + &mut builder, + &mut FrameStats::default(), + ) + .unwrap(); assert!(token.is_none()); assert_eq!(builder.len(), 1); // Only the short packet header has been added. } @@ -921,12 +933,14 @@ mod tests { let mut builder = PacketBuilder::short(Encoder::new(), false, &[]); builder.set_limit(32); - let token = tracker.write_frame( - PNSpace::Initial, - *NOW, - &mut builder, - &mut FrameStats::default(), - ); + let token = tracker + .write_frame( + PNSpace::Initial, + *NOW, + &mut builder, + &mut FrameStats::default(), + ) + .unwrap(); assert!(token.is_some()); let mut dec = builder.as_decoder();