Skip to content

Commit

Permalink
Client returns CombineablePromise (#237)
Browse files Browse the repository at this point in the history
* Join RestateGrpcContextImpl and RpcContextImpl into a single implementation, RestateContextImpl. This simplifies the implementation, allowing to share private methods.
* Mark invoke return promise as combineable.
  • Loading branch information
slinkydeveloper authored Jan 30, 2024
1 parent 56a00bf commit 33ad948
Show file tree
Hide file tree
Showing 8 changed files with 120 additions and 155 deletions.
3 changes: 1 addition & 2 deletions src/embedded/handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,13 @@
*/

import { RpcContext, useContext } from "../restate_context";
import { RpcContextImpl } from "../restate_context_impl";
import { GrpcServiceMethod, HostedGrpcServiceMethod } from "../types/grpc";

export function wrapHandler<I, O>(
handler: (ctx: RpcContext, input: I) => Promise<O>
): HostedGrpcServiceMethod<I, O> {
const localMethod = (instance: unknown, input: I): Promise<O> => {
const ctx = new RpcContextImpl(useContext(instance));
const ctx = useContext(instance) as unknown as RpcContext;
return handler(ctx, input);
};

Expand Down
34 changes: 17 additions & 17 deletions src/restate_context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import { RetrySettings } from "./utils/public_utils";
import { Client, SendClient } from "./types/router";
import { RestateGrpcContextImpl } from "./restate_context_impl";
import { RestateContextImpl } from "./restate_context_impl";

/**
* A promise that can be combined using Promise combinators in RestateContext.
Expand Down Expand Up @@ -238,9 +238,10 @@ export const CombineablePromise = {
return Promise.all(values);
}

return (
values[0].__restate_context as RestateGrpcContextImpl
).createCombinator(Promise.all.bind(Promise), values) as Promise<{
return (values[0].__restate_context as RestateContextImpl).createCombinator(
Promise.all.bind(Promise),
values
) as Promise<{
-readonly [P in keyof T]: Awaited<T[P]>;
}>;
},
Expand All @@ -261,11 +262,10 @@ export const CombineablePromise = {
return Promise.race(values);
}

return (
values[0].__restate_context as RestateGrpcContextImpl
).createCombinator(Promise.race.bind(Promise), values) as Promise<
Awaited<T[number]>
>;
return (values[0].__restate_context as RestateContextImpl).createCombinator(
Promise.race.bind(Promise),
values
) as Promise<Awaited<T[number]>>;
},

/**
Expand All @@ -285,11 +285,10 @@ export const CombineablePromise = {
return Promise.any(values);
}

return (
values[0].__restate_context as RestateGrpcContextImpl
).createCombinator(Promise.any.bind(Promise), values) as Promise<
Awaited<T[number]>
>;
return (values[0].__restate_context as RestateContextImpl).createCombinator(
Promise.any.bind(Promise),
values
) as Promise<Awaited<T[number]>>;
},

/**
Expand All @@ -310,9 +309,10 @@ export const CombineablePromise = {
return Promise.allSettled(values);
}

return (
values[0].__restate_context as RestateGrpcContextImpl
).createCombinator(Promise.allSettled.bind(Promise), values) as Promise<{
return (values[0].__restate_context as RestateContextImpl).createCombinator(
Promise.allSettled.bind(Promise),
values
) as Promise<{
-readonly [P in keyof T]: PromiseSettledResult<Awaited<T[P]>>;
}>;
},
Expand Down
204 changes: 85 additions & 119 deletions src/restate_context_impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,12 @@ export interface CallContext {
delay?: number;
}

export type InternalCombineablePromise<T> = CombineablePromise<T> & {
journalIndex: number;
};
export type InternalCombineablePromise<T> = CombineablePromise<T> &
WrappedPromise<T> & {
journalIndex: number;
};

export class RestateGrpcContextImpl implements RestateGrpcContext {
export class RestateContextImpl implements RestateGrpcContext, RpcContext {
// here, we capture the context information for actions on the Restate context that
// are executed within other actions, such as
// ctx.oneWayCall( () => client.foo(bar) );
Expand Down Expand Up @@ -153,13 +154,20 @@ export class RestateGrpcContextImpl implements RestateGrpcContext {
);
}

// --- Calls, background calls, etc

public request(
service: string,
method: string,
data: Uint8Array
): Promise<Uint8Array> {
if (this.isInOneWayCall()) {
return this.invokeOneWay(service, method, data);
return this.invokeOneWay(
service,
method,
data,
this.getOneWayCallDelay()
);
} else {
return this.invoke(service, method, data);
}
Expand All @@ -170,26 +178,29 @@ export class RestateGrpcContextImpl implements RestateGrpcContext {
service: string,
method: string,
data: Uint8Array
): Promise<Uint8Array> {
): InternalCombineablePromise<Uint8Array> {
this.checkState("invoke");

const msg = InvokeEntryMessage.create({
serviceName: service,
methodName: method,
parameter: Buffer.from(data),
});
return this.stateMachine
.handleUserCodeMessage(INVOKE_ENTRY_MESSAGE_TYPE, msg)
.transform((v) => v as Uint8Array);
return this.markCombineablePromise(
this.stateMachine
.handleUserCodeMessage(INVOKE_ENTRY_MESSAGE_TYPE, msg)
.transform((v) => v as Uint8Array)
);
}

private async invokeOneWay(
service: string,
method: string,
data: Uint8Array
data: Uint8Array,
delay?: number
): Promise<Uint8Array> {
const delay = this.getOneWayCallDelay();
const invokeTime = delay > 0 ? Date.now() + delay : undefined;
const actualDelay = delay || 0;
const invokeTime = actualDelay > 0 ? Date.now() + actualDelay : undefined;
const msg = BackgroundInvokeEntryMessage.create({
serviceName: service,
methodName: method,
Expand All @@ -211,7 +222,7 @@ export class RestateGrpcContextImpl implements RestateGrpcContext {
): Promise<void> {
this.checkState("oneWayCall");

return RestateGrpcContextImpl.callContext.run(
return RestateContextImpl.callContext.run(
{ type: CallContexType.OneWayCall },
call
);
Expand All @@ -226,14 +237,65 @@ export class RestateGrpcContextImpl implements RestateGrpcContext {
this.checkState("delayedCall");

// Delayed call is a one way call with a delay
return RestateGrpcContextImpl.callContext.run(
return RestateContextImpl.callContext.run(
{ type: CallContexType.OneWayCall, delay: delayMillis },
call
);
}

rpc<M>({ path }: ServiceApi<M>): Client<M> {
const clientProxy = new Proxy(
{},
{
get: (_target, prop) => {
const route = prop as string;
return async (...args: unknown[]) => {
const request = requestFromArgs(args);
const requestBytes = RpcRequest.encode(request).finish();
return this.invoke(path, route, requestBytes).transform(
(responseBytes) => RpcResponse.decode(responseBytes).response
);
};
},
}
);

return clientProxy as Client<M>;
}

public send<M>(options: ServiceApi): SendClient<M> {
return this.sendDelayed(options, 0);
}

public sendDelayed<M>(
{ path }: ServiceApi,
delayMillis: number
): SendClient<M> {
const clientProxy = new Proxy(
{},
{
get: (_target, prop) => {
const route = prop as string;
return (...args: unknown[]) => {
const request = requestFromArgs(args);
const requestBytes = RpcRequest.encode(request).finish();
this.invokeOneWay(path, route, requestBytes, delayMillis);
};
},
}
);

return clientProxy as SendClient<M>;
}

// --- Methods exposed by respective interfaces to interact with other APIs

grpcChannel(): RestateGrpcChannel {
return this;
}

rpcGateway(): RpcGateway {
return new RpcContextImpl(this);
return this;
}

// DON'T make this function async!!!
Expand Down Expand Up @@ -272,7 +334,7 @@ export class RestateGrpcContextImpl implements RestateGrpcContext {

let sideEffectResult: T;
try {
sideEffectResult = await RestateGrpcContextImpl.callContext.run(
sideEffectResult = await RestateContextImpl.callContext.run(
{ type: CallContexType.SideEffect },
fn
);
Expand Down Expand Up @@ -447,18 +509,18 @@ export class RestateGrpcContextImpl implements RestateGrpcContext {
// -- Various private methods

private isInSideEffect(): boolean {
const context = RestateGrpcContextImpl.callContext.getStore();
const context = RestateContextImpl.callContext.getStore();
return context?.type === CallContexType.SideEffect;
}

private isInOneWayCall(): boolean {
const context = RestateGrpcContextImpl.callContext.getStore();
const context = RestateContextImpl.callContext.getStore();
return context?.type === CallContexType.OneWayCall;
}

private getOneWayCallDelay(): number {
const context = RestateGrpcContextImpl.callContext.getStore();
return context?.delay || 0;
private getOneWayCallDelay(): number | undefined {
const context = RestateContextImpl.callContext.getStore();
return context?.delay;
}

private checkNotExecutingSideEffect() {
Expand All @@ -472,7 +534,7 @@ export class RestateGrpcContextImpl implements RestateGrpcContext {
}

private checkState(callType: string): void {
const context = RestateGrpcContextImpl.callContext.getStore();
const context = RestateContextImpl.callContext.getStore();
if (!context) {
this.checkNotExecutingSideEffect();
return;
Expand All @@ -496,7 +558,7 @@ export class RestateGrpcContextImpl implements RestateGrpcContext {
}

private markCombineablePromise<T>(
p: Promise<T>
p: WrappedPromise<T>
): InternalCombineablePromise<T> {
const journalIndex = this.stateMachine.getUserCodeJournalIndex();
const orTimeout = (millis: number): Promise<T> => {
Expand Down Expand Up @@ -606,99 +668,3 @@ async function executeWithRetries<T>(
);
}
}

export class RpcContextImpl implements RpcContext {
constructor(
private readonly ctx: RestateGrpcContext,
public readonly id: Buffer = ctx.id,
public readonly rand: Rand = ctx.rand,
public readonly console: Console = ctx.console,
public readonly serviceName: string = ctx.serviceName
) {}

public rpc<M>({ path }: ServiceApi): Client<M> {
const clientProxy = new Proxy(
{},
{
get: (_target, prop) => {
const route = prop as string;
return async (...args: unknown[]) => {
const request = requestFromArgs(args);
const requestBytes = RpcRequest.encode(request).finish();
const responseBytes = await this.ctx.request(
path,
route,
requestBytes
);
const response = RpcResponse.decode(responseBytes);
return response.response;
};
},
}
);

return clientProxy as Client<M>;
}

public send<M>(options: ServiceApi): SendClient<M> {
return this.sendDelayed(options, 0);
}

public sendDelayed<M>(
{ path }: ServiceApi,
delayMillis: number
): SendClient<M> {
const clientProxy = new Proxy(
{},
{
get: (_target, prop) => {
const route = prop as string;
return (...args: unknown[]) => {
const request = requestFromArgs(args);
const requestBytes = RpcRequest.encode(request).finish();
const sender = () => this.ctx.request(path, route, requestBytes);
if (delayMillis === undefined || delayMillis === 0) {
this.ctx.oneWayCall(sender);
} else {
this.ctx.delayedCall(sender, delayMillis);
}
};
},
}
);

return clientProxy as SendClient<M>;
}

public get<T>(name: string): Promise<T | null> {
return this.ctx.get(name);
}
public set<T>(name: string, value: T): void {
this.ctx.set(name, value);
}
public clear(name: string): void {
this.ctx.clear(name);
}
public sideEffect<T>(
fn: () => Promise<T>,
retryPolicy?: RetrySettings
): Promise<T> {
return this.ctx.sideEffect(fn, retryPolicy);
}
public awakeable<T>(): { id: string; promise: CombineablePromise<T> } {
return this.ctx.awakeable();
}
public resolveAwakeable<T>(id: string, payload: T): void {
this.ctx.resolveAwakeable(id, payload);
}
public rejectAwakeable(id: string, reason: string): void {
this.ctx.rejectAwakeable(id, reason);
}
public sleep(millis: number): CombineablePromise<void> {
return this.ctx.sleep(millis);
}

grpcChannel(): RestateGrpcChannel {
return this.ctx;
}
}
Loading

0 comments on commit 33ad948

Please sign in to comment.