From 54c6172008023ec8cd936f9eff2b972bd809b3f9 Mon Sep 17 00:00:00 2001 From: Harsha Teja Kanna Date: Wed, 30 Oct 2024 00:29:36 -0400 Subject: [PATCH] Added replay aware tracing filter --- Cargo.toml | 9 ++- examples/tracing.rs | 65 ++++++++++++++++++ src/endpoint/context.rs | 21 ++++++ src/endpoint/futures/async_result_poll.rs | 10 ++- src/endpoint/mod.rs | 23 +++++++ src/filter.rs | 80 +++++++++++++++++++++++ src/lib.rs | 2 + 7 files changed, 207 insertions(+), 3 deletions(-) create mode 100644 examples/tracing.rs create mode 100644 src/filter.rs diff --git a/Cargo.toml b/Cargo.toml index acfcb4c..00d99d5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,10 +7,16 @@ license = "MIT" repository = "https://github.com/restatedev/sdk-rust" rust-version = "1.76.0" +[[example]] +name = "tracing" +path = "examples/tracing.rs" +required-features = ["tracing-subscriber"] + [features] default = ["http_server", "rand", "uuid"] hyper = ["dep:hyper", "http-body-util", "restate-sdk-shared-core/http"] http_server = ["hyper", "hyper/server", "hyper/http2", "hyper-util", "tokio/net", "tokio/signal", "tokio/macros"] +tracing-subscriber = ["dep:tracing-subscriber"] [dependencies] bytes = "1.6.1" @@ -30,11 +36,12 @@ thiserror = "1.0.63" tokio = { version = "1", default-features = false, features = ["sync"] } tower-service = "0.3" tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["registry"], optional = true } uuid = { version = "1.10.0", optional = true } [dev-dependencies] tokio = { version = "1", features = ["full"] } -tracing-subscriber = "0.3" +tracing-subscriber = { version = "0.3", features = ["env-filter", "registry"] } trybuild = "1.0" reqwest = { version = "0.12", features = ["json"] } rand = "0.8.5" diff --git a/examples/tracing.rs b/examples/tracing.rs new file mode 100644 index 0000000..68051b9 --- /dev/null +++ b/examples/tracing.rs @@ -0,0 +1,65 @@ +use restate_sdk::prelude::*; +use std::convert::Infallible; +use std::time::Duration; +use tracing::info; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, Layer}; + +#[restate_sdk::service] +trait Greeter { + async fn greet(name: String) -> Result; +} + +struct GreeterImpl; + +impl Greeter for GreeterImpl { + async fn greet(&self, ctx: Context<'_>, name: String) -> Result { + let timeout = 60; // More than suspension timeout to trigger replay + info!("This will be logged on replay"); + _ = ctx.service_client::().delay(1).call().await; + info!("This will not be logged on replay"); + _ = ctx + .service_client::() + .delay(timeout) + .call() + .await; + info!("This will be logged on processing after suspension"); + Ok(format!("Greetings {name} after {timeout} seconds")) + } +} + +#[restate_sdk::service] +trait Delayer { + async fn delay(seconds: u64) -> Result; +} + +struct DelayerImpl; + +impl Delayer for DelayerImpl { + async fn delay(&self, ctx: Context<'_>, seconds: u64) -> Result { + _ = ctx.sleep(Duration::from_secs(seconds)).await; + info!("Delayed for {seconds} seconds"); + Ok(format!("Delayed {seconds}")) + } +} + +#[tokio::main] +async fn main() { + let env_filter = tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "restate_sdk=info".into()); + let replay_filter = restate_sdk::filter::ReplayAwareFilter; + tracing_subscriber::registry() + .with( + tracing_subscriber::fmt::layer() + .with_filter(env_filter) + .with_filter(replay_filter), + ) + .init(); + HttpServer::new( + Endpoint::builder() + .bind(GreeterImpl.serve()) + .bind(DelayerImpl.serve()) + .build(), + ) + .listen_and_serve("0.0.0.0:9080".parse().unwrap()) + .await; +} diff --git a/src/endpoint/context.rs b/src/endpoint/context.rs index 33de131..8d22ff9 100644 --- a/src/endpoint/context.rs +++ b/src/endpoint/context.rs @@ -28,6 +28,10 @@ pub struct ContextInternalInner { pub(crate) read: InputReceiver, pub(crate) write: OutputSender, pub(super) handler_state: HandlerStateNotifier, + // Flag to indicate whether span replay attribute should be set + // When replaying this is set on the sys call + // When not replaying this is reset on the sys call that transitioned the state + pub(super) tracing_replaying_flag: bool, } impl ContextInternalInner { @@ -42,6 +46,7 @@ impl ContextInternalInner { read, write, handler_state, + tracing_replaying_flag: true, } } @@ -50,6 +55,22 @@ impl ContextInternalInner { .notify_error(e.0.to_string().into(), format!("{:#}", e.0).into(), None); self.handler_state.mark_error(e); } + + pub(super) fn set_tracing_replaying_flag(&mut self) { + if !self.vm.is_processing() { + // Replay record is not yet set in the span + if self.tracing_replaying_flag { + tracing::Span::current().record("replaying", true); + self.tracing_replaying_flag = false; + } + } else { + // Replay record is not yet reset in the span + if !self.tracing_replaying_flag { + tracing::Span::current().record("replaying", false); + self.tracing_replaying_flag = true; + } + } + } } /// Internal context interface. diff --git a/src/endpoint/futures/async_result_poll.rs b/src/endpoint/futures/async_result_poll.rs index 8f6ef5d..6cdd848 100644 --- a/src/endpoint/futures/async_result_poll.rs +++ b/src/endpoint/futures/async_result_poll.rs @@ -83,7 +83,10 @@ impl Future for VmAsyncResultPollFuture { // At this point let's try to take the async result match inner_lock.vm.take_async_result(handle) { - Ok(Some(v)) => return Poll::Ready(Ok(v)), + Ok(Some(v)) => { + inner_lock.set_tracing_replaying_flag(); + return Poll::Ready(Ok(v)); + } Ok(None) => { drop(inner_lock); self.state = Some(PollState::WaitingInput { ctx, handle }); @@ -121,7 +124,10 @@ impl Future for VmAsyncResultPollFuture { // Now try to take async result again match inner_lock.vm.take_async_result(handle) { - Ok(Some(v)) => return Poll::Ready(Ok(v)), + Ok(Some(v)) => { + inner_lock.set_tracing_replaying_flag(); + return Poll::Ready(Ok(v)); + } Ok(None) => { drop(inner_lock); self.state = Some(PollState::WaitingInput { ctx, handle }); diff --git a/src/endpoint/mod.rs b/src/endpoint/mod.rs index 90f907a..fe293d8 100644 --- a/src/endpoint/mod.rs +++ b/src/endpoint/mod.rs @@ -18,6 +18,8 @@ use std::future::poll_fn; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; +#[cfg(feature = "tracing-subscriber")] +use tracing::{field, info_span, Instrument}; const DISCOVERY_CONTENT_TYPE: &str = "application/vnd.restate.endpointmanifest.v1+json"; @@ -344,6 +346,27 @@ impl BidiStreamRunner { .get(&self.svc_name) .expect("service must exist at this point"); + #[cfg(feature = "tracing-subscriber")] + { + let span = info_span!( + "handle", + "rpc.system" = "restate", + "rpc.service" = self.svc_name, + "rpc.method" = self.handler_name, + "replaying" = field::Empty, + ); + handle( + input_rx, + output_tx, + self.vm, + self.svc_name, + self.handler_name, + svc, + ) + .instrument(span) + .await + } + #[cfg(not(feature = "tracing-subscriber"))] handle( input_rx, output_tx, diff --git a/src/filter.rs b/src/filter.rs new file mode 100644 index 0000000..c43aa2c --- /dev/null +++ b/src/filter.rs @@ -0,0 +1,80 @@ +//! Replay aware tracing filter +//! +//! Use this filter to skip tracing events in the service/workflow while replaying. +//! +//! Example: +//! ```rust,no_run +//! use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, Layer}; +//! let replay_filter = restate_sdk::filter::ReplayAwareFilter; +//! tracing_subscriber::registry() +//! .with(tracing_subscriber::fmt::layer().with_filter(replay_filter)) +//! .init(); +//! ``` +use std::fmt::Debug; +use tracing::{ + field::{Field, Visit}, + span::{Attributes, Record}, + Event, Id, Metadata, Subscriber, +}; +use tracing_subscriber::{ + layer::{Context, Filter}, + registry::LookupSpan, + Layer, +}; + +#[derive(Debug)] +struct ReplayField(bool); + +struct ReplayFieldVisitor(bool); + +impl Visit for ReplayFieldVisitor { + fn record_bool(&mut self, field: &Field, value: bool) { + if field.name().eq("replaying") { + self.0 = value; + } + } + + fn record_debug(&mut self, _field: &Field, _value: &dyn Debug) {} +} + +pub struct ReplayAwareFilter; + +impl LookupSpan<'lookup>> Filter for ReplayAwareFilter { + fn enabled(&self, _meta: &Metadata<'_>, _cx: &Context<'_, S>) -> bool { + true + } + + fn event_enabled(&self, event: &Event<'_>, cx: &Context<'_, S>) -> bool { + if let Some(scope) = cx.event_scope(event) { + if let Some(span) = scope.from_root().next() { + let extensions = span.extensions(); + if let Some(replay) = extensions.get::() { + return !replay.0; + } + } + true + } else { + true + } + } + + fn on_new_span(&self, attrs: &Attributes<'_>, id: &Id, ctx: Context<'_, S>) { + if let Some(span) = ctx.span(id) { + let mut visitor = ReplayFieldVisitor(false); + attrs.record(&mut visitor); + let mut extensions = span.extensions_mut(); + extensions.insert::(ReplayField(visitor.0)); + } + } + + fn on_record(&self, id: &Id, values: &Record<'_>, ctx: Context<'_, S>) { + if let Some(span) = ctx.span(id) { + let mut visitor = ReplayFieldVisitor(false); + values.record(&mut visitor); + let mut extensions = span.extensions_mut(); + extensions.replace::(ReplayField(visitor.0)); + } + } +} + +impl Layer for ReplayAwareFilter {} diff --git a/src/lib.rs b/src/lib.rs index 91b6bd0..7dcc14c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -46,6 +46,8 @@ pub mod errors; pub mod http_server; #[cfg(feature = "hyper")] pub mod hyper; +#[cfg(feature = "tracing-subscriber")] +pub mod filter; pub mod serde; /// Entry-point macro to define a Restate [Service](https://docs.restate.dev/concepts/services#services-1).