diff --git a/src/local_state_store.ts b/src/local_state_store.ts index f57620ec..880f5271 100644 --- a/src/local_state_store.ts +++ b/src/local_state_store.ts @@ -13,6 +13,7 @@ import { ClearStateEntryMessage, GetStateEntryMessage, GetStateKeysEntryMessage, + GetStateKeysEntryMessage_StateKeys, SetStateEntryMessage, StartMessage_StateEntry, } from "./generated/proto/protocol"; @@ -28,44 +29,39 @@ export class LocalStateStore { ); } - public get(key: string): GetStateEntryMessage { - const present = this.state.has(key.toString()); - if (!present && this.isPartial) { - // Partial eager state, so retrieve state from the runtime - return GetStateEntryMessage.create({ key: Buffer.from(key) }); - } else if (!present) { - // Complete eager state, so state entry is null - return GetStateEntryMessage.create({ - key: Buffer.from(key), - empty: Empty.create({}), - }); + // Returns true if completed + public tryCompleteGet(key: string, msg: GetStateEntryMessage): boolean { + const stateEntry = this.state.get(key); + if (stateEntry === undefined) { + if (this.isPartial) { + // Partial eager state, so retrieve state from the runtime + return false; + } else { + // Complete eager state, so state entry is null + msg.empty = Empty.create({}); + return true; + } } - const stateEntry = this.state.get(key.toString()); if (stateEntry instanceof Buffer) { - return GetStateEntryMessage.create({ - key: Buffer.from(key), - value: stateEntry, - }); + msg.value = stateEntry; } else { // stateEntry is Empty - return GetStateEntryMessage.create({ - key: Buffer.from(key), - empty: stateEntry, - }); + msg.empty = stateEntry; } + return true; } - public getStateKeys(): GetStateKeysEntryMessage { + // Returns true if completed + public tryCompletedGetStateKeys(msg: GetStateKeysEntryMessage): boolean { if (this.isPartial) { - return {}; + return false; } - return GetStateKeysEntryMessage.create({ - value: { - keys: Array.from(this.state.keys()).map((b) => Buffer.from(b)), - }, + msg.value = GetStateKeysEntryMessage_StateKeys.create({ + keys: Array.from(this.state.keys()).map((b) => Buffer.from(b)), }); + return true; } public set(key: string, value: T): SetStateEntryMessage { diff --git a/src/restate_context_impl.ts b/src/restate_context_impl.ts index 80e67143..d3a16355 100644 --- a/src/restate_context_impl.ts +++ b/src/restate_context_impl.ts @@ -24,6 +24,8 @@ import { BackgroundInvokeEntryMessage, CompleteAwakeableEntryMessage, DeepPartial, + GetStateEntryMessage, + GetStateKeysEntryMessage, GetStateKeysEntryMessage_StateKeys, InvokeEntryMessage, SleepEntryMessage, @@ -112,21 +114,24 @@ export class RestateContextImpl implements RestateGrpcContext, RpcContext { this.checkState("get state"); // Create the message and let the state machine process it - const msg = this.stateMachine.localStateStore.get(name); + const msg = GetStateEntryMessage.create({ key: Buffer.from(name) }); + const completed = this.stateMachine.localStateStore.tryCompleteGet( + name, + msg + ); const getState = async (): Promise => { const result = await this.stateMachine.handleUserCodeMessage( GET_STATE_ENTRY_MESSAGE_TYPE, - msg + msg, + completed ); - // TODO WHERE's the completed flag? - // If the GetState message did not have a value or empty, // then we went to the runtime to get the value. // When we get the response, we set it in the localStateStore, // to answer subsequent requests - if (msg.value === undefined && msg.empty === undefined) { + if (!completed) { this.stateMachine.localStateStore.add(name, result as Buffer | Empty); } @@ -145,12 +150,15 @@ export class RestateContextImpl implements RestateGrpcContext, RpcContext { this.checkState("state keys"); // Create the message and let the state machine process it - const msg = this.stateMachine.localStateStore.getStateKeys(); + const msg = GetStateKeysEntryMessage.create({}); + const completed = + this.stateMachine.localStateStore.tryCompletedGetStateKeys(msg); const getStateKeys = async (): Promise> => { const result = await this.stateMachine.handleUserCodeMessage( GET_STATE_KEYS_ENTRY_MESSAGE_TYPE, - msg + msg, + completed ); return (result as GetStateKeysEntryMessage_StateKeys).keys.map((b) => diff --git a/test/protoutils.ts b/test/protoutils.ts index 52097c16..37fbc63a 100644 --- a/test/protoutils.ts +++ b/test/protoutils.ts @@ -157,7 +157,8 @@ export function getStateMessage( GetStateEntryMessage.create({ key: Buffer.from(key), empty: Empty.create({}), - }) + }), + true ); } else if (value !== undefined) { return new Message( @@ -165,7 +166,8 @@ export function getStateMessage( GetStateEntryMessage.create({ key: Buffer.from(key), value: Buffer.from(jsonSerialize(value)), - }) + }), + true ); } else if (failure !== undefined) { return new Message( @@ -173,21 +175,27 @@ export function getStateMessage( GetStateEntryMessage.create({ key: Buffer.from(key), failure: failure, - }) + }), + true ); } else { return new Message( GET_STATE_ENTRY_MESSAGE_TYPE, GetStateEntryMessage.create({ key: Buffer.from(key), - }) + }), + false ); } } export function getStateKeysMessage(value?: Array): Message { if (value === undefined) { - return new Message(GET_STATE_KEYS_ENTRY_MESSAGE_TYPE, {}); + return new Message( + GET_STATE_KEYS_ENTRY_MESSAGE_TYPE, + GetStateKeysEntryMessage.create({}), + false + ); } else { return new Message( GET_STATE_KEYS_ENTRY_MESSAGE_TYPE, @@ -195,7 +203,8 @@ export function getStateKeysMessage(value?: Array): Message { value: { keys: value.map((b) => Buffer.from(b)), }, - }) + }), + true ); } }