diff --git a/Cargo.toml b/Cargo.toml index acfcb4c..458aad7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,7 @@ pin-project-lite = "0.2" rand = { version = "0.8.5", optional = true } regress = "0.10" restate-sdk-macros = { version = "0.3.0", path = "macros" } -restate-sdk-shared-core = "0.1.0" +restate-sdk-shared-core = { git = "https://github.com/restatedev/sdk-shared-core.git", branch = "issues/cancel-invocation-get-invocation-id" } serde = "1.0" serde_json = "1.0" thiserror = "1.0.63" diff --git a/src/context/mod.rs b/src/context/mod.rs index 0726c0b..72c1f27 100644 --- a/src/context/mod.rs +++ b/src/context/mod.rs @@ -9,7 +9,7 @@ use std::time::Duration; mod request; mod run; -pub use request::{Request, RequestTarget}; +pub use request::{CallFuture, InvocationHandle, Request, RequestTarget}; pub use run::{RunClosure, RunFuture, RunRetryPolicy}; pub type HeaderMap = http::HeaderMap; @@ -228,6 +228,11 @@ pub trait ContextClient<'ctx>: private::SealedContext<'ctx> { Request::new(self.inner_context(), request_target, req) } + /// Create an [`InvocationHandle`] from an invocation id. + fn invocation_handle(&self, invocation_id: String) -> impl InvocationHandle + 'ctx { + self.inner_context().invocation_handle(invocation_id) + } + /// Create a service client. The service client is generated by the [`restate_sdk_macros::service`] macro with the same name of the trait suffixed with `Client`. /// /// ```rust,no_run diff --git a/src/context/request.rs b/src/context/request.rs index d0104f5..0bf9196 100644 --- a/src/context/request.rs +++ b/src/context/request.rs @@ -87,7 +87,7 @@ impl<'a, Req, Res> Request<'a, Req, Res> { } /// Call a service. This returns a future encapsulating the response. - pub fn call(self) -> impl Future> + Send + pub fn call(self) -> impl CallFuture> + Send where Req: Serialize + 'static, Res: Deserialize + 'static, @@ -96,7 +96,7 @@ impl<'a, Req, Res> Request<'a, Req, Res> { } /// Send the request to the service, without waiting for the response. - pub fn send(self) + pub fn send(self) -> impl InvocationHandle where Req: Serialize + 'static, { @@ -104,10 +104,17 @@ impl<'a, Req, Res> Request<'a, Req, Res> { } /// Schedule the request to the service, without waiting for the response. - pub fn send_with_delay(self, duration: Duration) + pub fn send_with_delay(self, duration: Duration) -> impl InvocationHandle where Req: Serialize + 'static, { self.ctx.send(self.request_target, self.req, Some(duration)) } } + +pub trait InvocationHandle { + fn invocation_id(&self) -> impl Future> + Send; + fn cancel(&self); +} + +pub trait CallFuture: Future + InvocationHandle {} diff --git a/src/endpoint/context.rs b/src/endpoint/context.rs index 33de131..2354580 100644 --- a/src/endpoint/context.rs +++ b/src/endpoint/context.rs @@ -1,4 +1,6 @@ -use crate::context::{Request, RequestTarget, RunClosure, RunRetryPolicy}; +use crate::context::{ + CallFuture, InvocationHandle, Request, RequestTarget, RunClosure, RunRetryPolicy, +}; use crate::endpoint::futures::async_result_poll::VmAsyncResultPollFuture; use crate::endpoint::futures::intercept_error::InterceptErrorFuture; use crate::endpoint::futures::trap::TrapFuture; @@ -10,7 +12,8 @@ use futures::future::Either; use futures::{FutureExt, TryFutureExt}; use pin_project_lite::pin_project; use restate_sdk_shared_core::{ - CoreVM, Failure, NonEmptyValue, RetryPolicy, RunEnterResult, RunExitResult, TakeOutputResult, + AsyncResultHandle, CancelInvocationTarget, CoreVM, Failure, GetInvocationIdTarget, + NonEmptyValue, RetryPolicy, RunEnterResult, RunExitResult, SendHandle, TakeOutputResult, Target, Value, VM, }; use std::borrow::Cow; @@ -215,6 +218,10 @@ impl ContextInternal { variant: "state_keys", syscall: "get_state", }), + Ok(Value::InvocationId(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "invocation_id", + syscall: "get_state", + }), Err(e) => Err(e), }); @@ -230,11 +237,15 @@ impl ContextInternal { .map(|res| match res { Ok(Value::Void) => Err(ErrorInner::UnexpectedValueVariantForSyscall { variant: "empty", - syscall: "get_state", + syscall: "get_keys", }), Ok(Value::Success(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { variant: "success", - syscall: "get_state", + syscall: "get_keys", + }), + Ok(Value::InvocationId(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "invocation_id", + syscall: "get_keys", }), Ok(Value::Failure(f)) => Ok(Err(f.into())), Ok(Value::StateKeys(s)) => Ok(Ok(s)), @@ -289,6 +300,10 @@ impl ContextInternal { variant: "state_keys", syscall: "sleep", }), + Ok(Value::InvocationId(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "invocation_id", + syscall: "sleep", + }), }); InterceptErrorFuture::new(self.clone(), poll_future.map_err(Error)) @@ -302,7 +317,7 @@ impl ContextInternal { &self, request_target: RequestTarget, req: Req, - ) -> impl Future> + Send + Sync { + ) -> impl CallFuture> + Send + Sync { let mut inner_lock = must_lock!(self.inner); let input = match Req::serialize(&req) { @@ -322,31 +337,17 @@ impl ContextInternal { let maybe_handle = inner_lock.vm.sys_call(request_target.into(), input); drop(inner_lock); - let poll_future = VmAsyncResultPollFuture::new(Cow::Borrowed(&self.inner), maybe_handle) - .map(|res| match res { - Ok(Value::Void) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "empty", - syscall: "call", - }), - Ok(Value::Success(mut s)) => { - let t = Res::deserialize(&mut s).map_err(|e| ErrorInner::Deserialization { - syscall: "call", - err: Box::new(e), - })?; - Ok(Ok(t)) - } - Ok(Value::Failure(f)) => Ok(Err(f.into())), - Ok(Value::StateKeys(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "state_keys", - syscall: "call", - }), - Err(e) => Err(e), - }); + let call_future_impl = CallFutureImpl { + poll_future: VmAsyncResultPollFuture::new( + Cow::Borrowed(&self.inner), + maybe_handle.clone(), + ), + res: PhantomData, + ctx: self.clone(), + call_handle: maybe_handle.ok(), + }; - Either::Left(InterceptErrorFuture::new( - self.clone(), - poll_future.map_err(Error), - )) + Either::Left(InterceptErrorFuture::new(self.clone(), call_future_impl)) } pub fn send( @@ -354,12 +355,17 @@ impl ContextInternal { request_target: RequestTarget, req: Req, delay: Option, - ) { + ) -> impl InvocationHandle { let mut inner_lock = must_lock!(self.inner); match Req::serialize(&req) { Ok(t) => { - let _ = inner_lock.vm.sys_send(request_target.into(), t, delay); + let result = inner_lock.vm.sys_send(request_target.into(), t, delay); + drop(inner_lock); + SendRequestHandle { + ctx: self.clone(), + send_handle: result.ok(), + } } Err(e) => { inner_lock.fail( @@ -369,8 +375,19 @@ impl ContextInternal { } .into(), ); + SendRequestHandle { + ctx: self.clone(), + send_handle: None, + } } - }; + } + } + + pub fn invocation_handle(&self, invocation_id: String) -> impl InvocationHandle { + InvocationIdBackedInvocationHandle { + ctx: self.clone(), + invocation_id, + } } pub fn awakeable( @@ -409,6 +426,10 @@ impl ContextInternal { variant: "state_keys", syscall: "awakeable", }), + Ok(Value::InvocationId(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "invocation_id", + syscall: "awakeable", + }), Err(e) => Err(e), }); @@ -468,6 +489,10 @@ impl ContextInternal { variant: "state_keys", syscall: "promise", }), + Ok(Value::InvocationId(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "invocation_id", + syscall: "promise", + }), Err(e) => Err(e), }); @@ -495,6 +520,10 @@ impl ContextInternal { variant: "state_keys", syscall: "peek_promise", }), + Ok(Value::InvocationId(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "invocation_id", + syscall: "peek_promise", + }), Err(e) => Err(e), }); @@ -766,9 +795,218 @@ where syscall: "run", } .into()), + Value::InvocationId(_) => { + Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "invocation_id", + syscall: "run", + } + .into()) + } }); } } } } } + +struct SendRequestHandle { + ctx: ContextInternal, + send_handle: Option, +} + +impl InvocationHandle for SendRequestHandle { + fn invocation_id(&self) -> impl Future> + Send { + if let Some(ref send_handle) = self.send_handle { + let maybe_handle = { + must_lock!(self.ctx.inner) + .vm + .sys_get_call_invocation_id(GetInvocationIdTarget::SendEntry(*send_handle)) + }; + + let poll_future = VmAsyncResultPollFuture::new( + Cow::Borrowed(&self.ctx.inner), + maybe_handle, + ) + .map(|res| match res { + Ok(Value::Failure(f)) => Ok(Err(f.into())), + Ok(Value::InvocationId(s)) => Ok(Ok(s)), + Err(e) => Err(e), + Ok(Value::StateKeys(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "state_keys", + syscall: "get_call_invocation_id", + }), + Ok(Value::Void) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "void", + syscall: "get_call_invocation_id", + }), + Ok(Value::Success(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "success", + syscall: "get_call_invocation_id", + }), + }); + + Either::Left(InterceptErrorFuture::new( + self.ctx.clone(), + poll_future.map_err(Error), + )) + } else { + // If the send didn't succeed, trap the execution + Either::Right(TrapFuture::default()) + } + } + + fn cancel(&self) { + if let Some(ref send_handle) = self.send_handle { + let mut inner_lock = must_lock!(self.ctx.inner); + let _ = inner_lock + .vm + .sys_cancel_invocation(CancelInvocationTarget::SendEntry(*send_handle)); + } + // If the send didn't succeed, then simply ignore the cancel + } +} + +pin_project! { + struct CallFutureImpl { + #[pin] + poll_future: VmAsyncResultPollFuture, + res: PhantomData R>, + ctx: ContextInternal, + call_handle: Option, + } +} + +impl Future for CallFutureImpl { + type Output = Result, Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + this.poll_future + .poll(cx) + .map(|res| match res { + Ok(Value::Void) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "empty", + syscall: "call", + }), + Ok(Value::Success(mut s)) => { + let t = Res::deserialize(&mut s).map_err(|e| ErrorInner::Deserialization { + syscall: "call", + err: Box::new(e), + })?; + Ok(Ok(t)) + } + Ok(Value::Failure(f)) => Ok(Err(f.into())), + Ok(Value::StateKeys(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "state_keys", + syscall: "call", + }), + Ok(Value::InvocationId(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "invocation_id", + syscall: "call", + }), + Err(e) => Err(e), + }) + .map(|res| res.map_err(Error)) + } +} + +impl InvocationHandle for CallFutureImpl { + fn invocation_id(&self) -> impl Future> + Send { + if let Some(ref call_handle) = self.call_handle { + let maybe_handle = { + must_lock!(self.ctx.inner) + .vm + .sys_get_call_invocation_id(GetInvocationIdTarget::CallEntry(*call_handle)) + }; + + let poll_future = VmAsyncResultPollFuture::new( + Cow::Borrowed(&self.ctx.inner), + maybe_handle, + ) + .map(|res| match res { + Ok(Value::Failure(f)) => Ok(Err(f.into())), + Ok(Value::InvocationId(s)) => Ok(Ok(s)), + Err(e) => Err(e), + Ok(Value::StateKeys(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "state_keys", + syscall: "get_call_invocation_id", + }), + Ok(Value::Void) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "void", + syscall: "get_call_invocation_id", + }), + Ok(Value::Success(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "success", + syscall: "get_call_invocation_id", + }), + }); + + Either::Left(InterceptErrorFuture::new( + self.ctx.clone(), + poll_future.map_err(Error), + )) + } else { + // If the send didn't succeed, trap the execution + Either::Right(TrapFuture::default()) + } + } + + fn cancel(&self) { + if let Some(ref call_handle) = self.call_handle { + let mut inner_lock = must_lock!(self.ctx.inner); + let _ = inner_lock + .vm + .sys_cancel_invocation(CancelInvocationTarget::CallEntry(*call_handle)); + } + // If the send didn't succeed, then simply ignore the cancel + } +} + +impl CallFuture, Error>> + for CallFutureImpl +{ +} + +impl InvocationHandle for Either { + fn invocation_id(&self) -> impl Future> + Send { + match self { + Either::Left(l) => Either::Left(l.invocation_id()), + Either::Right(r) => Either::Right(r.invocation_id()), + } + } + + fn cancel(&self) { + match self { + Either::Left(l) => l.cancel(), + Either::Right(r) => r.cancel(), + } + } +} + +impl CallFuture for Either +where + A: CallFuture, + B: CallFuture, +{ +} + +struct InvocationIdBackedInvocationHandle { + ctx: ContextInternal, + invocation_id: String, +} + +impl InvocationHandle for InvocationIdBackedInvocationHandle { + fn invocation_id(&self) -> impl Future> + Send { + ready(Ok(self.invocation_id.clone())) + } + + fn cancel(&self) { + let mut inner_lock = must_lock!(self.ctx.inner); + let _ = inner_lock + .vm + .sys_cancel_invocation(CancelInvocationTarget::InvocationId( + self.invocation_id.clone(), + )); + } +} diff --git a/src/endpoint/futures/intercept_error.rs b/src/endpoint/futures/intercept_error.rs index b486fbf..606df95 100644 --- a/src/endpoint/futures/intercept_error.rs +++ b/src/endpoint/futures/intercept_error.rs @@ -1,5 +1,6 @@ -use crate::context::{RunFuture, RunRetryPolicy}; +use crate::context::{CallFuture, InvocationHandle, RunFuture, RunRetryPolicy}; use crate::endpoint::{ContextInternal, Error}; +use crate::errors::TerminalError; use pin_project_lite::pin_project; use std::future::Future; use std::pin::Pin; @@ -58,3 +59,15 @@ where self } } + +impl InvocationHandle for InterceptErrorFuture { + fn invocation_id(&self) -> impl Future> + Send { + self.fut.invocation_id() + } + + fn cancel(&self) { + self.fut.cancel() + } +} + +impl CallFuture for InterceptErrorFuture where F: CallFuture> {} diff --git a/src/endpoint/futures/trap.rs b/src/endpoint/futures/trap.rs index 9b0269d..b614de1 100644 --- a/src/endpoint/futures/trap.rs +++ b/src/endpoint/futures/trap.rs @@ -1,3 +1,5 @@ +use crate::context::{CallFuture, InvocationHandle}; +use crate::errors::TerminalError; use std::future::Future; use std::marker::PhantomData; use std::pin::Pin; @@ -20,3 +22,13 @@ impl Future for TrapFuture { Poll::Pending } } + +impl InvocationHandle for TrapFuture { + fn invocation_id(&self) -> impl Future> + Send { + TrapFuture::default() + } + + fn cancel(&self) {} +} + +impl CallFuture for TrapFuture {} diff --git a/src/endpoint/mod.rs b/src/endpoint/mod.rs index c539328..464e2a2 100644 --- a/src/endpoint/mod.rs +++ b/src/endpoint/mod.rs @@ -176,8 +176,8 @@ impl Default for Builder { Self { svcs: Default::default(), discovery: crate::discovery::Endpoint { - max_protocol_version: 2, - min_protocol_version: 2, + max_protocol_version: 3, + min_protocol_version: 3, protocol_mode: Some(crate::discovery::ProtocolMode::BidiStream), services: vec![], }, diff --git a/src/lib.rs b/src/lib.rs index 91b6bd0..6a32159 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -194,9 +194,10 @@ pub mod prelude { pub use crate::http_server::HttpServer; pub use crate::context::{ - Context, ContextAwakeables, ContextClient, ContextPromises, ContextReadState, - ContextSideEffects, ContextTimers, ContextWriteState, HeaderMap, ObjectContext, Request, - RunFuture, RunRetryPolicy, SharedObjectContext, SharedWorkflowContext, WorkflowContext, + CallFuture, Context, ContextAwakeables, ContextClient, ContextPromises, ContextReadState, + ContextSideEffects, ContextTimers, ContextWriteState, HeaderMap, InvocationHandle, + ObjectContext, Request, RunFuture, RunRetryPolicy, SharedObjectContext, + SharedWorkflowContext, WorkflowContext, }; pub use crate::endpoint::Endpoint; pub use crate::errors::{HandlerError, HandlerResult, TerminalError}; diff --git a/test-services/Dockerfile b/test-services/Dockerfile index c041674..50e4ff7 100644 --- a/test-services/Dockerfile +++ b/test-services/Dockerfile @@ -7,5 +7,6 @@ RUN cargo build -p test-services RUN cp ./target/debug/test-services /bin/server ENV RUST_LOG="debug,restate_shared_core=trace" +ENV RUST_BACKTRACE=1 CMD ["/bin/server"] \ No newline at end of file diff --git a/test-services/src/proxy.rs b/test-services/src/proxy.rs index 36954f6..d202b2c 100644 --- a/test-services/src/proxy.rs +++ b/test-services/src/proxy.rs @@ -46,7 +46,7 @@ pub(crate) trait Proxy { #[name = "call"] async fn call(req: Json) -> HandlerResult>>; #[name = "oneWayCall"] - async fn one_way_call(req: Json) -> HandlerResult<()>; + async fn one_way_call(req: Json) -> HandlerResult; #[name = "manyCalls"] async fn many_calls(req: Json>) -> HandlerResult<()>; } @@ -70,16 +70,16 @@ impl Proxy for ProxyImpl { &self, ctx: Context<'_>, Json(req): Json, - ) -> HandlerResult<()> { + ) -> HandlerResult { let request = ctx.request::<_, ()>(req.to_target(), req.message); - if let Some(delay_millis) = req.delay_millis { - request.send_with_delay(Duration::from_millis(delay_millis)); + let invocation_id = if let Some(delay_millis) = req.delay_millis { + request.send_with_delay(Duration::from_millis(delay_millis)).invocation_id().await? } else { - request.send(); - } + request.send().invocation_id().await? + }; - Ok(()) + Ok(invocation_id) } async fn many_calls( diff --git a/test-services/src/test_utils_service.rs b/test-services/src/test_utils_service.rs index a1d84a1..672b0f0 100644 --- a/test-services/src/test_utils_service.rs +++ b/test-services/src/test_utils_service.rs @@ -5,6 +5,7 @@ use futures::FutureExt; use restate_sdk::prelude::*; use serde::{Deserialize, Serialize}; use std::collections::HashMap; +use std::convert::Infallible; use std::sync::atomic::{AtomicU8, Ordering}; use std::sync::Arc; use std::time::Duration; @@ -62,6 +63,8 @@ pub(crate) trait TestUtilsService { async fn count_executed_side_effects(increments: u32) -> HandlerResult; #[name = "getEnvVariable"] async fn get_env_variable(env: String) -> HandlerResult; + #[name = "cancelInvocation"] + async fn cancel_invocation(invocation_id: String) -> Result<(), Infallible>; #[name = "interpretCommands"] async fn interpret_commands(req: Json) -> HandlerResult<()>; } @@ -155,6 +158,12 @@ impl TestUtilsService for TestUtilsServiceImpl { Ok(std::env::var(env).ok().unwrap_or_default()) } + async fn cancel_invocation(&self, ctx: Context<'_>, invocation_id: String) -> Result<(), Infallible> { + ctx.invocation_handle(invocation_id) + .cancel(); + Ok(()) + } + async fn interpret_commands( &self, context: Context<'_>,