Skip to content

Commit

Permalink
feat(volo-thrift): close connection when encounter error
Browse files Browse the repository at this point in the history
  • Loading branch information
PureWhiteWu committed Aug 7, 2024
1 parent ea932cf commit 7e77528
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 19 deletions.
14 changes: 8 additions & 6 deletions volo-thrift/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ impl std::ops::DerefMut for ServerContext {
}

pub trait ThriftContext: volo::context::Context<Config = Config> + Send + 'static {
fn encode_conn_reset(&self) -> Option<bool>;
fn encode_conn_reset(&self) -> bool;
fn set_conn_reset_by_ttheader(&mut self, reset: bool);
fn handle_decoded_msg_ident(&mut self, ident: &TMessageIdentifier);
fn seq_id(&self) -> i32;
Expand All @@ -307,8 +307,8 @@ pub trait ThriftContext: volo::context::Context<Config = Config> + Send + 'stati

impl ThriftContext for ClientContext {
#[inline]
fn encode_conn_reset(&self) -> Option<bool> {
None
fn encode_conn_reset(&self) -> bool {
false
}

#[inline]
Expand Down Expand Up @@ -342,12 +342,14 @@ impl ThriftContext for ClientContext {

impl ThriftContext for ServerContext {
#[inline]
fn encode_conn_reset(&self) -> Option<bool> {
Some(self.transport.is_conn_reset())
fn encode_conn_reset(&self) -> bool {
self.transport.is_conn_reset()
}

#[inline]
fn set_conn_reset_by_ttheader(&mut self, _reset: bool) {}
fn set_conn_reset_by_ttheader(&mut self, reset: bool) {
self.transport.set_conn_reset(reset)
}

#[inline]
fn handle_decoded_msg_ident(&mut self, ident: &TMessageIdentifier) {
Expand Down
17 changes: 11 additions & 6 deletions volo-thrift/src/transport/multiplex/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use volo::{context::Context, net::Address, volo_unreachable};

use crate::{
codec::{Decoder, Encoder},
context::ServerContext,
context::{ServerContext, ThriftContext as _},
protocol::TMessageType,
server_error_to_application_exception, thrift_exception_to_application_exception, DummyMessage,
EntryMessage, ServerError, ThriftMessage,
Expand Down Expand Up @@ -40,7 +40,8 @@ pub async fn serve<Svc, Req, Resp, E, D>(

// mpsc channel used to send responses to the loop
let (send_tx, mut send_rx) = mpsc::channel(CHANNEL_SIZE);
let (error_send_tx, mut error_send_rx) = mpsc::channel(1);
let (error_send_tx, mut error_send_rx) =
mpsc::channel::<(ServerContext, ThriftMessage<DummyMessage>)>(1);

tokio::spawn({
let peer_addr = peer_addr.clone();
Expand Down Expand Up @@ -70,6 +71,9 @@ pub async fn serve<Svc, Req, Resp, E, D>(
return;
}
stat_tracer.iter().for_each(|f| f(&cx));
if cx.encode_conn_reset() {
return;
}
}
None => {
// log it
Expand All @@ -85,6 +89,7 @@ pub async fn serve<Svc, Req, Resp, E, D>(
error_msg = error_send_rx.recv() => {
match error_msg {
Some((mut cx, msg)) => {
cx.set_conn_reset_by_ttheader(true);
if let Err(e) = encoder
.encode::<DummyMessage, ServerContext>(&mut cx, msg)
.await
Expand Down Expand Up @@ -185,11 +190,11 @@ pub async fn serve<Svc, Req, Resp, E, D>(
metainfo::METAINFO
.scope(RefCell::new(mi), async move {
cx.stats.record_process_start_at();
let resp = svc.call(&mut cx, req).await;
let resp = svc.call(&mut cx, req).await.map_err(Into::into);
cx.stats.record_process_end_at();

if exit_mark.load(Ordering::Relaxed) {
cx.transport.set_conn_reset(true);
if exit_mark.load(Ordering::Relaxed) || matches!(resp, Err(ServerError::Application(_))) {
cx.set_conn_reset_by_ttheader(true);
}
let req_msg_type =
cx.req_msg_type.expect("`req_msg_type` should be set.");
Expand All @@ -201,7 +206,7 @@ pub async fn serve<Svc, Req, Resp, E, D>(
let msg = ThriftMessage::mk_server_resp(
&cx,
resp.map_err(|e| {
server_error_to_application_exception(e.into())
server_error_to_application_exception(e)
}),
);
let mi = metainfo::METAINFO.with(|m| m.take());
Expand Down
18 changes: 11 additions & 7 deletions volo-thrift/src/transport/pingpong/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use volo::{net::Address, volo_unreachable};

use crate::{
codec::{Decoder, Encoder},
context::{ServerContext, SERVER_CONTEXT_CACHE},
context::{ServerContext, ThriftContext, SERVER_CONTEXT_CACHE},
protocol::TMessageType,
server_error_to_application_exception, thrift_exception_to_application_exception,
tracing::SpanProvider,
Expand Down Expand Up @@ -81,11 +81,13 @@ pub async fn serve<Svc, Req, Resp, E, D, SP>(
match msg {
Ok(Some(ThriftMessage { data: Ok(req), .. })) => {
cx.stats.record_process_start_at();
let resp = service.call(&mut cx, req).await;
let resp = service.call(&mut cx, req).await.map_err(Into::into);
cx.stats.record_process_end_at();

if exit_mark.load(Ordering::Relaxed) {
cx.transport.set_conn_reset(true);
if exit_mark.load(Ordering::Relaxed)
|| matches!(resp, Err(ServerError::Application(_)))
{
cx.set_conn_reset_by_ttheader(true);
}

let req_msg_type =
Expand All @@ -98,9 +100,7 @@ pub async fn serve<Svc, Req, Resp, E, D, SP>(
});
let msg = ThriftMessage::mk_server_resp(
&cx,
resp.map_err(|e| {
server_error_to_application_exception(e.into())
}),
resp.map_err(|e| server_error_to_application_exception(e)),
);
if let Err(e) = async {
let result = encoder.encode(&mut cx, msg).await;
Expand All @@ -119,6 +119,9 @@ pub async fn serve<Svc, Req, Resp, E, D, SP>(
return Err(());
}
}
if cx.transport.is_conn_reset() {
return Err(());
}
}
Ok(Some(ThriftMessage { data: Err(_), .. })) => {
volo_unreachable!();
Expand All @@ -138,6 +141,7 @@ pub async fn serve<Svc, Req, Resp, E, D, SP>(
e, cx, peer_addr
);
cx.msg_type = Some(TMessageType::Exception);
cx.set_conn_reset_by_ttheader(true);
if !matches!(e, ThriftException::Transport(_)) {
let msg = ThriftMessage::mk_server_resp(
&cx,
Expand Down

0 comments on commit 7e77528

Please sign in to comment.