Skip to content

Commit

Permalink
feat(volo-thrift): close connection when encounter error (#482)
Browse files Browse the repository at this point in the history
  • Loading branch information
PureWhiteWu authored Aug 7, 2024
1 parent ea932cf commit 0c15227
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 24 deletions.
9 changes: 4 additions & 5 deletions volo-thrift/src/codec/default/ttheader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ pub(crate) fn encode<Cx: ThriftContext>(
}
Role::Server => {
metainfo.get_all_backward_transients().is_some()
|| cx.encode_conn_reset().unwrap_or(false)
|| cx.encode_conn_reset()
|| cx.stats().biz_error().is_some()
}
};
Expand Down Expand Up @@ -375,7 +375,7 @@ pub(crate) fn encode<Cx: ThriftContext>(
string_kv_len += 1;
}
}
if cx.encode_conn_reset().unwrap_or(false) {
if cx.encode_conn_reset() {
dst.put_u16(5);
dst.put_slice("crrst".as_bytes());
dst.put_u16(1);
Expand Down Expand Up @@ -582,8 +582,7 @@ pub(crate) fn encode_size<Cx: ThriftContext>(cx: &mut Cx) -> Result<usize, Thrif
metainfo.get_all_persistents().is_some() || metainfo.get_all_transients().is_some()
}
Role::Server => {
metainfo.get_all_backward_transients().is_some()
|| thrift_cx.encode_conn_reset().unwrap_or(false)
metainfo.get_all_backward_transients().is_some() || thrift_cx.encode_conn_reset()
}
};

Expand Down Expand Up @@ -624,7 +623,7 @@ pub(crate) fn encode_size<Cx: ThriftContext>(cx: &mut Cx) -> Result<usize, Thrif
len += value.as_bytes().len();
}
}
if thrift_cx.encode_conn_reset().unwrap_or(false) {
if thrift_cx.encode_conn_reset() {
len += 2;
len += "crrst".as_bytes().len();
len += 2;
Expand Down
14 changes: 8 additions & 6 deletions volo-thrift/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ impl std::ops::DerefMut for ServerContext {
}

pub trait ThriftContext: volo::context::Context<Config = Config> + Send + 'static {
fn encode_conn_reset(&self) -> Option<bool>;
fn encode_conn_reset(&self) -> bool;
fn set_conn_reset_by_ttheader(&mut self, reset: bool);
fn handle_decoded_msg_ident(&mut self, ident: &TMessageIdentifier);
fn seq_id(&self) -> i32;
Expand All @@ -307,8 +307,8 @@ pub trait ThriftContext: volo::context::Context<Config = Config> + Send + 'stati

impl ThriftContext for ClientContext {
#[inline]
fn encode_conn_reset(&self) -> Option<bool> {
None
fn encode_conn_reset(&self) -> bool {
false
}

#[inline]
Expand Down Expand Up @@ -342,12 +342,14 @@ impl ThriftContext for ClientContext {

impl ThriftContext for ServerContext {
#[inline]
fn encode_conn_reset(&self) -> Option<bool> {
Some(self.transport.is_conn_reset())
fn encode_conn_reset(&self) -> bool {
self.transport.is_conn_reset()
}

#[inline]
fn set_conn_reset_by_ttheader(&mut self, _reset: bool) {}
fn set_conn_reset_by_ttheader(&mut self, reset: bool) {
self.transport.set_conn_reset(reset)
}

#[inline]
fn handle_decoded_msg_ident(&mut self, ident: &TMessageIdentifier) {
Expand Down
17 changes: 10 additions & 7 deletions volo-thrift/src/transport/multiplex/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use volo::{context::Context, net::Address, volo_unreachable};

use crate::{
codec::{Decoder, Encoder},
context::ServerContext,
context::{ServerContext, ThriftContext as _},
protocol::TMessageType,
server_error_to_application_exception, thrift_exception_to_application_exception, DummyMessage,
EntryMessage, ServerError, ThriftMessage,
Expand Down Expand Up @@ -40,7 +40,8 @@ pub async fn serve<Svc, Req, Resp, E, D>(

// mpsc channel used to send responses to the loop
let (send_tx, mut send_rx) = mpsc::channel(CHANNEL_SIZE);
let (error_send_tx, mut error_send_rx) = mpsc::channel(1);
let (error_send_tx, mut error_send_rx) =
mpsc::channel::<(ServerContext, ThriftMessage<DummyMessage>)>(1);

tokio::spawn({
let peer_addr = peer_addr.clone();
Expand Down Expand Up @@ -70,6 +71,9 @@ pub async fn serve<Svc, Req, Resp, E, D>(
return;
}
stat_tracer.iter().for_each(|f| f(&cx));
if cx.encode_conn_reset() {
return;
}
}
None => {
// log it
Expand All @@ -85,6 +89,7 @@ pub async fn serve<Svc, Req, Resp, E, D>(
error_msg = error_send_rx.recv() => {
match error_msg {
Some((mut cx, msg)) => {
cx.set_conn_reset_by_ttheader(true);
if let Err(e) = encoder
.encode::<DummyMessage, ServerContext>(&mut cx, msg)
.await
Expand Down Expand Up @@ -185,11 +190,11 @@ pub async fn serve<Svc, Req, Resp, E, D>(
metainfo::METAINFO
.scope(RefCell::new(mi), async move {
cx.stats.record_process_start_at();
let resp = svc.call(&mut cx, req).await;
let resp = svc.call(&mut cx, req).await.map_err(Into::into);
cx.stats.record_process_end_at();

if exit_mark.load(Ordering::Relaxed) {
cx.transport.set_conn_reset(true);
cx.set_conn_reset_by_ttheader(true);
}
let req_msg_type =
cx.req_msg_type.expect("`req_msg_type` should be set.");
Expand All @@ -200,9 +205,7 @@ pub async fn serve<Svc, Req, Resp, E, D>(
});
let msg = ThriftMessage::mk_server_resp(
&cx,
resp.map_err(|e| {
server_error_to_application_exception(e.into())
}),
resp.map_err(server_error_to_application_exception),
);
let mi = metainfo::METAINFO.with(|m| m.take());
let _ = send_tx.send((mi, cx, msg)).await;
Expand Down
14 changes: 8 additions & 6 deletions volo-thrift/src/transport/pingpong/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use volo::{net::Address, volo_unreachable};

use crate::{
codec::{Decoder, Encoder},
context::{ServerContext, SERVER_CONTEXT_CACHE},
context::{ServerContext, ThriftContext, SERVER_CONTEXT_CACHE},
protocol::TMessageType,
server_error_to_application_exception, thrift_exception_to_application_exception,
tracing::SpanProvider,
Expand Down Expand Up @@ -81,11 +81,11 @@ pub async fn serve<Svc, Req, Resp, E, D, SP>(
match msg {
Ok(Some(ThriftMessage { data: Ok(req), .. })) => {
cx.stats.record_process_start_at();
let resp = service.call(&mut cx, req).await;
let resp = service.call(&mut cx, req).await.map_err(Into::into);
cx.stats.record_process_end_at();

if exit_mark.load(Ordering::Relaxed) {
cx.transport.set_conn_reset(true);
cx.set_conn_reset_by_ttheader(true);
}

let req_msg_type =
Expand All @@ -98,9 +98,7 @@ pub async fn serve<Svc, Req, Resp, E, D, SP>(
});
let msg = ThriftMessage::mk_server_resp(
&cx,
resp.map_err(|e| {
server_error_to_application_exception(e.into())
}),
resp.map_err(server_error_to_application_exception),
);
if let Err(e) = async {
let result = encoder.encode(&mut cx, msg).await;
Expand All @@ -119,6 +117,9 @@ pub async fn serve<Svc, Req, Resp, E, D, SP>(
return Err(());
}
}
if cx.transport.is_conn_reset() {
return Err(());
}
}
Ok(Some(ThriftMessage { data: Err(_), .. })) => {
volo_unreachable!();
Expand All @@ -138,6 +139,7 @@ pub async fn serve<Svc, Req, Resp, E, D, SP>(
e, cx, peer_addr
);
cx.msg_type = Some(TMessageType::Exception);
cx.set_conn_reset_by_ttheader(true);
if !matches!(e, ThriftException::Transport(_)) {
let msg = ThriftMessage::mk_server_resp(
&cx,
Expand Down

0 comments on commit 0c15227

Please sign in to comment.