From 30417f7b020a55bed56ce32217f403be0a2e916b Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Tue, 1 Oct 2024 18:16:39 +0200 Subject: [PATCH] Add version checks --- src/vm/errors.rs | 26 +++++++++++++++++++++++++- src/vm/mod.rs | 29 ++++++++++++++++++++++++++++- 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/src/vm/errors.rs b/src/vm/errors.rs index d08bca4..6cc45d1 100644 --- a/src/vm/errors.rs +++ b/src/vm/errors.rs @@ -1,5 +1,5 @@ use crate::service_protocol::{DecodingError, MessageType, UnsupportedVersionError}; -use crate::Error; +use crate::{Error, Version}; use std::borrow::Cow; use std::fmt; @@ -62,6 +62,7 @@ pub mod codes { pub const JOURNAL_MISMATCH: InvocationErrorCode = InvocationErrorCode(570); pub const PROTOCOL_VIOLATION: InvocationErrorCode = InvocationErrorCode(571); pub const AWAITING_TWO_ASYNC_RESULTS: InvocationErrorCode = InvocationErrorCode(572); + pub const UNSUPPORTED_FEATURE: InvocationErrorCode = InvocationErrorCode(573); } // Const errors @@ -204,6 +205,28 @@ pub struct EmptyGetCallInvocationId; #[error("Cannot decode get call invocation id: {0}")] pub struct DecodeGetCallInvocationIdUtf8(#[from] pub(crate) std::string::FromUtf8Error); +#[derive(Debug, thiserror::Error)] +#[error("Feature {feature} is not supported by the negotiated protocol version '{current_version}', the minimum required version is '{minimum_required_version}'")] +pub struct UnsupportedFeatureForNegotiatedVersion { + feature: &'static str, + current_version: Version, + minimum_required_version: Version, +} + +impl UnsupportedFeatureForNegotiatedVersion { + pub fn new( + feature: &'static str, + current_version: Version, + minimum_required_version: Version, + ) -> Self { + Self { + feature, + current_version, + minimum_required_version, + } + } +} + // Conversions to VMError trait WithInvocationErrorCode { @@ -244,3 +267,4 @@ impl_error_code!(DecodeStateKeysUtf8, PROTOCOL_VIOLATION); impl_error_code!(EmptyStateKeys, PROTOCOL_VIOLATION); impl_error_code!(EmptyGetCallInvocationId, PROTOCOL_VIOLATION); impl_error_code!(DecodeGetCallInvocationIdUtf8, PROTOCOL_VIOLATION); +impl_error_code!(UnsupportedFeatureForNegotiatedVersion, UNSUPPORTED_FEATURE); diff --git a/src/vm/mod.rs b/src/vm/mod.rs index 45a836c..51df4e0 100644 --- a/src/vm/mod.rs +++ b/src/vm/mod.rs @@ -11,7 +11,7 @@ use crate::service_protocol::messages::{ }; use crate::service_protocol::{Decoder, RawMessage, Version}; use crate::vm::context::{EagerGetState, EagerGetStateKeys}; -use crate::vm::errors::UnexpectedStateError; +use crate::vm::errors::{UnexpectedStateError, UnsupportedFeatureForNegotiatedVersion}; use crate::vm::transitions::*; use crate::{ AsyncResultCombinator, AsyncResultHandle, CancelInvocationTarget, Error, GetInvocationIdTarget, @@ -85,6 +85,25 @@ impl CoreVM { "" } } + + fn verify_feature_support( + &mut self, + feature: &'static str, + minimum_required_protocol: Version, + ) -> VMResult<()> { + if self.version < minimum_required_protocol { + return self.do_transition(HitError { + error: UnsupportedFeatureForNegotiatedVersion::new( + feature, + self.version, + minimum_required_protocol, + ) + .into(), + next_retry_delay: None, + }); + } + Ok(()) + } } impl fmt::Debug for CoreVM { @@ -414,6 +433,9 @@ impl super::VM for CoreVM { ret )] fn sys_call(&mut self, target: Target, input: Bytes) -> VMResult { + if target.idempotency_key.is_some() { + self.verify_feature_support("attach idempotency key to one way call", Version::V3)?; + } self.do_transition(SysCompletableEntry( "SysCall", CallEntryMessage { @@ -439,6 +461,9 @@ impl super::VM for CoreVM { input: Bytes, delay: Option, ) -> VMResult { + if target.idempotency_key.is_some() { + self.verify_feature_support("attach idempotency key to one way call", Version::V3)?; + } self.do_transition(SysNonCompletableEntry( "SysOneWayCall", OneWayCallEntryMessage { @@ -592,6 +617,7 @@ impl super::VM for CoreVM { &mut self, call: GetInvocationIdTarget, ) -> VMResult { + self.verify_feature_support("get call invocation id", Version::V3)?; self.do_transition(SysCompletableEntry( "SysGetCallInvocationId", GetCallInvocationIdEntryMessage { @@ -606,6 +632,7 @@ impl super::VM for CoreVM { #[instrument(level = "debug", ret)] fn sys_cancel_invocation(&mut self, target: CancelInvocationTarget) -> VMResult<()> { + self.verify_feature_support("cancel invocation", Version::V3)?; self.do_transition(SysNonCompletableEntry( "SysCancelInvocation", CancelInvocationEntryMessage {