Skip to content

Commit

Permalink
Fix anti-deadlock PTO during handshake (Tencent#137)
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaofei0800 authored Jan 12, 2024
1 parent ba0de61 commit 414f8ac
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 11 deletions.
30 changes: 30 additions & 0 deletions src/connection/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4746,6 +4746,36 @@ pub(crate) mod tests {
Ok(())
}

#[test]
fn handshake_with_antiamplification_deadlock() -> Result<()> {
let mut test_pair = TestPair::new_with_test_config()?;

// Client send Initial.
let packets = TestPair::conn_packets_out(&mut test_pair.client)?;
TestPair::conn_packets_in(&mut test_pair.server, packets)?;

// Server send Initial and Handshake.
let mut packets = TestPair::conn_packets_out(&mut test_pair.server)?;

// Fake dropping the second packet.
packets.truncate(1);

// Client recv Initial and the first Handshake.
TestPair::conn_packets_in(&mut test_pair.client, packets)?;
assert!(!test_pair.client.tls_session.is_completed());

// Client send ACK and PADDING and wait for retransmission of the second packet.
let _ = TestPair::conn_packets_out(&mut test_pair.client)?;

// `LossDetection` timer should not be None to avoid deadlock.
assert!(test_pair.client.timeout().is_some());
assert!(test_pair.client.timers.get(Timer::LossDetection).is_some());

// TODO: complete the remaining part after supporting anti-amplification in server side.

Ok(())
}

#[test]
fn handshake_with_alpn_mismatched() -> Result<()> {
let mut client_config = TestPair::new_test_config(false)?;
Expand Down
55 changes: 44 additions & 11 deletions src/connection/recovery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ pub struct Recovery {
/// declared lost. The size does not include IP or UDP overhead.
pub bytes_in_flight: usize,

/// Number of ack-eliciting packets in flight.
pub ack_eliciting_in_flight: u64,

/// RTT estimation for the corresponding path.
pub rtt: RttEstimator,

Expand All @@ -104,6 +107,7 @@ impl Recovery {
pkt_thresh: INITIAL_PACKET_THRESHOLD,
time_thresh: INITIAL_TIME_THRESHOLD,
bytes_in_flight: 0,
ack_eliciting_in_flight: 0,
rtt: RttEstimator::new(conf.initial_rtt),
congestion: congestion_control::build_congestion_controller(conf),
trace_id: String::from(""),
Expand Down Expand Up @@ -167,6 +171,8 @@ impl Recovery {
if ack_eliciting {
space.time_of_last_sent_ack_eliciting_pkt = Some(now);
space.loss_probes = space.loss_probes.saturating_sub(1);
space.ack_eliciting_in_flight += 1;
self.ack_eliciting_in_flight += 1;
}

space.bytes_in_flight += sent_size;
Expand Down Expand Up @@ -295,6 +301,13 @@ impl Recovery {
space.bytes_in_flight =
space.bytes_in_flight.saturating_sub(sent_pkt.sent_size);
self.bytes_in_flight = self.bytes_in_flight.saturating_sub(sent_pkt.sent_size);

if sent_pkt.ack_eliciting {
space.ack_eliciting_in_flight =
space.ack_eliciting_in_flight.saturating_sub(1);
self.ack_eliciting_in_flight =
self.ack_eliciting_in_flight.saturating_sub(1);
}
}

// Process each acked packet in congestion controller and update delivery
Expand Down Expand Up @@ -397,6 +410,13 @@ impl Recovery {
lost_bytes += unacked.sent_size;
space.bytes_in_flight = space.bytes_in_flight.saturating_sub(unacked.sent_size);
self.bytes_in_flight = self.bytes_in_flight.saturating_sub(unacked.sent_size);

if unacked.ack_eliciting {
space.ack_eliciting_in_flight =
space.ack_eliciting_in_flight.saturating_sub(1);
self.ack_eliciting_in_flight =
self.ack_eliciting_in_flight.saturating_sub(1);
}
}
latest_lost_packet = Some(unacked.clone());
trace!(
Expand Down Expand Up @@ -484,7 +504,7 @@ impl Recovery {

// TODO: The server's timer is not set if nothing can be sent.

if self.bytes_in_flight == 0 && handshake_status.peer_verified_address {
if self.ack_eliciting_in_flight == 0 && handshake_status.peer_verified_address {
// There is nothing to detect lost, so no timer is set.
// However, the client needs to arm the timer if the
// server might be blocked by the anti-amplification limit.
Expand Down Expand Up @@ -530,7 +550,7 @@ impl Recovery {
}

// PTO timer mode (REVISIT)
let sid = if self.bytes_in_flight > 0 {
let sid = if self.ack_eliciting_in_flight > 0 {
// Send new data if available, else retransmit old data. If neither
// is available, send a single PING frame.
let (_, e) = self.get_pto_time_and_space(space_id, spaces, handshake_status, now);
Expand Down Expand Up @@ -642,8 +662,8 @@ impl Recovery {
) -> (Option<Instant>, SpaceId) {
let mut duration = self.calculate_pto();

// Arm PTO from now when there are no inflight packets.
if self.bytes_in_flight == 0 {
// Arm PTO from now when there are no ack-eliciting packets inflight.
if self.ack_eliciting_in_flight == 0 {
if handshake_status.derived_handshake_keys {
return (Some(now + duration), SpaceId::Handshake);
} else {
Expand All @@ -665,7 +685,7 @@ impl Recovery {
Some(space) => space,
None => continue,
};
if space.bytes_in_flight == 0 {
if space.ack_eliciting_in_flight == 0 {
continue;
}

Expand Down Expand Up @@ -720,6 +740,7 @@ impl Recovery {
space.loss_time = None;
space.loss_probes = 0;
space.bytes_in_flight = 0;
space.ack_eliciting_in_flight = 0;
self.set_loss_detection_timer(space_id, spaces, handshake_status, now);
}

Expand All @@ -728,12 +749,16 @@ impl Recovery {
/// When Initial or Handshake keys are discarded, packets sent in that
/// space no longer count toward bytes in flight.
fn remove_from_bytes_in_flight(&mut self, space: &PacketNumSpace) {
let unacked_bytes = space
.sent
.iter()
.filter(|p| p.in_flight && p.time_acked.is_none() && p.time_lost.is_none())
.fold(0, |acc, p| acc + p.sent_size);
self.bytes_in_flight = self.bytes_in_flight.saturating_sub(unacked_bytes);
for pkt in &space.sent {
if !pkt.in_flight || pkt.time_acked.is_some() || pkt.time_lost.is_some() {
continue;
}

self.bytes_in_flight = self.bytes_in_flight.saturating_sub(pkt.sent_size);
if pkt.ack_eliciting {
self.ack_eliciting_in_flight = self.ack_eliciting_in_flight.saturating_sub(1);
}
}
}

/// Update maximum datagram size
Expand Down Expand Up @@ -830,7 +855,9 @@ mod tests {
recovery.on_packet_sent(sent_pkt2, space_id, &mut spaces, status, now);
assert_eq!(spaces.get(space_id).unwrap().sent.len(), 3);
assert_eq!(spaces.get(space_id).unwrap().bytes_in_flight, 3003);
assert_eq!(spaces.get(space_id).unwrap().ack_eliciting_in_flight, 3);
assert_eq!(recovery.bytes_in_flight, 3003);
assert_eq!(recovery.ack_eliciting_in_flight, 3);

// Advance ticks and fake receiving of ack
now += Duration::from_millis(100);
Expand All @@ -839,13 +866,17 @@ mod tests {
acked.insert(2..3);
recovery.on_ack_received(&acked, 0, SpaceId::Handshake, &mut spaces, status, now)?;
assert_eq!(spaces.get(space_id).unwrap().sent.len(), 2);
assert_eq!(spaces.get(space_id).unwrap().ack_eliciting_in_flight, 1);
assert_eq!(recovery.ack_eliciting_in_flight, 1);

// Advance ticks until loss timeout
now = recovery.loss_detection_timer().unwrap();
let (lost_pkts, lost_bytes) =
recovery.on_loss_detection_timeout(SpaceId::Handshake, &mut spaces, status, now);
assert_eq!(lost_pkts, 1);
assert_eq!(lost_bytes, 1001);
assert_eq!(spaces.get(space_id).unwrap().ack_eliciting_in_flight, 0);
assert_eq!(recovery.ack_eliciting_in_flight, 0);

Ok(())
}
Expand Down Expand Up @@ -1026,7 +1057,9 @@ mod tests {
recovery.on_pkt_num_space_discarded(space_id, &mut spaces, status, now);
assert_eq!(spaces.get(space_id).unwrap().sent.len(), 0);
assert_eq!(spaces.get(space_id).unwrap().bytes_in_flight, 0);
assert_eq!(spaces.get(space_id).unwrap().ack_eliciting_in_flight, 0);
assert_eq!(recovery.bytes_in_flight, 1003);
assert_eq!(recovery.ack_eliciting_in_flight, 1);

Ok(())
}
Expand Down
4 changes: 4 additions & 0 deletions src/connection/space.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ pub struct PacketNumSpace {
/// number space.
pub bytes_in_flight: usize,

/// Number of ack-eliciting packets in flight.
pub ack_eliciting_in_flight: u64,

/// Packet number space for application data
pub is_data: bool,

Expand Down Expand Up @@ -146,6 +149,7 @@ impl PacketNumSpace {
largest_acked_pkt: std::u64::MAX,
loss_probes: 0,
bytes_in_flight: 0,
ack_eliciting_in_flight: 0,
is_data: id != SpaceId::Initial && id != SpaceId::Handshake,
reinject: ReinjectQueue::default(),
}
Expand Down

0 comments on commit 414f8ac

Please sign in to comment.