diff --git a/src/restate_context_impl.ts b/src/restate_context_impl.ts index e22dcfcb..a128d786 100644 --- a/src/restate_context_impl.ts +++ b/src/restate_context_impl.ts @@ -311,7 +311,7 @@ export class RestateGrpcContextImpl implements RestateGrpcContext { const msg = AwakeableEntryMessage.create(); const promise = this.stateMachine .handleUserCodeMessage(AWAKEABLE_ENTRY_MESSAGE_TYPE, msg) - .then((result: Buffer | void) => { + .then2((result: Buffer | void) => { if (!(result instanceof Buffer)) { // This should either be a filled buffer or an empty buffer but never anything else. throw RetryableError.internal( diff --git a/src/state_machine.ts b/src/state_machine.ts index 2b706e19..6f1d96fb 100644 --- a/src/state_machine.ts +++ b/src/state_machine.ts @@ -114,11 +114,11 @@ export class StateMachine implements RestateStreamConsumer { completedFlag?: boolean, protocolVersion?: number, requiresAckFlag?: boolean - ): Promise { + ): WrappedPromise { // if the state machine is already closed, return a promise that never // completes, so that the user code does not resume if (this.stateMachineClosed) { - return new CompletablePromise().promise; + return wrapDeeply(new CompletablePromise().promise); } const promise = this.journal.handleUserSideMessage(messageType, message); @@ -151,17 +151,14 @@ export class StateMachine implements RestateStreamConsumer { ); } - if (!p.SUSPENSION_TRIGGERS.includes(messageType)) { - return promise; - } - - const proxified: Promise = wrapDeeply(promise, () => { + return wrapDeeply(promise, () => { + if (!p.SUSPENSION_TRIGGERS.includes(messageType)) { + return; + } if (this.journal.isUnResolved(journalIndex)) { this.scheduleSuspension(); } }); - - return proxified; } /** @@ -499,9 +496,40 @@ export class StateMachine implements RestateStreamConsumer { * Returns a promise that wraps the original promise and calls cb() at the first time * this promise or any nested promise that is chained to it is awaited. (then-ed) */ -const wrapDeeply = (promise: Promise, cb: () => void): Promise => { + +/* eslint-disable @typescript-eslint/no-explicit-any */ +export type WrappedPromise = Promise & { + then2: ( + onfulfilled?: + | ((value: T) => TResult1 | PromiseLike) + | null + | undefined, + onrejected?: + | ((reason: any) => TResult2 | PromiseLike) + | null + | undefined + ) => Promise; +}; + +const wrapDeeply = ( + promise: Promise, + cb?: () => void +): WrappedPromise => { /* eslint-disable @typescript-eslint/no-explicit-any */ return { + then2: function ( + onfulfilled?: + | ((value: T) => TResult1 | PromiseLike) + | null + | undefined, + onrejected?: + | ((reason: any) => TResult2 | PromiseLike) + | null + | undefined + ): Promise { + return wrapDeeply(promise.then(onfulfilled, onrejected), cb); + }, + then: function ( onfulfilled?: | ((value: T) => TResult1 | PromiseLike) @@ -512,7 +540,9 @@ const wrapDeeply = (promise: Promise, cb: () => void): Promise => { | null | undefined ): Promise { - cb(); + if (cb !== undefined) { + cb(); + } return promise.then(onfulfilled, onrejected); }, catch: function (