Skip to content

Commit

Permalink
Add version checks
Browse files Browse the repository at this point in the history
  • Loading branch information
slinkydeveloper committed Oct 1, 2024
1 parent 7825a93 commit 30417f7
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 2 deletions.
26 changes: 25 additions & 1 deletion src/vm/errors.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::service_protocol::{DecodingError, MessageType, UnsupportedVersionError};
use crate::Error;
use crate::{Error, Version};
use std::borrow::Cow;
use std::fmt;

Expand Down Expand Up @@ -62,6 +62,7 @@ pub mod codes {
pub const JOURNAL_MISMATCH: InvocationErrorCode = InvocationErrorCode(570);
pub const PROTOCOL_VIOLATION: InvocationErrorCode = InvocationErrorCode(571);
pub const AWAITING_TWO_ASYNC_RESULTS: InvocationErrorCode = InvocationErrorCode(572);
pub const UNSUPPORTED_FEATURE: InvocationErrorCode = InvocationErrorCode(573);
}

// Const errors
Expand Down Expand Up @@ -204,6 +205,28 @@ pub struct EmptyGetCallInvocationId;
#[error("Cannot decode get call invocation id: {0}")]
pub struct DecodeGetCallInvocationIdUtf8(#[from] pub(crate) std::string::FromUtf8Error);

#[derive(Debug, thiserror::Error)]
#[error("Feature {feature} is not supported by the negotiated protocol version '{current_version}', the minimum required version is '{minimum_required_version}'")]
pub struct UnsupportedFeatureForNegotiatedVersion {
feature: &'static str,
current_version: Version,
minimum_required_version: Version,
}

impl UnsupportedFeatureForNegotiatedVersion {
pub fn new(
feature: &'static str,
current_version: Version,
minimum_required_version: Version,
) -> Self {
Self {
feature,
current_version,
minimum_required_version,
}
}
}

// Conversions to VMError

trait WithInvocationErrorCode {
Expand Down Expand Up @@ -244,3 +267,4 @@ impl_error_code!(DecodeStateKeysUtf8, PROTOCOL_VIOLATION);
impl_error_code!(EmptyStateKeys, PROTOCOL_VIOLATION);
impl_error_code!(EmptyGetCallInvocationId, PROTOCOL_VIOLATION);
impl_error_code!(DecodeGetCallInvocationIdUtf8, PROTOCOL_VIOLATION);
impl_error_code!(UnsupportedFeatureForNegotiatedVersion, UNSUPPORTED_FEATURE);
29 changes: 28 additions & 1 deletion src/vm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::service_protocol::messages::{
};
use crate::service_protocol::{Decoder, RawMessage, Version};
use crate::vm::context::{EagerGetState, EagerGetStateKeys};
use crate::vm::errors::UnexpectedStateError;
use crate::vm::errors::{UnexpectedStateError, UnsupportedFeatureForNegotiatedVersion};
use crate::vm::transitions::*;
use crate::{
AsyncResultCombinator, AsyncResultHandle, CancelInvocationTarget, Error, GetInvocationIdTarget,
Expand Down Expand Up @@ -85,6 +85,25 @@ impl CoreVM {
""
}
}

fn verify_feature_support(
&mut self,
feature: &'static str,
minimum_required_protocol: Version,
) -> VMResult<()> {
if self.version < minimum_required_protocol {
return self.do_transition(HitError {
error: UnsupportedFeatureForNegotiatedVersion::new(
feature,
self.version,
minimum_required_protocol,
)
.into(),
next_retry_delay: None,
});
}
Ok(())
}
}

impl fmt::Debug for CoreVM {
Expand Down Expand Up @@ -414,6 +433,9 @@ impl super::VM for CoreVM {
ret
)]
fn sys_call(&mut self, target: Target, input: Bytes) -> VMResult<AsyncResultHandle> {
if target.idempotency_key.is_some() {
self.verify_feature_support("attach idempotency key to one way call", Version::V3)?;
}
self.do_transition(SysCompletableEntry(
"SysCall",
CallEntryMessage {
Expand All @@ -439,6 +461,9 @@ impl super::VM for CoreVM {
input: Bytes,
delay: Option<Duration>,
) -> VMResult<SendHandle> {
if target.idempotency_key.is_some() {
self.verify_feature_support("attach idempotency key to one way call", Version::V3)?;
}
self.do_transition(SysNonCompletableEntry(
"SysOneWayCall",
OneWayCallEntryMessage {
Expand Down Expand Up @@ -592,6 +617,7 @@ impl super::VM for CoreVM {
&mut self,
call: GetInvocationIdTarget,
) -> VMResult<AsyncResultHandle> {
self.verify_feature_support("get call invocation id", Version::V3)?;
self.do_transition(SysCompletableEntry(
"SysGetCallInvocationId",
GetCallInvocationIdEntryMessage {
Expand All @@ -606,6 +632,7 @@ impl super::VM for CoreVM {

#[instrument(level = "debug", ret)]
fn sys_cancel_invocation(&mut self, target: CancelInvocationTarget) -> VMResult<()> {
self.verify_feature_support("cancel invocation", Version::V3)?;
self.do_transition(SysNonCompletableEntry(
"SysCancelInvocation",
CancelInvocationEntryMessage {
Expand Down

0 comments on commit 30417f7

Please sign in to comment.