diff --git a/volo-thrift/src/codec/default/ttheader.rs b/volo-thrift/src/codec/default/ttheader.rs index 5b882362..1f6e7161 100644 --- a/volo-thrift/src/codec/default/ttheader.rs +++ b/volo-thrift/src/codec/default/ttheader.rs @@ -324,7 +324,7 @@ pub(crate) fn encode( } Role::Server => { metainfo.get_all_backward_transients().is_some() - || cx.encode_conn_reset().unwrap_or(false) + || cx.encode_conn_reset() || cx.stats().biz_error().is_some() } }; @@ -375,7 +375,7 @@ pub(crate) fn encode( string_kv_len += 1; } } - if cx.encode_conn_reset().unwrap_or(false) { + if cx.encode_conn_reset() { dst.put_u16(5); dst.put_slice("crrst".as_bytes()); dst.put_u16(1); @@ -582,8 +582,7 @@ pub(crate) fn encode_size(cx: &mut Cx) -> Result { - metainfo.get_all_backward_transients().is_some() - || thrift_cx.encode_conn_reset().unwrap_or(false) + metainfo.get_all_backward_transients().is_some() || thrift_cx.encode_conn_reset() } }; @@ -624,7 +623,7 @@ pub(crate) fn encode_size(cx: &mut Cx) -> Result + Send + 'static { - fn encode_conn_reset(&self) -> Option; + fn encode_conn_reset(&self) -> bool; fn set_conn_reset_by_ttheader(&mut self, reset: bool); fn handle_decoded_msg_ident(&mut self, ident: &TMessageIdentifier); fn seq_id(&self) -> i32; @@ -307,8 +307,8 @@ pub trait ThriftContext: volo::context::Context + Send + 'stati impl ThriftContext for ClientContext { #[inline] - fn encode_conn_reset(&self) -> Option { - None + fn encode_conn_reset(&self) -> bool { + false } #[inline] @@ -342,12 +342,14 @@ impl ThriftContext for ClientContext { impl ThriftContext for ServerContext { #[inline] - fn encode_conn_reset(&self) -> Option { - Some(self.transport.is_conn_reset()) + fn encode_conn_reset(&self) -> bool { + self.transport.is_conn_reset() } #[inline] - fn set_conn_reset_by_ttheader(&mut self, _reset: bool) {} + fn set_conn_reset_by_ttheader(&mut self, reset: bool) { + self.transport.set_conn_reset(reset) + } #[inline] fn handle_decoded_msg_ident(&mut self, ident: &TMessageIdentifier) { diff --git a/volo-thrift/src/transport/multiplex/server.rs b/volo-thrift/src/transport/multiplex/server.rs index 71b6f601..bd6b96e0 100644 --- a/volo-thrift/src/transport/multiplex/server.rs +++ b/volo-thrift/src/transport/multiplex/server.rs @@ -12,7 +12,7 @@ use volo::{context::Context, net::Address, volo_unreachable}; use crate::{ codec::{Decoder, Encoder}, - context::ServerContext, + context::{ServerContext, ThriftContext as _}, protocol::TMessageType, server_error_to_application_exception, thrift_exception_to_application_exception, DummyMessage, EntryMessage, ServerError, ThriftMessage, @@ -40,7 +40,8 @@ pub async fn serve( // mpsc channel used to send responses to the loop let (send_tx, mut send_rx) = mpsc::channel(CHANNEL_SIZE); - let (error_send_tx, mut error_send_rx) = mpsc::channel(1); + let (error_send_tx, mut error_send_rx) = + mpsc::channel::<(ServerContext, ThriftMessage)>(1); tokio::spawn({ let peer_addr = peer_addr.clone(); @@ -70,6 +71,9 @@ pub async fn serve( return; } stat_tracer.iter().for_each(|f| f(&cx)); + if cx.encode_conn_reset() { + return; + } } None => { // log it @@ -85,6 +89,7 @@ pub async fn serve( error_msg = error_send_rx.recv() => { match error_msg { Some((mut cx, msg)) => { + cx.set_conn_reset_by_ttheader(true); if let Err(e) = encoder .encode::(&mut cx, msg) .await @@ -185,11 +190,11 @@ pub async fn serve( metainfo::METAINFO .scope(RefCell::new(mi), async move { cx.stats.record_process_start_at(); - let resp = svc.call(&mut cx, req).await; + let resp = svc.call(&mut cx, req).await.map_err(Into::into); cx.stats.record_process_end_at(); if exit_mark.load(Ordering::Relaxed) { - cx.transport.set_conn_reset(true); + cx.set_conn_reset_by_ttheader(true); } let req_msg_type = cx.req_msg_type.expect("`req_msg_type` should be set."); @@ -200,9 +205,7 @@ pub async fn serve( }); let msg = ThriftMessage::mk_server_resp( &cx, - resp.map_err(|e| { - server_error_to_application_exception(e.into()) - }), + resp.map_err(server_error_to_application_exception), ); let mi = metainfo::METAINFO.with(|m| m.take()); let _ = send_tx.send((mi, cx, msg)).await; diff --git a/volo-thrift/src/transport/pingpong/server.rs b/volo-thrift/src/transport/pingpong/server.rs index 886bfac7..d7797449 100644 --- a/volo-thrift/src/transport/pingpong/server.rs +++ b/volo-thrift/src/transport/pingpong/server.rs @@ -12,7 +12,7 @@ use volo::{net::Address, volo_unreachable}; use crate::{ codec::{Decoder, Encoder}, - context::{ServerContext, SERVER_CONTEXT_CACHE}, + context::{ServerContext, ThriftContext, SERVER_CONTEXT_CACHE}, protocol::TMessageType, server_error_to_application_exception, thrift_exception_to_application_exception, tracing::SpanProvider, @@ -81,11 +81,11 @@ pub async fn serve( match msg { Ok(Some(ThriftMessage { data: Ok(req), .. })) => { cx.stats.record_process_start_at(); - let resp = service.call(&mut cx, req).await; + let resp = service.call(&mut cx, req).await.map_err(Into::into); cx.stats.record_process_end_at(); if exit_mark.load(Ordering::Relaxed) { - cx.transport.set_conn_reset(true); + cx.set_conn_reset_by_ttheader(true); } let req_msg_type = @@ -98,9 +98,7 @@ pub async fn serve( }); let msg = ThriftMessage::mk_server_resp( &cx, - resp.map_err(|e| { - server_error_to_application_exception(e.into()) - }), + resp.map_err(server_error_to_application_exception), ); if let Err(e) = async { let result = encoder.encode(&mut cx, msg).await; @@ -119,6 +117,9 @@ pub async fn serve( return Err(()); } } + if cx.transport.is_conn_reset() { + return Err(()); + } } Ok(Some(ThriftMessage { data: Err(_), .. })) => { volo_unreachable!(); @@ -138,6 +139,7 @@ pub async fn serve( e, cx, peer_addr ); cx.msg_type = Some(TMessageType::Exception); + cx.set_conn_reset_by_ttheader(true); if !matches!(e, ThriftException::Transport(_)) { let msg = ThriftMessage::mk_server_resp( &cx,