Skip to content

Commit

Permalink
Add ctx.stateKeys()
Browse files Browse the repository at this point in the history
  • Loading branch information
slinkydeveloper committed Feb 8, 2024
1 parent 86a228b commit d7357af
Show file tree
Hide file tree
Showing 7 changed files with 180 additions and 17 deletions.
39 changes: 36 additions & 3 deletions src/journal.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
*/

import * as p from "./types/protocol";
import { Failure } from "./generated/proto/protocol";
import {
Failure,
GetStateKeysEntryMessage_StateKeys,
} from "./generated/proto/protocol";
import {
AWAKEABLE_ENTRY_MESSAGE_TYPE,
AwakeableEntryMessage,
Expand All @@ -21,7 +24,9 @@ import {
CompletionMessage,
EntryAckMessage,
GET_STATE_ENTRY_MESSAGE_TYPE,
GET_STATE_KEYS_ENTRY_MESSAGE_TYPE,
GetStateEntryMessage,
GetStateKeysEntryMessage,
INVOKE_ENTRY_MESSAGE_TYPE,
InvokeEntryMessage,
OUTPUT_STREAM_ENTRY_MESSAGE_TYPE,
Expand Down Expand Up @@ -133,6 +138,16 @@ export class Journal<I, O> {
return this.appendJournalEntry(messageType, message);
}
}
case p.GET_STATE_KEYS_ENTRY_MESSAGE_TYPE: {
const getStateMsg = message as GetStateKeysEntryMessage;
if (getStateMsg.value !== undefined) {
// State was eagerly filled by the local state store
return Promise.resolve(getStateMsg.value);
} else {
// Need to retrieve state by going to the runtime.
return this.appendJournalEntry(messageType, message);
}
}
default: {
return this.appendJournalEntry(messageType, message);
}
Expand Down Expand Up @@ -169,8 +184,16 @@ export class Journal<I, O> {
}

if (m.value !== undefined) {
journalEntry.completablePromise.resolve(m.value);
this.pendingJournalEntries.delete(m.entryIndex);
if (journalEntry.messageType === GET_STATE_KEYS_ENTRY_MESSAGE_TYPE) {
// In case of get state keys we expect the parsed message
journalEntry.completablePromise.resolve(
GetStateKeysEntryMessage_StateKeys.decode(m.value)
);
this.pendingJournalEntries.delete(m.entryIndex);
} else {
journalEntry.completablePromise.resolve(m.value);
this.pendingJournalEntries.delete(m.entryIndex);
}
} else if (m.failure !== undefined) {
// we do all completions with Terminal Errors, because failures triggered by those exceptions
// when the bubble up would otherwise lead to re-tries, deterministic replay, re-throwing, and
Expand Down Expand Up @@ -256,6 +279,16 @@ export class Journal<I, O> {
);
break;
}
case GET_STATE_KEYS_ENTRY_MESSAGE_TYPE: {
const getStateMsg = replayMessage.message as GetStateKeysEntryMessage;
this.resolveResult(
journalIndex,
journalEntry,
getStateMsg.value,
getStateMsg.failure
);
break;
}
case INVOKE_ENTRY_MESSAGE_TYPE: {
const invokeMsg = replayMessage.message as InvokeEntryMessage;
this.resolveResult(
Expand Down
14 changes: 8 additions & 6 deletions src/local_state_store.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@

import {
ClearStateEntryMessage,
GetStateEntryMessage, GetStateKeysEntryMessage, SetStateEntryMessage,
GetStateEntryMessage,
GetStateKeysEntryMessage,
SetStateEntryMessage,
StartMessage_StateEntry,
} from "./generated/proto/protocol";
import { Empty } from "./generated/google/protobuf/empty";
Expand Down Expand Up @@ -56,14 +58,14 @@ export class LocalStateStore {

public getStateKeys(): GetStateKeysEntryMessage {
if (this.isPartial) {
return {}
return {};
}

return {
return GetStateKeysEntryMessage.create({
value: {
keys: Array.from(this.state.keys()).map(b => Buffer.from(b))
}
}
keys: Array.from(this.state.keys()).map((b) => Buffer.from(b)),
},
});
}

public set<T>(key: string, value: T): SetStateEntryMessage {
Expand Down
2 changes: 1 addition & 1 deletion src/restate_context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ export interface RestateBaseContext {
*/
get<T>(name: string): Promise<T | null>;

stateKeys(): Promise<Array<string>>
stateKeys(): Promise<Array<string>>;

/**
* Set/store state in the Restate runtime.
Expand Down
20 changes: 13 additions & 7 deletions src/restate_context_impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ import {
AwakeableEntryMessage,
BackgroundInvokeEntryMessage,
CompleteAwakeableEntryMessage,
DeepPartial, GetStateKeysEntryMessage_StateKeys,
DeepPartial,
GetStateKeysEntryMessage_StateKeys,
InvokeEntryMessage,
SleepEntryMessage,
} from "./generated/proto/protocol";
Expand All @@ -33,7 +34,8 @@ import {
BACKGROUND_INVOKE_ENTRY_MESSAGE_TYPE,
CLEAR_STATE_ENTRY_MESSAGE_TYPE,
COMPLETE_AWAKEABLE_ENTRY_MESSAGE_TYPE,
GET_STATE_ENTRY_MESSAGE_TYPE, GET_STATE_KEYS_ENTRY_MESSAGE_TYPE,
GET_STATE_ENTRY_MESSAGE_TYPE,
GET_STATE_KEYS_ENTRY_MESSAGE_TYPE,
INVOKE_ENTRY_MESSAGE_TYPE,
SET_STATE_ENTRY_MESSAGE_TYPE,
SIDE_EFFECT_ENTRY_MESSAGE_TYPE,
Expand Down Expand Up @@ -118,6 +120,8 @@ export class RestateContextImpl implements RestateGrpcContext, RpcContext {
msg
);

// TODO WHERE's the completed flag?

// If the GetState message did not have a value or empty,
// then we went to the runtime to get the value.
// When we get the response, we set it in the localStateStore,
Expand All @@ -143,15 +147,17 @@ export class RestateContextImpl implements RestateGrpcContext, RpcContext {
// Create the message and let the state machine process it
const msg = this.stateMachine.localStateStore.getStateKeys();

const getState = async (): Promise<Array<string>> => {
const getStateKeys = async (): Promise<Array<string>> => {
const result = await this.stateMachine.handleUserCodeMessage(
GET_STATE_KEYS_ENTRY_MESSAGE_TYPE,
msg
GET_STATE_KEYS_ENTRY_MESSAGE_TYPE,
msg
);

return (result as GetStateKeysEntryMessage_StateKeys).keys.map(b => b.toString());
return (result as GetStateKeysEntryMessage_StateKeys).keys.map((b) =>
b.toString()
);
};
return getState();
return getStateKeys();
}

public set<T>(name: string, value: T): void {
Expand Down
2 changes: 2 additions & 0 deletions src/types/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import {
ENTRY_ACK_MESSAGE_TYPE,
ERROR_MESSAGE_TYPE,
GET_STATE_ENTRY_MESSAGE_TYPE,
GET_STATE_KEYS_ENTRY_MESSAGE_TYPE,
KNOWN_MESSAGE_TYPES,
POLL_INPUT_STREAM_ENTRY_MESSAGE_TYPE,
ProtocolMessage,
Expand Down Expand Up @@ -48,6 +49,7 @@ class MessageType {
return (
messageType === POLL_INPUT_STREAM_ENTRY_MESSAGE_TYPE ||
messageType === GET_STATE_ENTRY_MESSAGE_TYPE ||
messageType === GET_STATE_KEYS_ENTRY_MESSAGE_TYPE ||
messageType === SLEEP_ENTRY_MESSAGE_TYPE ||
messageType === AWAKEABLE_ENTRY_MESSAGE_TYPE
);
Expand Down
17 changes: 17 additions & 0 deletions test/protoutils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ import {
EndMessage,
AWAKEABLE_IDENTIFIER_PREFIX,
COMBINATOR_ENTRY_MESSAGE,
GET_STATE_KEYS_ENTRY_MESSAGE_TYPE,
GetStateKeysEntryMessage,
} from "../src/types/protocol";
import { Message } from "../src/types/types";
import { TestRequest, TestResponse } from "../src/generated/proto/test";
Expand Down Expand Up @@ -183,6 +185,21 @@ export function getStateMessage<T>(
}
}

export function getStateKeysMessage(value?: Array<string>): Message {
if (value === undefined) {
return new Message(GET_STATE_KEYS_ENTRY_MESSAGE_TYPE, {});
} else {
return new Message(
GET_STATE_KEYS_ENTRY_MESSAGE_TYPE,
GetStateKeysEntryMessage.create({
value: {
keys: value.map((b) => Buffer.from(b)),
},
})
);
}
}

export function setStateMessage<T>(key: string, value: T): Message {
return new Message(
SET_STATE_ENTRY_MESSAGE_TYPE,
Expand Down
103 changes: 103 additions & 0 deletions test/state_keys.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* Copyright (c) 2023 - Restate Software, Inc., Restate GmbH
*
* This file is part of the Restate SDK for Node.js/TypeScript,
* which is released under the MIT license.
*
* You can find a copy of the license in file LICENSE in the root
* directory of this repository or package, or at
* https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
*/

import { describe, expect } from "@jest/globals";
import * as restate from "../src/public_api";
import { TestDriver } from "./testdriver";
import {
completionMessage,
END_MESSAGE,
getStateKeysMessage,
greetRequest,
greetResponse,
inputMessage,
keyVal,
outputMessage,
startMessage,
suspensionMessage,
} from "./protoutils";
import { TestGreeter, TestResponse } from "../src/generated/proto/test";
import { GetStateKeysEntryMessage_StateKeys } from "../src/generated/proto/protocol";

const INPUT_MESSAGE = inputMessage(greetRequest(""));

function stateKeys(...keys: Array<string>): GetStateKeysEntryMessage_StateKeys {
return {
keys: keys.map((b) => Buffer.from(b)),
};
}

class ListKeys implements TestGreeter {
async greet(): Promise<TestResponse> {
const ctx = restate.useContext(this);

return {
greeting: (await ctx.stateKeys()).join(","),
};
}
}

describe("ListKeys", () => {
it("with partial state suspends", async () => {
const result = await new TestDriver(new ListKeys(), [
startMessage(1, true, [keyVal("A", "1")]),
INPUT_MESSAGE,
]).run();

expect(result).toStrictEqual([
getStateKeysMessage(),
suspensionMessage([1]),
]);
});

it("with partial state", async () => {
const result = await new TestDriver(new ListKeys(), [
startMessage(1, true, [keyVal("A", "1")]),
INPUT_MESSAGE,
completionMessage(
1,
GetStateKeysEntryMessage_StateKeys.encode(stateKeys("B", "C")).finish()
),
]).run();

expect(result).toStrictEqual([
getStateKeysMessage(),
outputMessage(greetResponse("B,C")),
END_MESSAGE,
]);
});

it("with complete state", async () => {
const result = await new TestDriver(new ListKeys(), [
startMessage(1, false, [keyVal("A", "1")]),
INPUT_MESSAGE,
]).run();

expect(result).toStrictEqual([
getStateKeysMessage(["A"]),
outputMessage(greetResponse("A")),
END_MESSAGE,
]);
});

it("replay", async () => {
const result = await new TestDriver(new ListKeys(), [
startMessage(1, true, [keyVal("A", "1")]),
INPUT_MESSAGE,
getStateKeysMessage(["A", "B", "C"]),
]).run();

expect(result).toStrictEqual([
outputMessage(greetResponse("A,B,C")),
END_MESSAGE,
]);
});
});

0 comments on commit d7357af

Please sign in to comment.