Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(volo-thrift): close connection when encounter error #482

Merged
merged 1 commit into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions volo-thrift/src/codec/default/ttheader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ pub(crate) fn encode<Cx: ThriftContext>(
}
Role::Server => {
metainfo.get_all_backward_transients().is_some()
|| cx.encode_conn_reset().unwrap_or(false)
|| cx.encode_conn_reset()
|| cx.stats().biz_error().is_some()
}
};
Expand Down Expand Up @@ -375,7 +375,7 @@ pub(crate) fn encode<Cx: ThriftContext>(
string_kv_len += 1;
}
}
if cx.encode_conn_reset().unwrap_or(false) {
if cx.encode_conn_reset() {
dst.put_u16(5);
dst.put_slice("crrst".as_bytes());
dst.put_u16(1);
Expand Down Expand Up @@ -582,8 +582,7 @@ pub(crate) fn encode_size<Cx: ThriftContext>(cx: &mut Cx) -> Result<usize, Thrif
metainfo.get_all_persistents().is_some() || metainfo.get_all_transients().is_some()
}
Role::Server => {
metainfo.get_all_backward_transients().is_some()
|| thrift_cx.encode_conn_reset().unwrap_or(false)
metainfo.get_all_backward_transients().is_some() || thrift_cx.encode_conn_reset()
}
};

Expand Down Expand Up @@ -624,7 +623,7 @@ pub(crate) fn encode_size<Cx: ThriftContext>(cx: &mut Cx) -> Result<usize, Thrif
len += value.as_bytes().len();
}
}
if thrift_cx.encode_conn_reset().unwrap_or(false) {
if thrift_cx.encode_conn_reset() {
len += 2;
len += "crrst".as_bytes().len();
len += 2;
Expand Down
14 changes: 8 additions & 6 deletions volo-thrift/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ impl std::ops::DerefMut for ServerContext {
}

pub trait ThriftContext: volo::context::Context<Config = Config> + Send + 'static {
fn encode_conn_reset(&self) -> Option<bool>;
fn encode_conn_reset(&self) -> bool;
fn set_conn_reset_by_ttheader(&mut self, reset: bool);
fn handle_decoded_msg_ident(&mut self, ident: &TMessageIdentifier);
fn seq_id(&self) -> i32;
Expand All @@ -307,8 +307,8 @@ pub trait ThriftContext: volo::context::Context<Config = Config> + Send + 'stati

impl ThriftContext for ClientContext {
#[inline]
fn encode_conn_reset(&self) -> Option<bool> {
None
fn encode_conn_reset(&self) -> bool {
false
}

#[inline]
Expand Down Expand Up @@ -342,12 +342,14 @@ impl ThriftContext for ClientContext {

impl ThriftContext for ServerContext {
#[inline]
fn encode_conn_reset(&self) -> Option<bool> {
Some(self.transport.is_conn_reset())
fn encode_conn_reset(&self) -> bool {
self.transport.is_conn_reset()
}

#[inline]
fn set_conn_reset_by_ttheader(&mut self, _reset: bool) {}
fn set_conn_reset_by_ttheader(&mut self, reset: bool) {
self.transport.set_conn_reset(reset)
}

#[inline]
fn handle_decoded_msg_ident(&mut self, ident: &TMessageIdentifier) {
Expand Down
17 changes: 10 additions & 7 deletions volo-thrift/src/transport/multiplex/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use volo::{context::Context, net::Address, volo_unreachable};

use crate::{
codec::{Decoder, Encoder},
context::ServerContext,
context::{ServerContext, ThriftContext as _},
protocol::TMessageType,
server_error_to_application_exception, thrift_exception_to_application_exception, DummyMessage,
EntryMessage, ServerError, ThriftMessage,
Expand Down Expand Up @@ -40,7 +40,8 @@ pub async fn serve<Svc, Req, Resp, E, D>(

// mpsc channel used to send responses to the loop
let (send_tx, mut send_rx) = mpsc::channel(CHANNEL_SIZE);
let (error_send_tx, mut error_send_rx) = mpsc::channel(1);
let (error_send_tx, mut error_send_rx) =
mpsc::channel::<(ServerContext, ThriftMessage<DummyMessage>)>(1);

tokio::spawn({
let peer_addr = peer_addr.clone();
Expand Down Expand Up @@ -70,6 +71,9 @@ pub async fn serve<Svc, Req, Resp, E, D>(
return;
}
stat_tracer.iter().for_each(|f| f(&cx));
if cx.encode_conn_reset() {
return;
}
}
None => {
// log it
Expand All @@ -85,6 +89,7 @@ pub async fn serve<Svc, Req, Resp, E, D>(
error_msg = error_send_rx.recv() => {
match error_msg {
Some((mut cx, msg)) => {
cx.set_conn_reset_by_ttheader(true);
if let Err(e) = encoder
.encode::<DummyMessage, ServerContext>(&mut cx, msg)
.await
Expand Down Expand Up @@ -185,11 +190,11 @@ pub async fn serve<Svc, Req, Resp, E, D>(
metainfo::METAINFO
.scope(RefCell::new(mi), async move {
cx.stats.record_process_start_at();
let resp = svc.call(&mut cx, req).await;
let resp = svc.call(&mut cx, req).await.map_err(Into::into);
cx.stats.record_process_end_at();

if exit_mark.load(Ordering::Relaxed) {
cx.transport.set_conn_reset(true);
cx.set_conn_reset_by_ttheader(true);
}
let req_msg_type =
cx.req_msg_type.expect("`req_msg_type` should be set.");
Expand All @@ -200,9 +205,7 @@ pub async fn serve<Svc, Req, Resp, E, D>(
});
let msg = ThriftMessage::mk_server_resp(
&cx,
resp.map_err(|e| {
server_error_to_application_exception(e.into())
}),
resp.map_err(server_error_to_application_exception),
);
let mi = metainfo::METAINFO.with(|m| m.take());
let _ = send_tx.send((mi, cx, msg)).await;
Expand Down
14 changes: 8 additions & 6 deletions volo-thrift/src/transport/pingpong/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use volo::{net::Address, volo_unreachable};

use crate::{
codec::{Decoder, Encoder},
context::{ServerContext, SERVER_CONTEXT_CACHE},
context::{ServerContext, ThriftContext, SERVER_CONTEXT_CACHE},
protocol::TMessageType,
server_error_to_application_exception, thrift_exception_to_application_exception,
tracing::SpanProvider,
Expand Down Expand Up @@ -81,11 +81,11 @@ pub async fn serve<Svc, Req, Resp, E, D, SP>(
match msg {
Ok(Some(ThriftMessage { data: Ok(req), .. })) => {
cx.stats.record_process_start_at();
let resp = service.call(&mut cx, req).await;
let resp = service.call(&mut cx, req).await.map_err(Into::into);
cx.stats.record_process_end_at();

if exit_mark.load(Ordering::Relaxed) {
cx.transport.set_conn_reset(true);
cx.set_conn_reset_by_ttheader(true);
}

let req_msg_type =
Expand All @@ -98,9 +98,7 @@ pub async fn serve<Svc, Req, Resp, E, D, SP>(
});
let msg = ThriftMessage::mk_server_resp(
&cx,
resp.map_err(|e| {
server_error_to_application_exception(e.into())
}),
resp.map_err(server_error_to_application_exception),
);
if let Err(e) = async {
let result = encoder.encode(&mut cx, msg).await;
Expand All @@ -119,6 +117,9 @@ pub async fn serve<Svc, Req, Resp, E, D, SP>(
return Err(());
}
}
if cx.transport.is_conn_reset() {
return Err(());
}
}
Ok(Some(ThriftMessage { data: Err(_), .. })) => {
volo_unreachable!();
Expand All @@ -138,6 +139,7 @@ pub async fn serve<Svc, Req, Resp, E, D, SP>(
e, cx, peer_addr
);
cx.msg_type = Some(TMessageType::Exception);
cx.set_conn_reset_by_ttheader(true);
if !matches!(e, ThriftException::Transport(_)) {
let msg = ThriftMessage::mk_server_resp(
&cx,
Expand Down
Loading