diff --git a/CHANGELOG.md b/CHANGELOG.md index 3248cff..6915805 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,8 @@ * Drop connection if client overflows concurrent streams number multiple times +* Drop connection number of resets more than 50% of total requests + ## [0.4.2] - 2023-10-09 * Add client streams helper methods diff --git a/src/connection.rs b/src/connection.rs index a696145..cb05253 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -41,6 +41,9 @@ struct ConnectionState { active_local_streams: Cell, readiness: RefCell>>, + rst_count: Cell, + total_count: Cell, + // Local config local_config: Config, // Maximum number of locally initiated streams @@ -94,6 +97,8 @@ impl Connection { streams: RefCell::new(HashMap::default()), active_remote_streams: Cell::new(0), active_local_streams: Cell::new(0), + rst_count: Cell::new(0), + total_count: Cell::new(0), readiness: RefCell::new(VecDeque::new()), next_stream_id: Cell::new(StreamId::new(1)), local_config: config, @@ -463,6 +468,7 @@ impl RecvHalfConnection { } else { let stream = StreamRef::new(id, true, Connection(self.0.clone())); self.0.next_stream_id.set(id); + self.0.total_count.set(self.0.total_count.get() + 1); self.0.streams.borrow_mut().insert(id, stream.clone()); self.0 .active_remote_streams @@ -491,7 +497,6 @@ impl RecvHalfConnection { )); Ok(None) } else { - println!("66664"); Err(Either::Left(ConnectionError::InvalidStreamId( "Received data", ))) @@ -632,6 +637,17 @@ impl RecvHalfConnection { } } + fn update_rst_count(&self) -> Result<(), Either> { + let count = self.0.rst_count.get() + 1; + let total_count = self.0.total_count.get(); + if total_count >= 10 && count >= total_count >> 1 { + Err(Either::Left(ConnectionError::ConcurrencyOverflow)) + } else { + self.0.rst_count.set(count); + Ok(()) + } + } + pub(crate) fn recv_rst_stream( &self, frm: frame::Reset, @@ -642,13 +658,16 @@ impl RecvHalfConnection { Err(Either::Left(ConnectionError::UnknownStream("RST_STREAM"))) } else if let Some(stream) = self.query(frm.stream_id()) { stream.recv_rst_stream(&frm); + self.update_rst_count()?; + Err(Either::Right(StreamErrorInner::new( stream, StreamError::Reset(frm.reason()), ))) } else if self.0.local_reset_ids.borrow().contains(&frm.stream_id()) { - Ok(()) + self.update_rst_count() } else { + self.update_rst_count()?; Err(Either::Left(ConnectionError::UnknownStream("RST_STREAM"))) } } @@ -745,7 +764,6 @@ impl fmt::Debug for Connection { async fn delay_drop_task(state: Connection) { state.set_flags(ConnectionFlags::DELAY_DROP_TASK_STARTED); - #[allow(clippy::while_let_loop)] loop { let next = if let Some(item) = state.0.local_reset_queue.borrow().front() { item.1 - now() diff --git a/src/stream.rs b/src/stream.rs index f27cd17..b603df7 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -213,11 +213,10 @@ impl StreamState { // stream is closed if reason.is_some() { log::trace!("{:?} is closed with local reset, dropping stream", self.id); - self.con.drop_stream(self.id); } else { log::trace!("{:?} both sides are closed, dropping stream", self.id); - self.con.drop_stream(self.id); } + self.con.drop_stream(self.id); } } } diff --git a/tests/connection.rs b/tests/connection.rs index f6bc9b7..6ca6a74 100644 --- a/tests/connection.rs +++ b/tests/connection.rs @@ -316,3 +316,119 @@ async fn test_goaway_on_overflow() { assert_eq!(res.reason(), Reason::FLOW_CONTROL_ERROR); assert!(io.recv(&codec).await.unwrap().is_none()); } + +#[ntex::test] +async fn test_goaway_on_reset() { + let srv = start_server(); + let addr = srv.addr(); + + let io = connect(addr).await; + let codec = Codec::default(); + let _ = io.with_write_buf(|buf| buf.extend_from_slice(&PREFACE)); + + let settings = frame::Settings::default(); + io.encode(settings.into(), &codec).unwrap(); + + // settings & window + let _ = io.recv(&codec).await; + let _ = io.recv(&codec).await; + let _ = io.recv(&codec).await; + + let mut id = frame::StreamId::CLIENT; + let pseudo = frame::PseudoHeaders { + method: Some(Method::GET), + scheme: Some("HTTPS".into()), + authority: Some("localhost".into()), + path: Some("/".into()), + ..Default::default() + }; + for _ in 0..5 { + let hdrs = frame::Headers::new(id, pseudo.clone(), HeaderMap::new(), true); + id = id.next_id().unwrap(); + io.send(hdrs.into(), &codec).await.unwrap(); + io.recv(&codec).await.unwrap().unwrap(); // headers + io.recv(&codec).await.unwrap().unwrap(); // data + io.recv(&codec).await.unwrap().unwrap(); // data eof + } + + for _ in 0..4 { + let rst = frame::Reset::new(id, Reason::NO_ERROR); + let hdrs = frame::Headers::new(id, pseudo.clone(), HeaderMap::new(), false); + id = id.next_id().unwrap(); + io.encode(hdrs.into(), &codec).unwrap(); + io.send(rst.into(), &codec).await.unwrap(); + io.recv(&codec).await.unwrap().unwrap(); // headers + } + let rst = frame::Reset::new(id, Reason::NO_ERROR); + let hdrs = frame::Headers::new(id, pseudo.clone(), HeaderMap::new(), false); + io.encode(hdrs.into(), &codec).unwrap(); + io.send(rst.into(), &codec).await.unwrap(); + let res = if let frame::Frame::GoAway(rst) = io.recv(&codec).await.unwrap().unwrap() { + rst + } else { + panic!() + }; + assert_eq!(res.reason(), Reason::FLOW_CONTROL_ERROR); + assert!(io.recv(&codec).await.unwrap().is_none()); +} + +#[ntex::test] +async fn test_goaway_on_reset2() { + let srv = start_server(); + let addr = srv.addr(); + + let io = connect(addr).await; + let codec = Codec::default(); + let _ = io.with_write_buf(|buf| buf.extend_from_slice(&PREFACE)); + + let settings = frame::Settings::default(); + io.encode(settings.into(), &codec).unwrap(); + + // settings & window + let _ = io.recv(&codec).await; + let _ = io.recv(&codec).await; + let _ = io.recv(&codec).await; + + let mut id = frame::StreamId::CLIENT; + let pseudo = frame::PseudoHeaders { + method: Some(Method::GET), + scheme: Some("HTTPS".into()), + authority: Some("localhost".into()), + path: Some("/".into()), + ..Default::default() + }; + for _ in 0..5 { + let hdrs = frame::Headers::new(id, pseudo.clone(), HeaderMap::new(), true); + id = id.next_id().unwrap(); + io.send(hdrs.into(), &codec).await.unwrap(); + io.recv(&codec).await.unwrap().unwrap(); // headers + io.recv(&codec).await.unwrap().unwrap(); // data + io.recv(&codec).await.unwrap().unwrap(); // data eof + } + + for _ in 0..4 { + let rst = frame::Reset::new(id, Reason::NO_ERROR); + let hdrs = frame::Headers::new(id, pseudo.clone(), HeaderMap::new(), true); + id = id.next_id().unwrap(); + io.encode(hdrs.into(), &codec).unwrap(); + io.send(rst.into(), &codec).await.unwrap(); + io.recv(&codec).await.unwrap().unwrap(); // headers + io.recv(&codec).await.unwrap().unwrap(); // data + io.recv(&codec).await.unwrap().unwrap(); // data eof + } + let rst = frame::Reset::new(id, Reason::NO_ERROR); + let hdrs = frame::Headers::new(id, pseudo.clone(), HeaderMap::new(), true); + io.encode(hdrs.into(), &codec).unwrap(); + io.send(rst.into(), &codec).await.unwrap(); + io.recv(&codec).await.unwrap().unwrap(); // headers + io.recv(&codec).await.unwrap().unwrap(); // data + io.recv(&codec).await.unwrap().unwrap(); // data eof + + let res = if let frame::Frame::GoAway(rst) = io.recv(&codec).await.unwrap().unwrap() { + rst + } else { + panic!() + }; + assert_eq!(res.reason(), Reason::FLOW_CONTROL_ERROR); + assert!(io.recv(&codec).await.unwrap().is_none()); +}