Skip to content

Commit

Permalink
Partial state flag (#160)
Browse files Browse the repository at this point in the history
* Update contract
* Remove flag parsing, replace with StartMessage.partial_state field
  • Loading branch information
slinkydeveloper authored Sep 6, 2023
1 parent 411be14 commit bee23ae
Show file tree
Hide file tree
Showing 8 changed files with 14 additions and 45 deletions.
1 change: 1 addition & 0 deletions proto/protocol.proto
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ message StartMessage {

// protolint:disable:next REPEATED_FIELD_NAMES_PLURALIZED
repeated StateEntry state_map = 4;
bool partial_state = 5;
}

// Type: 0x0000 + 1
Expand Down
12 changes: 3 additions & 9 deletions src/invocation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,7 @@ export class InvocationBuilder<I, O> implements RestateStreamConsumer {
switch (this.state) {
case State.ExpectingStart:
checkState(State.ExpectingStart, START_MESSAGE_TYPE, m);
this.handleStartMessage(
m.message as StartMessage,
m.partialStateFlag || false
);
this.handleStartMessage(m.message as StartMessage);
this.state = State.ExpectingInput;
return false;

Expand Down Expand Up @@ -114,14 +111,11 @@ export class InvocationBuilder<I, O> implements RestateStreamConsumer {
return this.complete.promise;
}

private handleStartMessage(
m: StartMessage,
partialState: boolean
): InvocationBuilder<I, O> {
private handleStartMessage(m: StartMessage): InvocationBuilder<I, O> {
this.nbEntriesToReplay = m.knownEntries;
this.id = m.id;
this.debugId = m.debugId;
this.localStateStore = new LocalStateStore(partialState, m.stateMap);
this.localStateStore = new LocalStateStore(m.partialState, m.stateMap);
return this;
}

Expand Down
3 changes: 1 addition & 2 deletions src/io/decoder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,7 @@ function decodeMessages(decoderState: DecoderState, out: Output): DecoderState {
message,
header.completedFlag,
header.protocolVersion,
header.requiresAckFlag,
header.partialStateFlag
header.requiresAckFlag
)
);
}
Expand Down
3 changes: 1 addition & 2 deletions src/io/encoder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ export function encodeMessage(msg: Message): Uint8Array {
bodyBuf.length,
msg.completed,
msg.protocolVersion, // only set for incoming start message
msg.requiresAck,
msg.partialStateFlag
msg.requiresAck
);
const headerBuf = Buffer.alloc(8);
const encoded = header.toU64be();
Expand Down
19 changes: 2 additions & 17 deletions src/types/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ export class Message {
readonly message: ProtocolMessage,
readonly completed?: boolean,
readonly protocolVersion?: number,
readonly requiresAck?: boolean,
readonly partialStateFlag?: boolean
readonly requiresAck?: boolean
) {}
}

Expand Down Expand Up @@ -54,10 +53,6 @@ class MessageType {
return messageType == START_MESSAGE_TYPE;
}

static hasPartialStateFlag(messageType: bigint): boolean {
return messageType == START_MESSAGE_TYPE;
}

static isCustom(messageTypeId: bigint): boolean {
return !KNOWN_MESSAGE_TYPES.has(messageTypeId);
}
Expand All @@ -71,7 +66,6 @@ const CUSTOM_MESSAGE_MASK = BigInt(0xfc00);
const COMPLETED_MASK = BigInt(0x0001_0000_0000);
const VERSION_MASK = BigInt(0x03ff_0000_0000);
const REQUIRES_ACK_MASK = BigInt(0x0001_0000_0000);
const PARTIAL_STATE_MASK = BigInt(0x0400_0000_0000);

// The header is exported but only for tests.
export class Header {
Expand Down Expand Up @@ -101,20 +95,14 @@ export class Header {
(value & REQUIRES_ACK_MASK) !== 0n
? true
: undefined;
const partialStateFlag =
MessageType.hasPartialStateFlag(messageType) &&
(value & PARTIAL_STATE_MASK) !== 0n
? true
: undefined;
const frameLength = Number(value & 0xffffffffn);

return new Header(
messageType,
frameLength,
completedFlag,
protocolVersion,
requiresAckFlag,
partialStateFlag
requiresAckFlag
);
}

Expand All @@ -129,9 +117,6 @@ export class Header {
if (this.requiresAckFlag) {
res = res | REQUIRES_ACK_MASK;
}
if (this.partialStateFlag) {
res = res | PARTIAL_STATE_MASK;
}
return res;
}
}
13 changes: 2 additions & 11 deletions test/protocol_stream.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,6 @@ describe("Header", () => {

it("invoke_test", () => roundtripTest(newStart(1, 25)));

it("invoke_test_partial_state", () => roundtripTest(newStart(1, 25, true)));
it("invoke_test_complete_state", () =>
roundtripTest(newStart(1, 25, undefined)));

it("completion_test", () =>
roundtripTest(newHeader(COMPLETION_MESSAGE_TYPE, 22)));

Expand Down Expand Up @@ -187,18 +183,13 @@ function newHeader(messageTypeId: bigint, length: number): Header {
return new Header(messageTypeId, length);
}

function newStart(
protocolVersion: number,
length: number,
partialStateFlag?: boolean
): Header {
function newStart(protocolVersion: number, length: number): Header {
return new Header(
START_MESSAGE_TYPE,
length,
undefined,
protocolVersion,
undefined,
partialStateFlag
undefined
);
}

Expand Down
4 changes: 2 additions & 2 deletions test/protoutils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,11 @@ export function startMessage(
debugId: "123",
knownEntries: knownEntries, // only used for the Lambda case. For bidi streaming, this will be imputed by the testdriver
stateMap: toStateEntries(state || []),
partialState: partialState !== false,
}),
undefined,
0,
undefined,
partialState === false ? undefined : true
undefined
);
}

Expand Down
4 changes: 2 additions & 2 deletions test/testdriver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,11 @@ export class TestDriver<I, O> implements Connection {
debugId: startEntry.debugId,
knownEntries: endOfReplay - 1,
stateMap: startEntry.stateMap,
partialState: startEntry.partialState,
}),
msg.completed,
msg.protocolVersion,
msg.requiresAck,
msg.partialStateFlag
msg.requiresAck
);

const replayMessages = entries.slice(0, endOfReplay);
Expand Down

0 comments on commit bee23ae

Please sign in to comment.