Skip to content

Commit

Permalink
Added replay aware tracing filter
Browse files Browse the repository at this point in the history
  • Loading branch information
h7kanna committed Oct 30, 2024
1 parent 4c828ed commit 54c6172
Show file tree
Hide file tree
Showing 7 changed files with 207 additions and 3 deletions.
9 changes: 8 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,16 @@ license = "MIT"
repository = "https://github.com/restatedev/sdk-rust"
rust-version = "1.76.0"

[[example]]
name = "tracing"
path = "examples/tracing.rs"
required-features = ["tracing-subscriber"]

[features]
default = ["http_server", "rand", "uuid"]
hyper = ["dep:hyper", "http-body-util", "restate-sdk-shared-core/http"]
http_server = ["hyper", "hyper/server", "hyper/http2", "hyper-util", "tokio/net", "tokio/signal", "tokio/macros"]
tracing-subscriber = ["dep:tracing-subscriber"]

[dependencies]
bytes = "1.6.1"
Expand All @@ -30,11 +36,12 @@ thiserror = "1.0.63"
tokio = { version = "1", default-features = false, features = ["sync"] }
tower-service = "0.3"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["registry"], optional = true }
uuid = { version = "1.10.0", optional = true }

[dev-dependencies]
tokio = { version = "1", features = ["full"] }
tracing-subscriber = "0.3"
tracing-subscriber = { version = "0.3", features = ["env-filter", "registry"] }
trybuild = "1.0"
reqwest = { version = "0.12", features = ["json"] }
rand = "0.8.5"
Expand Down
65 changes: 65 additions & 0 deletions examples/tracing.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
use restate_sdk::prelude::*;
use std::convert::Infallible;
use std::time::Duration;
use tracing::info;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, Layer};

#[restate_sdk::service]
trait Greeter {
async fn greet(name: String) -> Result<String, Infallible>;
}

struct GreeterImpl;

impl Greeter for GreeterImpl {
async fn greet(&self, ctx: Context<'_>, name: String) -> Result<String, Infallible> {
let timeout = 60; // More than suspension timeout to trigger replay
info!("This will be logged on replay");
_ = ctx.service_client::<DelayerClient>().delay(1).call().await;
info!("This will not be logged on replay");
_ = ctx
.service_client::<DelayerClient>()
.delay(timeout)
.call()
.await;
info!("This will be logged on processing after suspension");
Ok(format!("Greetings {name} after {timeout} seconds"))
}
}

#[restate_sdk::service]
trait Delayer {
async fn delay(seconds: u64) -> Result<String, Infallible>;
}

struct DelayerImpl;

impl Delayer for DelayerImpl {
async fn delay(&self, ctx: Context<'_>, seconds: u64) -> Result<String, Infallible> {
_ = ctx.sleep(Duration::from_secs(seconds)).await;
info!("Delayed for {seconds} seconds");
Ok(format!("Delayed {seconds}"))
}
}

#[tokio::main]
async fn main() {
let env_filter = tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "restate_sdk=info".into());
let replay_filter = restate_sdk::filter::ReplayAwareFilter;
tracing_subscriber::registry()
.with(
tracing_subscriber::fmt::layer()
.with_filter(env_filter)
.with_filter(replay_filter),
)
.init();
HttpServer::new(
Endpoint::builder()
.bind(GreeterImpl.serve())
.bind(DelayerImpl.serve())
.build(),
)
.listen_and_serve("0.0.0.0:9080".parse().unwrap())
.await;
}
21 changes: 21 additions & 0 deletions src/endpoint/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ pub struct ContextInternalInner {
pub(crate) read: InputReceiver,
pub(crate) write: OutputSender,
pub(super) handler_state: HandlerStateNotifier,
// Flag to indicate whether span replay attribute should be set
// When replaying this is set on the sys call
// When not replaying this is reset on the sys call that transitioned the state
pub(super) tracing_replaying_flag: bool,
}

impl ContextInternalInner {
Expand All @@ -42,6 +46,7 @@ impl ContextInternalInner {
read,
write,
handler_state,
tracing_replaying_flag: true,
}
}

Expand All @@ -50,6 +55,22 @@ impl ContextInternalInner {
.notify_error(e.0.to_string().into(), format!("{:#}", e.0).into(), None);
self.handler_state.mark_error(e);
}

pub(super) fn set_tracing_replaying_flag(&mut self) {
if !self.vm.is_processing() {
// Replay record is not yet set in the span
if self.tracing_replaying_flag {
tracing::Span::current().record("replaying", true);
self.tracing_replaying_flag = false;
}
} else {
// Replay record is not yet reset in the span
if !self.tracing_replaying_flag {
tracing::Span::current().record("replaying", false);
self.tracing_replaying_flag = true;
}
}
}
}

/// Internal context interface.
Expand Down
10 changes: 8 additions & 2 deletions src/endpoint/futures/async_result_poll.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,10 @@ impl Future for VmAsyncResultPollFuture {

// At this point let's try to take the async result
match inner_lock.vm.take_async_result(handle) {
Ok(Some(v)) => return Poll::Ready(Ok(v)),
Ok(Some(v)) => {
inner_lock.set_tracing_replaying_flag();
return Poll::Ready(Ok(v));
}
Ok(None) => {
drop(inner_lock);
self.state = Some(PollState::WaitingInput { ctx, handle });
Expand Down Expand Up @@ -121,7 +124,10 @@ impl Future for VmAsyncResultPollFuture {

// Now try to take async result again
match inner_lock.vm.take_async_result(handle) {
Ok(Some(v)) => return Poll::Ready(Ok(v)),
Ok(Some(v)) => {
inner_lock.set_tracing_replaying_flag();
return Poll::Ready(Ok(v));
}
Ok(None) => {
drop(inner_lock);
self.state = Some(PollState::WaitingInput { ctx, handle });
Expand Down
23 changes: 23 additions & 0 deletions src/endpoint/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ use std::future::poll_fn;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
#[cfg(feature = "tracing-subscriber")]
use tracing::{field, info_span, Instrument};

const DISCOVERY_CONTENT_TYPE: &str = "application/vnd.restate.endpointmanifest.v1+json";

Expand Down Expand Up @@ -344,6 +346,27 @@ impl BidiStreamRunner {
.get(&self.svc_name)
.expect("service must exist at this point");

#[cfg(feature = "tracing-subscriber")]
{
let span = info_span!(
"handle",
"rpc.system" = "restate",
"rpc.service" = self.svc_name,
"rpc.method" = self.handler_name,
"replaying" = field::Empty,
);
handle(
input_rx,
output_tx,
self.vm,
self.svc_name,
self.handler_name,
svc,
)
.instrument(span)
.await
}
#[cfg(not(feature = "tracing-subscriber"))]
handle(
input_rx,
output_tx,
Expand Down
80 changes: 80 additions & 0 deletions src/filter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
//! Replay aware tracing filter
//!
//! Use this filter to skip tracing events in the service/workflow while replaying.
//!
//! Example:
//! ```rust,no_run
//! use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, Layer};
//! let replay_filter = restate_sdk::filter::ReplayAwareFilter;
//! tracing_subscriber::registry()
//! .with(tracing_subscriber::fmt::layer().with_filter(replay_filter))
//! .init();
//! ```
use std::fmt::Debug;
use tracing::{
field::{Field, Visit},
span::{Attributes, Record},
Event, Id, Metadata, Subscriber,
};
use tracing_subscriber::{
layer::{Context, Filter},
registry::LookupSpan,
Layer,
};

#[derive(Debug)]
struct ReplayField(bool);

struct ReplayFieldVisitor(bool);

impl Visit for ReplayFieldVisitor {
fn record_bool(&mut self, field: &Field, value: bool) {
if field.name().eq("replaying") {
self.0 = value;
}
}

fn record_debug(&mut self, _field: &Field, _value: &dyn Debug) {}
}

pub struct ReplayAwareFilter;

impl<S: Subscriber + for<'lookup> LookupSpan<'lookup>> Filter<S> for ReplayAwareFilter {
fn enabled(&self, _meta: &Metadata<'_>, _cx: &Context<'_, S>) -> bool {
true
}

fn event_enabled(&self, event: &Event<'_>, cx: &Context<'_, S>) -> bool {
if let Some(scope) = cx.event_scope(event) {
if let Some(span) = scope.from_root().next() {
let extensions = span.extensions();
if let Some(replay) = extensions.get::<ReplayField>() {
return !replay.0;
}
}
true
} else {
true
}
}

fn on_new_span(&self, attrs: &Attributes<'_>, id: &Id, ctx: Context<'_, S>) {
if let Some(span) = ctx.span(id) {
let mut visitor = ReplayFieldVisitor(false);
attrs.record(&mut visitor);
let mut extensions = span.extensions_mut();
extensions.insert::<ReplayField>(ReplayField(visitor.0));
}
}

fn on_record(&self, id: &Id, values: &Record<'_>, ctx: Context<'_, S>) {
if let Some(span) = ctx.span(id) {
let mut visitor = ReplayFieldVisitor(false);
values.record(&mut visitor);
let mut extensions = span.extensions_mut();
extensions.replace::<ReplayField>(ReplayField(visitor.0));
}
}
}

impl<S: Subscriber> Layer<S> for ReplayAwareFilter {}
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ pub mod errors;
pub mod http_server;
#[cfg(feature = "hyper")]
pub mod hyper;

Check warning on line 48 in src/lib.rs

View workflow job for this annotation

GitHub Actions / Build and test (ubuntu-22.04)

Diff in /home/runner/work/sdk-rust/sdk-rust/src/lib.rs
#[cfg(feature = "tracing-subscriber")]
pub mod filter;
pub mod serde;

/// Entry-point macro to define a Restate [Service](https://docs.restate.dev/concepts/services#services-1).
Expand Down

0 comments on commit 54c6172

Please sign in to comment.