diff --git a/src/promise_combinator_tracker.ts b/src/promise_combinator_tracker.ts index 007264f9..d2396f0c 100644 --- a/src/promise_combinator_tracker.ts +++ b/src/promise_combinator_tracker.ts @@ -32,7 +32,7 @@ export function newJournalEntryPromiseId(entryIndex: number): PromiseId { */ function prepareReplayedPromiseCombinator( combinatorIndex: number, - combinatorConstructor: (promises: Promise[]) => Promise, + combinatorConstructor: (promises: PromiseLike[]) => Promise, promises: Array<{ id: PromiseId; promise: Promise }>, readReplayOrder: (combinatorIndex: number) => PromiseId[] ): WrappedPromise { @@ -133,7 +133,7 @@ export class PromiseCombinatorTracker { ) {} public createCombinatorInReplayMode( - combinatorConstructor: (promises: Promise[]) => Promise, + combinatorConstructor: (promises: PromiseLike[]) => Promise, promises: Array<{ id: PromiseId; promise: Promise }> ): WrappedPromise { const combinatorIndex = this.nextCombinatorIndex; diff --git a/src/restate_context.ts b/src/restate_context.ts index e6675830..16dbe693 100644 --- a/src/restate_context.ts +++ b/src/restate_context.ts @@ -189,6 +189,34 @@ export interface RestateBaseContext { * await ctx.sleep(1000); */ sleep(millis: number): CombineablePromise; + + // --- Promise combinators + + /** + * Creates a Promise that is resolved with an array of results when all of the provided Promises + * resolve, or rejected when any Promise is rejected. + * + * See {@link Promise.all} for more details. + * + * @param values An iterable of Promises. + * @returns A new Promise. + */ + all[] | []>( + values: T + ): Promise<{ -readonly [P in keyof T]: Awaited }>; + + /** + * Creates a Promise that is resolved or rejected when any of the provided Promises are resolved + * or rejected. + * + * See {@link Promise.race} for more details. + * + * @param values An iterable of Promises. + * @returns A new Promise. + */ + race[] | []>( + values: T + ): Promise>; } export interface Rand { diff --git a/src/restate_context_impl.ts b/src/restate_context_impl.ts index aaff7286..19152d95 100644 --- a/src/restate_context_impl.ts +++ b/src/restate_context_impl.ts @@ -62,6 +62,10 @@ import { Client, SendClient } from "./types/router"; import { RpcRequest, RpcResponse } from "./generated/proto/dynrpc"; import { requestFromArgs } from "./utils/assumpsions"; import { RandImpl } from "./utils/rand"; +import { + newJournalEntryPromiseId, + PromiseId, +} from "./promise_combinator_tracker"; export enum CallContexType { None, @@ -225,6 +229,10 @@ export class RestateGrpcContextImpl implements RestateGrpcContext { ); } + rpcGateway(): RpcGateway { + return new RpcContextImpl(this); + } + // DON'T make this function async!!! // The reason is that we want the erros thrown by the initial checks to be propagated in the caller context, // and not in the promise context. To understand the semantic difference, make this function async and run the @@ -344,6 +352,8 @@ export class RestateGrpcContextImpl implements RestateGrpcContext { ); } + // -- Awakeables + public awakeable(): { id: string; promise: CombineablePromise } { this.checkState("awakeable"); @@ -399,6 +409,44 @@ export class RestateGrpcContextImpl implements RestateGrpcContext { ); } + // -- Combinators + + all[] | []>( + values: T + ): Promise<{ -readonly [P in keyof T]: Awaited }> { + return this.stateMachine.createCombinator( + Promise.all.bind(Promise), + this.extractPromisesWithIds(values) + ); + } + + race[] | []>( + values: T + ): Promise> { + return this.stateMachine.createCombinator( + Promise.race.bind(Promise), + this.extractPromisesWithIds(values) + ); + } + + private extractPromisesWithIds( + promises: Iterable> + ): Array<{ id: PromiseId; promise: Promise }> { + const outPromises = []; + + for (const promise of promises) { + const index = (promise as InternalCombineablePromise).journalIndex; + outPromises.push({ + id: newJournalEntryPromiseId(index), + promise: promise, + }); + } + + return outPromises; + } + + // -- Various private methods + private isInSideEffect(): boolean { const context = RestateGrpcContextImpl.callContext.getStore(); return context?.type === CallContexType.SideEffect; @@ -460,10 +508,6 @@ export class RestateGrpcContextImpl implements RestateGrpcContext { }, }) as InternalCombineablePromise; } - - rpcGateway(): RpcGateway { - return new RpcContextImpl(this); - } } async function executeWithRetries( @@ -626,6 +670,18 @@ export class RpcContextImpl implements RpcContext { return this.ctx.sleep(millis); } + all[] | []>( + values: T + ): Promise<{ -readonly [P in keyof T]: Awaited }> { + return this.ctx.all(values); + } + + race[] | []>( + values: T + ): Promise> { + return this.ctx.race(values); + } + grpcChannel(): RestateGrpcChannel { return this.ctx; } diff --git a/src/state_machine.ts b/src/state_machine.ts index cf4be3ca..58d519e6 100644 --- a/src/state_machine.ts +++ b/src/state_machine.ts @@ -194,16 +194,30 @@ export class StateMachine implements RestateStreamConsumer { // -- Methods related to combinators to wire up promise combinator API with PromiseCombinatorTracker - public createCombinator() { - // TODO create combinator - // if replay { combinator in replay mode } else { combinator in processing mode } + public createCombinator( + combinatorConstructor: (promises: PromiseLike[]) => Promise, + promises: Array<{ id: PromiseId; promise: Promise }> + ) { + if (this.stateMachineClosed) { + return WRAPPED_PROMISE_PENDING as WrappedPromise; + } - // We need to wrap deeply again to schedule suspension here! - return wrapDeeply(Promise.resolve("TODO"), () => { - if (this.journal.isUnResolved(0 /* TODO */)) { - this.scheduleSuspension(); - } - }); + if (this.journal.isProcessing()) { + return wrapDeeply( + this.promiseCombinatorTracker.createCombinatorInProcessingMode( + combinatorConstructor, + promises + ), + () => { + this.scheduleSuspension(); + } + ); + } else { + return this.promiseCombinatorTracker.createCombinatorInReplayMode( + combinatorConstructor, + promises + ); + } } readCombinatorOrderEntry(combinatorId: number): PromiseId[] {