diff --git a/buf.lock b/buf.lock index 84e00285..c91b5810 100644 --- a/buf.lock +++ b/buf.lock @@ -1,8 +1,2 @@ # Generated by buf. DO NOT EDIT. version: v1 -deps: - - remote: buf.build - owner: restatedev - repository: proto - commit: 6ea2d15aed8f408590a1465844df5a8e - digest: shake256:e6599809ff13490a631f87d1a4b13ef1886d1bd1c0aa001ccb92806c0acc373d047a6ead761f8a21dfbd57a4fd9acd5915a52e47bd5b4e4a02dd1766f78511b3 diff --git a/buf.yaml b/buf.yaml index 8e7714ae..9d1844be 100644 --- a/buf.yaml +++ b/buf.yaml @@ -1,6 +1,6 @@ version: v1 -deps: - - buf.build/restatedev/proto +#deps: +# - buf.build/restatedev/proto build: excludes: - node_modules diff --git a/examples/example.ts b/examples/example.ts index 1002de7d..2425f826 100644 --- a/examples/example.ts +++ b/examples/example.ts @@ -25,14 +25,14 @@ const greeter = restate.router({ greet: async (ctx: restate.Context, name: string) => { // blocking RPC call to a keyed service (here supplying type and path separately) const countSoFar = await ctx - .rpc({ path: "counter" }) + .object({ path: "counter" }, name) .count(name); const message = `Hello ${name}, for the ${countSoFar + 1}th time!`; // sending messages to ourselves, immediately and delayed - ctx.send(greeterApi).logger(message); - ctx.sendDelayed(greeterApi, 100).logger("delayed " + message); + ctx.serviceSend(greeterApi).logger(message); + ctx.serviceSendDelayed(greeterApi, 100).logger("delayed " + message); return message; }, diff --git a/examples/handler_example.ts b/examples/handler_example.ts deleted file mode 100644 index dc13f95e..00000000 --- a/examples/handler_example.ts +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH - * - * This file is part of the Restate SDK for Node.js/TypeScript, - * which is released under the MIT license. - * - * You can find a copy of the license in file LICENSE in the root - * directory of this repository or package, or at - * https://github.com/restatedev/sdk-typescript/blob/main/LICENSE - */ - -/* - * A simple example program showing how to let services listen to events produced - * by systems like Kafka. - */ - -import * as restate from "../src/public_api"; - -type UserProfile = { - id: string; - name: string; - email: string; -}; - -const profileService = restate.keyedRouter({ - registration: restate.keyedEventHandler( - async (ctx: restate.KeyedContext, event: restate.Event) => { - // store in state the user's information as coming from the registeration event - const { name } = event.json<{ name: string }>(); - ctx.set("name", name); - } - ), - - email: restate.keyedEventHandler( - async (ctx: restate.KeyedContext, event: restate.Event) => { - // store in state the user's information as coming from the email event - const { email } = event.json<{ email: string }>(); - ctx.set("email", email); - } - ), - - get: async (ctx: restate.KeyedContext, id: string): Promise => { - return { - id, - name: (await ctx.get("name")) ?? "", - email: (await ctx.get("email")) ?? "", - }; - }, -}); - -// restate server -restate.endpoint().bindKeyedRouter("profile", profileService).listen(9080); diff --git a/proto/discovery.proto b/proto/discovery.proto deleted file mode 100644 index 9bd04314..00000000 --- a/proto/discovery.proto +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH - * - * This file is part of the Restate SDK for Node.js/TypeScript, - * which is released under the MIT license. - * - * You can find a copy of the license in file LICENSE in the root - * directory of this repository or package, or at - * https://github.com/restatedev/sdk-typescript/blob/main/LICENSE - */ - -syntax = "proto3"; - -package dev.restate.service.discovery; - -import "google/protobuf/descriptor.proto"; - -option java_package = "dev.restate.generated.service.discovery"; -option go_package = "restate.dev/sdk-go/pb/service/discovery"; - - -// --- Service discovery endpoint --- -// Request: POST /discover with application/proto containing ServiceDiscoveryRequest -// Response: application/proto containing ServiceDiscoveryResponse - -message ServiceDiscoveryRequest { -} - -enum ProtocolMode { - // protolint:disable:next ENUM_FIELD_NAMES_ZERO_VALUE_END_WITH - BIDI_STREAM = 0; - REQUEST_RESPONSE = 1; -} - -message ServiceDiscoveryResponse { - // List of all proto files used to define the services, including the dependencies. - google.protobuf.FileDescriptorSet files = 1; - - // List of services to register. This might be a subset of services defined in files. - repeated string services = 2; - - // Service-protocol version negotiation - uint32 min_protocol_version = 3; - uint32 max_protocol_version = 4; - - // Protocol mode negotiation - ProtocolMode protocol_mode = 5; -} diff --git a/proto/dynrpc.proto b/proto/dynrpc.proto deleted file mode 100644 index c6101067..00000000 --- a/proto/dynrpc.proto +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH - * - * This file is part of the Restate SDK for Node.js/TypeScript, - * which is released under the MIT license. - * - * You can find a copy of the license in file LICENSE in the root - * directory of this repository or package, or at - * https://github.com/restatedev/sdk-typescript/blob/main/LICENSE - */ - -syntax = "proto3"; - -import "dev/restate/ext.proto"; -import "dev/restate/events.proto"; -import "google/protobuf/struct.proto"; -import "google/protobuf/empty.proto"; - -service RpcEndpoint { - option (dev.restate.ext.service_type) = KEYED; - - rpc call(RpcRequest) returns (RpcResponse) {}; - - rpc handle(KeyedEvent) returns (google.protobuf.Empty) {}; -} - -message KeyedEvent { - string key = 1 [(dev.restate.ext.field) = KEY]; - bytes payload = 3 [(dev.restate.ext.field) = EVENT_PAYLOAD]; - map attributes = 15 [(dev.restate.ext.field) = EVENT_METADATA]; -} - -service UnkeyedRpcEndpoint { - option (dev.restate.ext.service_type) = UNKEYED; - - rpc call(RpcRequest) returns (RpcResponse) {}; -} - -message RpcRequest { - string key = 1 [(dev.restate.ext.field) = KEY]; - google.protobuf.Value request = 2; - - // internal: see src/utils/assumptions.ts - int32 sender_assumes = 101; -} - -message RpcResponse { - google.protobuf.Value response = 1; -} diff --git a/src/context.ts b/src/context.ts index c64023ba..7575128b 100644 --- a/src/context.ts +++ b/src/context.ts @@ -249,7 +249,9 @@ export interface Context { * const result2 = await ctx.rpc(myApi).anotherAction(1337); * ``` */ - rpc(opts: ServiceApi): Client; + service(opts: ServiceApi): Client; + + object(opts: ServiceApi, key: string): Client; /** * Makes a type-safe one-way RPC to the specified target service. This method effectively behaves @@ -290,7 +292,8 @@ export interface Context { * ctx.send(myApi).anotherAction(1337); * ``` */ - send(opts: ServiceApi): SendClient; + objectSend(opts: ServiceApi, key: string): SendClient; + serviceSend(opts: ServiceApi): SendClient; /** * Makes a type-safe one-way RPC to the specified target service, after a delay specified by the @@ -337,7 +340,12 @@ export interface Context { * ctx.sendDelayed(myApi, 60_000).anotherAction(1337); * ``` */ - sendDelayed(opts: ServiceApi, delay: number): SendClient; + objectSendDelayed( + opts: ServiceApi, + delay: number, + key: string + ): SendClient; + serviceSendDelayed(opts: ServiceApi, delay: number): SendClient; } /** @@ -352,7 +360,9 @@ export interface Context { * This context can be used only within keyed services/routers. * */ -export interface KeyedContext extends Context, KeyValueStore {} +export interface KeyedContext extends Context, KeyValueStore { + key(): string; +} export interface Rand { /** diff --git a/src/context_impl.ts b/src/context_impl.ts index 91d98129..9b333bfd 100644 --- a/src/context_impl.ts +++ b/src/context_impl.ts @@ -57,11 +57,11 @@ import { RetrySettings, } from "./utils/public_utils"; import { Client, SendClient } from "./types/router"; -import { RpcRequest, RpcResponse } from "./generated/proto/dynrpc"; -import { requestFromArgs } from "./utils/assumptions"; import { RandImpl } from "./utils/rand"; import { newJournalEntryPromiseId } from "./promise_combinator_tracker"; import { WrappedPromise } from "./utils/promises"; +import { Buffer } from "node:buffer"; +import { deserializeJson, serializeJson } from "./utils/serde"; export enum CallContexType { None, @@ -97,11 +97,19 @@ export class ContextImpl implements KeyedContext { public readonly serviceName: string, public readonly console: Console, public readonly keyedContext: boolean, + public readonly keyedContextKey: string | undefined, // eslint-disable-next-line @typescript-eslint/no-explicit-any - private readonly stateMachine: StateMachine, + private readonly stateMachine: StateMachine, public readonly rand: Rand = new RandImpl(id) ) {} + public key(): string { + if (!this.keyedContextKey) { + throw new TerminalError("unexpected missing key"); + } + return this.keyedContextKey; + } + // DON'T make this function async!!! see sideEffect comment for details. public get(name: string): Promise { // Check if this is a valid action @@ -214,7 +222,8 @@ export class ContextImpl implements KeyedContext { private invoke( service: string, method: string, - data: Uint8Array + data: Uint8Array, + key?: string ): InternalCombineablePromise { this.checkState("invoke"); @@ -222,6 +231,7 @@ export class ContextImpl implements KeyedContext { serviceName: service, methodName: method, parameter: Buffer.from(data), + key, }); return this.markCombineablePromise( this.stateMachine @@ -234,7 +244,8 @@ export class ContextImpl implements KeyedContext { service: string, method: string, data: Uint8Array, - delay?: number + delay?: number, + key?: string ): Promise { const actualDelay = delay || 0; const invokeTime = actualDelay > 0 ? Date.now() + actualDelay : undefined; @@ -243,6 +254,7 @@ export class ContextImpl implements KeyedContext { methodName: method, parameter: Buffer.from(data), invokeTime: invokeTime, + key, }); await this.stateMachine.handleUserCodeMessage( @@ -280,17 +292,17 @@ export class ContextImpl implements KeyedContext { ); } - rpc({ path }: ServiceApi): Client { + service({ path }: ServiceApi): Client { const clientProxy = new Proxy( {}, { get: (_target, prop) => { const route = prop as string; - return async (...args: unknown[]) => { - const request = requestFromArgs(args); - const requestBytes = RpcRequest.encode(request).finish(); + return (...args: unknown[]) => { + const request = args.shift() ?? {}; + const requestBytes = serializeJson(request); return this.invoke(path, route, requestBytes).transform( - (responseBytes) => RpcResponse.decode(responseBytes).response + (responseBytes) => deserializeJson(responseBytes) ); }; }, @@ -300,11 +312,32 @@ export class ContextImpl implements KeyedContext { return clientProxy as Client; } - public send(options: ServiceApi): SendClient { - return this.sendDelayed(options, 0); + object({ path }: ServiceApi, key: string): Client { + const clientProxy = new Proxy( + {}, + { + get: (_target, prop) => { + const route = prop as string; + return (...args: unknown[]) => { + args.shift(); // TODO: FIXME: remove once we fix the types.ts + const request = args.shift() ?? {}; + const requestBytes = serializeJson(request); + return this.invoke(path, route, requestBytes, key).transform( + (responseBytes) => deserializeJson(responseBytes) + ); + }; + }, + } + ); + + return clientProxy as Client; } - public sendDelayed( + public serviceSend(options: ServiceApi): SendClient { + return this.serviceSendDelayed(options, 0); + } + + public serviceSendDelayed( { path }: ServiceApi, delayMillis: number ): SendClient { @@ -314,8 +347,8 @@ export class ContextImpl implements KeyedContext { get: (_target, prop) => { const route = prop as string; return (...args: unknown[]) => { - const request = requestFromArgs(args); - const requestBytes = RpcRequest.encode(request).finish(); + const request = args.shift() ?? null; + const requestBytes = serializeJson(request); this.invokeOneWay(path, route, requestBytes, delayMillis); }; }, @@ -325,6 +358,33 @@ export class ContextImpl implements KeyedContext { return clientProxy as SendClient; } + public objectSend(options: ServiceApi, key: string): SendClient { + return this.objectSendDelayed(options, 0, key); + } + + public objectSendDelayed( + { path }: ServiceApi, + delayMillis: number, + key: string + ): SendClient { + const clientProxy = new Proxy( + {}, + { + get: (_target, prop) => { + const route = prop as string; + return (...args: unknown[]) => { + args.shift(); // TODO: FIXME: remove once we fix the types.ts + const request = args.shift() ?? null; + const requestBytes = serializeJson(request); + this.invokeOneWay(path, route, requestBytes, delayMillis, key); + }; + }, + } + ); + + return clientProxy as SendClient; + } + // DON'T make this function async!!! // The reason is that we want the erros thrown by the initial checks to be propagated in the caller context, // and not in the promise context. To understand the semantic difference, make this function async and run the diff --git a/src/endpoint/endpoint_impl.ts b/src/endpoint/endpoint_impl.ts index ff75008b..d762a65d 100644 --- a/src/endpoint/endpoint_impl.ts +++ b/src/endpoint/endpoint_impl.ts @@ -11,72 +11,38 @@ /* eslint-disable @typescript-eslint/ban-types */ -import { - GrpcService, - GrpcServiceMethod, - HostedGrpcServiceMethod, - ProtoMetadata, -} from "../types/grpc"; -import { - DeepPartial, - ServiceDiscoveryResponse, -} from "../generated/proto/discovery"; -import { Event } from "../types/types"; -import { - FileDescriptorProto, - UninterpretedOption, -} from "../generated/google/protobuf/descriptor"; -import { Empty } from "../generated/google/protobuf/empty"; -import { - FileDescriptorProto as FileDescriptorProto1, - ServiceDescriptorProto as ServiceDescriptorProto1, - MethodDescriptorProto as MethodDescriptorProto1, -} from "ts-proto-descriptors"; -import { - fieldTypeToJSON, - ServiceType, - serviceTypeToJSON, -} from "../generated/dev/restate/ext"; -import { - RpcRequest, - RpcResponse, - ProtoMetadata as RpcServiceProtoMetadata, - protoMetadata as rpcServiceProtoMetadata, - KeyedEvent, -} from "../generated/proto/dynrpc"; -import { Context, KeyedContext, useContext, useKeyedContext } from "../context"; -import { verifyAssumptions } from "../utils/assumptions"; -import { RestateEndpoint, ServiceBundle, TerminalError } from "../public_api"; -import { KeyedRouter, UnKeyedRouter, isEventHandler } from "../types/router"; -import { jsonSafeAny } from "../utils/utils"; +import { RestateEndpoint, ServiceBundle } from "../public_api"; +import { KeyedRouter, UnKeyedRouter } from "../types/router"; import { rlog } from "../logger"; import http2, { Http2ServerRequest, Http2ServerResponse } from "http2"; import { Http2Handler } from "./http2_handler"; import { LambdaHandler } from "./lambda_handler"; +import { + Component, + ServiceComponent, + ServiceHandlerFunction, + VirtualObjectHandlerFunction, + VritualObjectComponent as VritualObjectComponent, +} from "../types/componenets"; + +import * as discovery from "../types/discovery"; export class EndpointImpl implements RestateEndpoint { - protected readonly methods: Record< - string, - HostedGrpcServiceMethod - > = {}; - readonly discovery: DeepPartial; - protected readonly dynrpcDescriptor: RpcServiceProtoMetadata; + private readonly componenets: Map = new Map(); - public constructor() { - this.discovery = { - files: { file: [] }, - services: [], - minProtocolVersion: 0, - maxProtocolVersion: 0, - }; - this.dynrpcDescriptor = copyProtoMetadata(rpcServiceProtoMetadata); + public componenetByName(componenetName: string): Component | undefined { + return this.componenets.get(componenetName); + } + + public addComponenet(componenet: Component) { + this.componenets.set(componenet.name(), componenet); } public bindRouter( path: string, router: UnKeyedRouter ): RestateEndpoint { - this.bindRpcService(path, router, false); + this.bindServiceComponent(path, router); return this; } @@ -84,7 +50,7 @@ export class EndpointImpl implements RestateEndpoint { path: string, router: KeyedRouter ): RestateEndpoint { - this.bindRpcService(path, router, true); + this.bindVirtualObjectComponent(path, router); return this; } @@ -117,483 +83,53 @@ export class EndpointImpl implements RestateEndpoint { return new Promise(() => {}); } - // Private methods to build the endpoint - - private addDescriptor(descriptor: ProtoMetadata) { - const desc = FileDescriptorProto.fromPartial(descriptor.fileDescriptor); - - // extract out service options and put into the fileDescriptor - for (const name in descriptor.options?.services) { - if ( - descriptor.options?.services[name]?.options?.service_type !== undefined - ) { - desc.service - .find((desc) => desc.name === name) - ?.options?.uninterpretedOption.push( - UninterpretedOption.fromPartial({ - name: [ - { namePart: "dev.restate.ext.service_type", isExtension: true }, - ], - identifierValue: serviceTypeToJSON( - descriptor.options?.services[name]?.options?.service_type - ), - }) - ); - } - } - - // extract out field options and put into the fileDescriptor - for (const messageName in descriptor.options?.messages) { - for (const fieldName in descriptor.options?.messages[messageName] - ?.fields) { - const fields = descriptor.options?.messages[messageName]?.fields || {}; - if (fields[fieldName]["field"] !== undefined) { - desc.messageType - .find((desc) => desc.name === messageName) - ?.field?.find((desc) => desc.name === fieldName) - ?.options?.uninterpretedOption.push( - UninterpretedOption.fromPartial({ - name: [ - { namePart: "dev.restate.ext.field", isExtension: true }, - ], - identifierValue: fieldTypeToJSON(fields[fieldName]["field"]), - }) - ); - } - } - } - - if ( - this.discovery.files?.file?.filter( - (haveDesc) => desc.name === haveDesc.name - ).length === 0 - ) { - this.discovery.files?.file.push(desc); - } - descriptor.dependencies?.forEach((dep) => { - this.addDescriptor(dep); - }); - } - - private rpcHandler( - keyed: boolean, - route: string, - handler: Function - ): { - descriptor: MethodDescriptorProto1; - method: GrpcServiceMethod; - } { - const descriptor = createRpcMethodDescriptor(route); + computeDiscovery(): discovery.Deployment { + const components = [...this.componenets.values()].map((c) => c.discovery()); - const localMethod = (instance: unknown, input: RpcRequest) => { - if (keyed) { - return dispatchKeyedRpcHandler( - useKeyedContext(instance), - input, - handler - ); - } else { - return dispatchUnkeyedRpcHandler(useContext(instance), input, handler); - } + const deployment: discovery.Deployment = { + protocolMode: discovery.ProtocolMode.BIDI_STREAM, + minProtocolVersion: 1, + maxProtocolVersion: 2, + components, }; - const decoder = RpcRequest.decode; - const encoder = (message: RpcResponse) => - RpcResponse.encode({ - response: jsonSafeAny("", message.response), - }).finish(); - - const method = new GrpcServiceMethod( - route, - route, - keyed, - localMethod, - decoder, - encoder - ); - - return { - descriptor: descriptor, - method: method as GrpcServiceMethod, - }; - } - - stringKeyedEventHandler( - keyed: boolean, - route: string, - handler: Function - ): { - descriptor: MethodDescriptorProto1; - method: GrpcServiceMethod; - } { - if (!keyed) { - // TODO: support unkeyed rpc event handler - throw new TerminalError("Unkeyed Event handlers are not yet supported."); - } - const descriptor = createStringKeyedMethodDescriptor(route); - const localMethod = (instance: unknown, input: KeyedEvent) => { - return dispatchKeyedEventHandler( - useKeyedContext(instance), - input, - handler - ); - }; - - const decoder = KeyedEvent.decode; - const encoder = (message: Empty) => Empty.encode(message).finish(); - - const method = new GrpcServiceMethod( - route, - route, - keyed, - localMethod, - decoder, - encoder - ); - - return { - descriptor, - method: method as GrpcServiceMethod, - }; + return deployment; } - private bindRpcService(name: string, router: RpcRouter, keyed: boolean) { - if (name === undefined || router === undefined || keyed === undefined) { - throw new Error("incomplete arguments: (name, router, keyed)"); - } - if (!(typeof name === "string") || name.length === 0) { - throw new Error("service name must be a non-empty string"); - } + private bindServiceComponent(name: string, router: RpcRouter) { if (name.indexOf("/") !== -1) { throw new Error("service name must not contain any slash '/'"); } - - const lastDot = name.indexOf("."); - const serviceName = lastDot === -1 ? name : name.substring(lastDot + 1); - const servicePackage = name.substring( - 0, - name.length - serviceName.length - 1 - ); - - const desc = this.dynrpcDescriptor; - const serviceGrpcSpec = keyed - ? pushKeyedService(desc, name) - : pushUnKeyedService(desc, name); + const componenet = new ServiceComponent(name); for (const [route, handler] of Object.entries(router)) { - let registration: { - descriptor: MethodDescriptorProto1; - method: GrpcServiceMethod; - }; - - if (isEventHandler(handler)) { - const theHandler = handler.handler; - registration = this.stringKeyedEventHandler(keyed, route, theHandler); - } else { - registration = this.rpcHandler(keyed, route, handler); - } - serviceGrpcSpec.method.push(registration.descriptor); - const url = `/invoke/${name}/${route}`; - this.methods[url] = new HostedGrpcServiceMethod( - {}, // we don't actually execute on any class instance - servicePackage, - serviceName, - registration.method - ) as HostedGrpcServiceMethod; - - rlog.info( - `Binding: ${url} -> ${JSON.stringify(registration.method, null, "\t")}` - ); - } - - // since we modified this descriptor, we need to remove it in case it was added before, - // so that the modified version is processed and added again - const filteredFiles = this.discovery.files?.file?.filter( - (haveDesc) => desc.fileDescriptor.name !== haveDesc.name - ); - if (this.discovery.files !== undefined && filteredFiles !== undefined) { - this.discovery.files.file = filteredFiles; - } - - this.addDescriptor(desc); - this.discovery.services?.push(name); - } - - methodByUrl( - url: string | undefined | null - ): HostedGrpcServiceMethod | undefined { - if (url == undefined || url === null) { - return undefined; + componenet.add({ + name: route, + /* eslint-disable @typescript-eslint/no-explicit-any */ + fn: handler as ServiceHandlerFunction, + }); } - return this.methods[url] as HostedGrpcServiceMethod; - } -} -/* eslint-disable @typescript-eslint/no-explicit-any */ -function indexProperties(instance: any): Map { - const names = new Map(); - while ( - instance !== null && - instance !== undefined && - instance !== Object.prototype - ) { - for (const property of Object.getOwnPropertyNames(instance)) { - names.set(property.toLowerCase(), property); - } - instance = Object.getPrototypeOf(instance); + this.addComponenet(componenet); } - return names; -} - -// Given: -// * an instance of a class that implements a gRPC TypeScript interface, -// as generated by our protoc plugin, this method -// * The ProtobufFileDescriptor as generated by the protobuf plugin -// * and the gRPC service name -// -// Return a GrpcService definition, as defined above. -// -// For example (see first: example.proto and example.ts): -// -// > parse(example.protoMetaData, "Greeter", new GreeterService()) -// -// produces ~ -// -// serviceName: 'Greeter', -// instance: GreeterService {}, -// methods: { -// multiword: { -// localName: 'multiWord', -// fn: [Function: multiWord], -// inputType: [Object], -// outputType: [Object] -// }, -// greet: { -// localName: 'greet', -// fn: [Function: greet], -// inputType: [Object], -// outputType: [Object] -// } -// } -//} -// -/* eslint-disable @typescript-eslint/no-explicit-any */ -export function parseService( - meta: ProtoMetadata, - serviceName: string, - instance: any -) { - const svcMethods: Array> = []; - const service_type = - meta.options?.services?.[serviceName].options?.["service_type"]; - const keyed = service_type !== ServiceType.UNKEYED; - - // index all the existing properties that `instance` has. - // we index them by the lower case represention. - const names = indexProperties(instance); - for (const serviceDescriptor of meta.fileDescriptor.service) { - if (serviceName !== serviceDescriptor.name) { - continue; + private bindVirtualObjectComponent(name: string, router: RpcRouter) { + if (name.indexOf("/") !== -1) { + throw new Error("service name must not contain any slash '/'"); } - for (const methodDescriptor of serviceDescriptor.method) { - const lowercaseName = methodDescriptor.name.toLowerCase(); - const localName = names.get(lowercaseName); - if (localName === undefined || localName === null) { - throw new Error(`unimplemented method ${methodDescriptor.name}`); - } - const fn = instance[localName]; - if (typeof fn !== "function") { - throw new Error( - `A property ${localName} exists, which coresponds to a gRPC service named ${methodDescriptor.name}, but that property is not a function.` - ); - } - const localMethod = async (instance: unknown, input: unknown) => { - return await fn.call(instance, input); - }; - let inputMessage = meta.references[methodDescriptor.inputType]; - // If the input message type is not defined by the proto files of the service but by a dependency (e.g. BoolValue, Empty, etc) - // then we need to look for the encoders and decoders in the dependencies. - if (inputMessage === undefined) { - meta.dependencies?.forEach((dep) => { - if (dep.references[methodDescriptor.inputType] !== undefined) { - inputMessage = dep.references[methodDescriptor.inputType]; - } - }); - } - let outputMessage = meta.references[methodDescriptor.outputType]; - // If the output message type is not defined by use but by the proto files of the service (e.g. BoolValue, Empty, etc) - // then we need to look for the encoders and decoders in the dependencies. - if (outputMessage === undefined) { - meta.dependencies?.forEach((dep) => { - if (dep.references[methodDescriptor.outputType] !== undefined) { - outputMessage = dep.references[methodDescriptor.outputType]; - } - }); - } + const component = new VritualObjectComponent(name); - const decoder = (buffer: Uint8Array) => inputMessage.decode(buffer); - const encoder = (message: unknown) => - outputMessage.encode(message).finish(); - svcMethods.push( - new GrpcServiceMethod( - methodDescriptor.name, - localName, - keyed, - localMethod, - decoder, - encoder - ) - ); + for (const [route, handler] of Object.entries(router)) { + component.add({ + name: route, + fn: handler as VirtualObjectHandlerFunction, + }); } - return new GrpcService( - serviceName, - meta.fileDescriptor.package, - instance, - svcMethods - ); + this.addComponenet(component); } - throw new Error(`Unable to find a service ${serviceName}.`); } export type RpcRouter = { [key: string]: Function; }; - -async function dispatchKeyedRpcHandler( - ctx: KeyedContext, - req: RpcRequest, - handler: Function -): Promise { - const { key, request } = verifyAssumptions(true, req); - if (typeof key !== "string" || key.length === 0) { - // we throw a terminal error here, because this cannot be patched by updating code: - // if the request is wrong (missing a key), the request can never make it - throw new TerminalError( - "Keyed handlers must recieve a non null or empty string key" - ); - } - const jsResult = (await handler(ctx, key, request)) as any; - return RpcResponse.create({ response: jsResult }); -} - -async function dispatchUnkeyedRpcHandler( - ctx: Context, - req: RpcRequest, - handler: Function -): Promise { - const { request } = verifyAssumptions(false, req); - const result = await handler(ctx, request); - return RpcResponse.create({ response: result }); -} - -async function dispatchKeyedEventHandler( - ctx: KeyedContext, - req: KeyedEvent, - handler: Function -): Promise { - const key = req.key; - if (key === null || key === undefined || key.length === 0) { - // we throw a terminal error here, because this cannot be patched by updating code: - // if the request is wrong (missing a key), the request can never make it - throw new TerminalError( - "Keyed handlers must receive a non null or empty string key" - ); - } - const jsEvent = new Event(key, req.payload, req.attributes); - await handler(ctx, jsEvent); - return Empty.create({}); -} - -function copyProtoMetadata( - original: RpcServiceProtoMetadata -): RpcServiceProtoMetadata { - // duplicate the file descriptor. shallow, because we only need to - // change one top-level field: service[] - const fileDescriptorCopy = { - ...original.fileDescriptor, - } as FileDescriptorProto1; - fileDescriptorCopy.service = []; - - let options = original.options; - if (options !== undefined) { - options = { - ...original.options, - }; - options.services = {}; - } - - return { - fileDescriptor: fileDescriptorCopy, - references: original.references, - dependencies: original.dependencies, - options: options, - }; -} - -function pushKeyedService( - desc: RpcServiceProtoMetadata, - newName: string -): ServiceDescriptorProto1 { - const service = { - ...rpcServiceProtoMetadata.fileDescriptor.service[0], - } as ServiceDescriptorProto1; - service.name = newName; - service.method = []; - desc.fileDescriptor.service.push(service); - - const serviceOptions = - rpcServiceProtoMetadata.options?.services?.["RpcEndpoint"]; - if (serviceOptions === undefined) { - throw new Error( - "Missing service options in original RpcEndpoint proto descriptor" - ); - } - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - desc.options!.services![newName] = serviceOptions; - - return service; -} - -function pushUnKeyedService( - desc: RpcServiceProtoMetadata, - newName: string -): ServiceDescriptorProto1 { - const service = { - ...rpcServiceProtoMetadata.fileDescriptor.service[1], - } as ServiceDescriptorProto1; - service.name = newName; - service.method = []; - desc.fileDescriptor.service.push(service); - - const serviceOptions = - rpcServiceProtoMetadata.options?.services?.["UnkeyedRpcEndpoint"]; - if (serviceOptions === undefined) { - throw new Error( - "Missing service options in original UnkeyedRpcEndpoint proto descriptor" - ); - } - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - desc.options!.services![newName] = serviceOptions; - - return service; -} - -function createRpcMethodDescriptor(methodName: string): MethodDescriptorProto1 { - const desc = { - ...rpcServiceProtoMetadata.fileDescriptor.service[0].method[0], - } as MethodDescriptorProto1; - desc.name = methodName; - return desc; -} - -function createStringKeyedMethodDescriptor( - methodName: string -): MethodDescriptorProto1 { - const desc = { - ...rpcServiceProtoMetadata.fileDescriptor.service[0].method[1], - } as MethodDescriptorProto1; - desc.name = methodName; - return desc; -} diff --git a/src/endpoint/http2_handler.ts b/src/endpoint/http2_handler.ts index 0a656fe2..4fd0370d 100644 --- a/src/endpoint/http2_handler.ts +++ b/src/endpoint/http2_handler.ts @@ -13,26 +13,21 @@ import stream from "stream"; import { pipeline, finished } from "stream/promises"; import http2, { Http2ServerRequest, Http2ServerResponse } from "http2"; import { parse as urlparse, Url } from "url"; -import { - ProtocolMode, - ServiceDiscoveryResponse, -} from "../generated/proto/discovery"; import { EndpointImpl } from "./endpoint_impl"; import { RestateHttp2Connection } from "../connection/http_connection"; -import { HostedGrpcServiceMethod } from "../types/grpc"; import { ensureError } from "../types/errors"; import { InvocationBuilder } from "../invocation"; import { StateMachine } from "../state_machine"; import { rlog } from "../logger"; +import { + ComponenetHandler, + VirtualObjectHandler, + parseUrlComponenets, +} from "../types/componenets"; +import { Deployment, ProtocolMode } from "../types/discovery"; export class Http2Handler { - private readonly discoveryResponse: ServiceDiscoveryResponse; - constructor(private readonly endpoint: EndpointImpl) { - this.discoveryResponse = ServiceDiscoveryResponse.fromPartial({ - ...this.endpoint.discovery, - protocolMode: ProtocolMode.BIDI_STREAM, - }); - } + constructor(private readonly endpoint: EndpointImpl) {} acceptConnection( request: Http2ServerRequest, @@ -52,49 +47,41 @@ export class Http2Handler { }); } - private async handleConnection( + private handleConnection( url: Url, stream: http2.ServerHttp2Stream ): Promise { - const method = this.endpoint.methodByUrl(url.path); - - if (method !== undefined) { - // valid connection, let's dispatch the invocation - stream.respond({ - "content-type": "application/restate", - ":status": 200, - }); - - const restateStream = RestateHttp2Connection.from(stream); - await handleInvocation(method, restateStream); - return; + if (url.path === "/restate/discover") { + return respondDiscovery(this.endpoint.computeDiscovery(), stream); } - - // no method under that name. might be a discovery request - if (url.path == "/discover") { - rlog.info( - "Answering discovery request. Announcing services: " + - JSON.stringify(this.discoveryResponse.services) - ); - await respondDiscovery(this.discoveryResponse, stream); - return; + const urlComponenets = parseUrlComponenets(url.path ?? undefined); + if (!urlComponenets) { + return respondNotFound(stream); } - - // no discovery, so unknown method: 404 - rlog.error(`No service and function found for URL ${url.path}`); - await respondNotFound(stream); + const componenet = this.endpoint.componenetByName( + urlComponenets.componenetName + ); + if (!componenet) { + return respondNotFound(stream); + } + const handler = componenet.handlerMatching(urlComponenets); + if (!handler) { + return respondNotFound(stream); + } + const restateStream = RestateHttp2Connection.from(stream); + return handleInvocation(handler, restateStream); } } async function respondDiscovery( - response: ServiceDiscoveryResponse, + response: Deployment, http2Stream: http2.ServerHttp2Stream ) { - const responseData = ServiceDiscoveryResponse.encode(response).finish(); + const responseData = JSON.stringify(response); http2Stream.respond({ ":status": 200, - "content-type": "application/proto", + "content-type": "application/restate", }); await pipeline(stream.Readable.from(responseData), http2Stream, { @@ -104,19 +91,19 @@ async function respondDiscovery( async function respondNotFound(stream: http2.ServerHttp2Stream) { stream.respond({ - "content-type": "application/restate", + "content-type": "application/json", ":status": 404, }); stream.end(); await finished(stream); } -async function handleInvocation( - func: HostedGrpcServiceMethod, +async function handleInvocation( + handler: ComponenetHandler, connection: RestateHttp2Connection ) { // step 1: collect all journal events - const journalBuilder = new InvocationBuilder(func); + const journalBuilder = new InvocationBuilder(handler); connection.pipeToConsumer(journalBuilder); try { await journalBuilder.completion(); @@ -127,11 +114,11 @@ async function handleInvocation( // step 2: create the state machine const invocation = journalBuilder.build(); - const stateMachine = new StateMachine( + const stateMachine = new StateMachine( connection, invocation, ProtocolMode.BIDI_STREAM, - func.method.keyedContext, + handler instanceof VirtualObjectHandler, invocation.inferLoggerContext() ); connection.pipeToConsumer(stateMachine); diff --git a/src/endpoint/lambda_handler.ts b/src/endpoint/lambda_handler.ts index 21e969a1..64462ef5 100644 --- a/src/endpoint/lambda_handler.ts +++ b/src/endpoint/lambda_handler.ts @@ -17,10 +17,6 @@ import { APIGatewayProxyResultV2, Context, } from "aws-lambda"; -import { - ProtocolMode, - ServiceDiscoveryResponse, -} from "../generated/proto/discovery"; import { EndpointImpl } from "./endpoint_impl"; import { LambdaConnection } from "../connection/lambda_connection"; import { InvocationBuilder } from "../invocation"; @@ -29,15 +25,14 @@ import { Message } from "../types/types"; import { StateMachine } from "../state_machine"; import { ensureError } from "../types/errors"; import { OUTPUT_STREAM_ENTRY_MESSAGE_TYPE } from "../types/protocol"; +import { ProtocolMode } from "../types/discovery"; +import { + VirtualObjectHandler, + parseUrlComponenets, +} from "../types/componenets"; export class LambdaHandler { - private readonly discoveryResponse: ServiceDiscoveryResponse; - constructor(private readonly endpoint: EndpointImpl) { - this.discoveryResponse = ServiceDiscoveryResponse.fromPartial({ - ...this.endpoint.discovery, - protocolMode: ProtocolMode.REQUEST_RESPONSE, - }); - } + constructor(private readonly endpoint: EndpointImpl) {} // -------------------------------------------------------------------------- @@ -58,22 +53,21 @@ export class LambdaHandler { } const pathSegments = path.split("/"); - // API Gateway can add a prefix to the path based on the name of the Lambda function and deployment stage - // (e.g. /default) - // So we only check the ending of the path on correctness. - // Logic: - // 1. Check whether there are at least three segments in the path and whether the third-last one is "invoke". - // If that is the case, treat it as an invocation. - // 2. See if the last one is "discover", answer with discovery. - // 3. Else report "invalid path". + if ( + pathSegments.length >= 2 && + pathSegments[pathSegments.length - 1] === "discover" && + pathSegments[pathSegments.length - 2] === "restate" + ) { + return this.handleDiscovery(); + } + + // TODO: FIXME verify how does this work on Lambda if ( pathSegments.length >= 3 && pathSegments[pathSegments.length - 3] === "invoke" ) { const url = "/" + pathSegments.slice(-3).join("/"); return await this.handleInvoke(url, event, context); - } else if (pathSegments[pathSegments.length - 1] === "discover") { - return this.handleDiscovery(); } else { const msg = `Invalid path: path doesn't end in /invoke/SvcName/MethodName and also not in /discover: ${path}`; rlog.trace(msg); @@ -87,7 +81,10 @@ export class LambdaHandler { context: Context ): Promise { try { - const method = this.endpoint.methodByUrl(url); + const parsedUrl = parseUrlComponenets(url); + const method = this.endpoint.componenetByName(parsedUrl?.componenetName!); + const handler = method?.handlerMatching(parsedUrl!); + if (event.body == null) { throw new Error("The incoming message body was null"); } @@ -106,7 +103,7 @@ export class LambdaHandler { // build the previous journal from the events let decodedEntries: Message[] | null = decodeLambdaBody(event.body); - const journalBuilder = new InvocationBuilder(method); + const journalBuilder = new InvocationBuilder(handler!); decodedEntries.forEach((e: Message) => journalBuilder.handleMessage(e)); const alreadyCompleted = decodedEntries.find( @@ -121,7 +118,7 @@ export class LambdaHandler { connection, invocation, ProtocolMode.REQUEST_RESPONSE, - method.method.keyedContext, + handler instanceof VirtualObjectHandler, invocation.inferLoggerContext({ AWSRequestId: context.awsRequestId, }) @@ -146,20 +143,13 @@ export class LambdaHandler { } private handleDiscovery(): APIGatewayProxyResult | APIGatewayProxyResultV2 { - // return discovery information - rlog.info( - "Answering discovery request. Announcing services: " + - JSON.stringify(this.discoveryResponse.services) - ); return { headers: { - "content-type": "application/proto", + "content-type": "application/restate", }, statusCode: 200, - isBase64Encoded: true, - body: encodeResponse( - ServiceDiscoveryResponse.encode(this.discoveryResponse).finish() - ), + isBase64Encoded: false, + body: JSON.stringify(this.endpoint.computeDiscovery()), }; } diff --git a/src/invocation.ts b/src/invocation.ts index 0175fcd8..220d21c3 100644 --- a/src/invocation.ts +++ b/src/invocation.ts @@ -12,7 +12,6 @@ /*eslint-disable @typescript-eslint/no-non-null-assertion*/ import { Message } from "./types/types"; -import { HostedGrpcServiceMethod } from "./types/grpc"; import { Failure, PollInputStreamEntryMessage, @@ -28,6 +27,7 @@ import { LocalStateStore } from "./local_state_store"; import { ensureError } from "./types/errors"; import { LoggerContext } from "./logger"; import { CompletablePromise } from "./utils/promises"; +import { ComponenetHandler } from "./types/componenets"; enum State { ExpectingStart = 0, @@ -40,7 +40,7 @@ type InvocationValue = | { kind: "value"; value: Buffer } | { kind: "failure"; failure: Failure }; -export class InvocationBuilder implements RestateStreamConsumer { +export class InvocationBuilder implements RestateStreamConsumer { private readonly complete = new CompletablePromise(); private state: State = State.ExpectingStart; @@ -52,8 +52,9 @@ export class InvocationBuilder implements RestateStreamConsumer { private invocationValue?: InvocationValue = undefined; private nbEntriesToReplay?: number = undefined; private localStateStore?: LocalStateStore; + private userKey?: string; - constructor(private readonly method: HostedGrpcServiceMethod) {} + constructor(private readonly component: ComponenetHandler) {} public handleMessage(m: Message): boolean { try { @@ -138,15 +139,16 @@ export class InvocationBuilder implements RestateStreamConsumer { return this.complete.promise; } - private handleStartMessage(m: StartMessage): InvocationBuilder { + private handleStartMessage(m: StartMessage): InvocationBuilder { this.nbEntriesToReplay = m.knownEntries; this.id = m.id; this.debugId = m.debugId; this.localStateStore = new LocalStateStore(m.partialState, m.stateMap); + this.userKey = m.key; return this; } - private addReplayEntry(m: Message): InvocationBuilder { + private addReplayEntry(m: Message): InvocationBuilder { // Will be retrieved when the user code reaches this point this.replayEntries.set(this.runtimeReplayIndex, m); this.incrementRuntimeReplayIndex(); @@ -161,33 +163,35 @@ export class InvocationBuilder implements RestateStreamConsumer { return this.state === State.Complete; } - public build(): Invocation { + public build(): Invocation { if (!this.isComplete()) { throw new Error( `Cannot build invocation. Not all data present: ${JSON.stringify(this)}` ); } return new Invocation( - this.method!, + this.component, this.id!, this.debugId!, this.nbEntriesToReplay!, this.replayEntries!, this.invocationValue!, - this.localStateStore! + this.localStateStore!, + this.userKey ); } } -export class Invocation { +export class Invocation { constructor( - public readonly method: HostedGrpcServiceMethod, + public readonly handler: ComponenetHandler, public readonly id: Buffer, public readonly debugId: string, public readonly nbEntriesToReplay: number, public readonly replayEntries: Map, public readonly invocationValue: InvocationValue, - public readonly localStateStore: LocalStateStore + public readonly localStateStore: LocalStateStore, + public readonly userKey?: string ) {} public inferLoggerContext(additionalContext?: { @@ -195,9 +199,9 @@ export class Invocation { }): LoggerContext { return new LoggerContext( this.debugId, - this.method.pkg, - this.method.service, - this.method.method.name, + "", + this.handler.name(), + this.handler.componenet().name(), additionalContext ); } diff --git a/src/journal.ts b/src/journal.ts index e1c57454..12a09499 100644 --- a/src/journal.ts +++ b/src/journal.ts @@ -50,7 +50,7 @@ import { CompletablePromise } from "./utils/promises"; const RESOLVED = Promise.resolve(undefined); -export class Journal { +export class Journal { private state = NewExecutionState.REPLAYING; private userCodeJournalIndex = 0; @@ -59,7 +59,7 @@ export class Journal { // 0 = root promise of the method invocation private pendingJournalEntries = new Map(); - constructor(readonly invocation: Invocation) { + constructor(readonly invocation: Invocation) { const inputMessage = invocation.replayEntries.get(0); if ( !inputMessage || diff --git a/src/state_machine.ts b/src/state_machine.ts index 8e8261e8..011fe2fc 100644 --- a/src/state_machine.ts +++ b/src/state_machine.ts @@ -12,7 +12,6 @@ import * as p from "./types/protocol"; import { ContextImpl } from "./context_impl"; import { Connection, RestateStreamConsumer } from "./connection/connection"; -import { ProtocolMode } from "./generated/proto/discovery"; import { Message } from "./types/types"; import { createStateMachineConsole, @@ -54,9 +53,10 @@ import { PromiseType, } from "./promise_combinator_tracker"; import { CombinatorEntryMessage } from "./generated/proto/javascript"; +import { ProtocolMode } from "./types/discovery"; -export class StateMachine implements RestateStreamConsumer { - private journal: Journal; +export class StateMachine implements RestateStreamConsumer { + private journal: Journal; private restateContext: ContextImpl; private readonly invocationComplete = new CompletablePromise(); @@ -84,7 +84,7 @@ export class StateMachine implements RestateStreamConsumer { constructor( private readonly connection: Connection, - private readonly invocation: Invocation, + private readonly invocation: Invocation, private readonly protocolMode: ProtocolMode, keyedContext: boolean, loggerContext: LoggerContext, @@ -95,10 +95,11 @@ export class StateMachine implements RestateStreamConsumer { this.restateContext = new ContextImpl( this.invocation.id, - this.invocation.method.service, + this.invocation.handler.componenet().name(), // The console exposed by RestateContext filters logs in replay, while the internal one is based on the ENV variables. createRestateConsole(loggerContext, () => !this.journal.isReplaying()), keyedContext, + invocation.userKey, this ); this.journal = new Journal(this.invocation); @@ -339,7 +340,7 @@ export class StateMachine implements RestateStreamConsumer { switch (this.invocation.invocationValue.kind) { case "value": - resultBytes = this.invocation.method.invoke( + resultBytes = this.invocation.handler.invoke( this.restateContext, this.invocation.invocationValue.value ); diff --git a/src/types/componenets.ts b/src/types/componenets.ts new file mode 100644 index 00000000..fc9ab5fc --- /dev/null +++ b/src/types/componenets.ts @@ -0,0 +1,221 @@ +/* + * Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH + * + * This file is part of the Restate SDK for Node.js/TypeScript, + * which is released under the MIT license. + * + * You can find a copy of the license in file LICENSE in the root + * directory of this repository or package, or at + * https://github.com/restatedev/sdk-typescript/blob/main/LICENSE + */ + +/* eslint-disable @typescript-eslint/ban-types */ +/* eslint-disable @typescript-eslint/no-explicit-any */ + +import { Context, KeyedContext } from "../context"; +import * as d from "./discovery"; +import { ContextImpl } from "../context_impl"; + +// +// Interfaces +// + +export interface Component { + name(): string; + handlerMatching(url: UrlPathComponenets): ComponenetHandler | undefined; + discovery(): d.Component; +} + +export interface ComponenetHandler { + name(): string; + componenet(): Component; + invoke(context: ContextImpl, input: Uint8Array): Promise; +} + +// +// Service +// + +export type ServiceHandlerFunction = ( + ctx: Context, + param: I +) => Promise; + +export type ServiceHandlerOpts = { + name: string; + fn: ServiceHandlerFunction; +}; + +export class ServiceComponent implements Component { + private readonly handlers: Map = new Map(); + + constructor(private readonly componentName: string) {} + + name(): string { + return this.componentName; + } + + add(opts: ServiceHandlerOpts) { + const c = new ServiceHandler(opts, this); + this.handlers.set(opts.name, c); + } + + discovery(): d.Component { + const handlers: d.Handler[] = [...this.handlers.keys()].map((name) => { + return { + name, + }; + }); + + return { + fullyQualifiedComponentName: this.componentName, + componentType: d.ComponentType.SERVICE, + handlers, + }; + } + + handlerMatching(url: UrlPathComponenets): ComponenetHandler | undefined { + return this.handlers.get(url.componenetName); + } +} + +export class ServiceHandler implements ComponenetHandler { + private readonly handlerName: string; + private readonly parent: ServiceComponent; + private readonly fn: ServiceHandlerFunction; + + constructor(opts: ServiceHandlerOpts, parent: ServiceComponent) { + this.handlerName = opts.name; + this.parent = parent; + this.fn = opts.fn; + } + + invoke(context: ContextImpl, input: Uint8Array): Promise { + const param = JSON.parse(input.toString()); + return this.fn(context, param); + } + + name(): string { + return this.handlerName; + } + componenet(): Component { + return this.parent; + } +} + +// +// Virtual Object +// + +export type VirtualObjectHandlerFunction = ( + ctx: KeyedContext, + key: string, + param: I +) => Promise; + +export type VirtualObjectHandlerOpts = { + name: string; + fn: VirtualObjectHandlerFunction; +}; + +export class VritualObjectComponent implements Component { + private readonly opts: Map> = + new Map(); + + constructor(public readonly componenetName: string) {} + + name(): string { + return this.componenetName; + } + + add(opts: VirtualObjectHandlerOpts) { + this.opts.set(opts.name, opts as VirtualObjectHandlerOpts); + } + + discovery(): d.Component { + const handlers: d.Handler[] = [...this.opts.keys()].map((name) => { + return { + name, + }; + }); + + return { + fullyQualifiedComponentName: this.componenetName, + componentType: d.ComponentType.VIRTUAL_OBJECT, + handlers, + }; + } + + handlerMatching(url: UrlPathComponenets): ComponenetHandler | undefined { + if (!url.key) { + return undefined; + } + const opts = this.opts.get(url.handlerName); + if (!opts) { + return undefined; + } + return new VirtualObjectHandler(url.handlerName, url.key, this, opts); + } +} + +export class VirtualObjectHandler implements ComponenetHandler { + constructor( + private readonly componentName: string, + private readonly key: string, + private readonly parent: VritualObjectComponent, + private readonly opts: VirtualObjectHandlerOpts + ) {} + + name(): string { + return this.componentName; + } + componenet(): Component { + return this.parent; + } + + invoke(context: ContextImpl, input: Uint8Array): Promise { + const param = JSON.parse(input.toString()); + return this.opts.fn(context, context.key(), param); + } +} + +export type UrlPathComponenets = { + componenetName: string; + handlerName: string; + key?: string; +}; + +export function parseUrlComponenets( + urlPath?: string +): UrlPathComponenets | undefined { + let path = urlPath ?? "/"; + if (path.charAt(0) !== "/") { + path = +"/" + path; + } + const fragments = path.split("/", 4); + // + // case 1: /foo/bar => ["", "foo" , "bar"] + // case 2: /foo/bob/bar => ["", "foo", "bob", "bar"] + // + if (fragments.length === 3) { + /* eslint-disable-next-line @typescript-eslint/no-unused-vars */ + const [_, componenetName, handlerName] = fragments; + return { + componenetName, + handlerName, + }; + } + if (fragments.length === 4) { + /* eslint-disable-next-line @typescript-eslint/no-unused-vars */ + const [_, componenetName, key, handlerName] = fragments; + return { + componenetName, + handlerName, + key, + }; + } + // + // an exceptional case, the runtime had invoked us with a strange url + // + return undefined; +} diff --git a/src/types/discovery.ts b/src/types/discovery.ts new file mode 100644 index 00000000..22a90e51 --- /dev/null +++ b/src/types/discovery.ts @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH + * + * This file is part of the Restate SDK for Node.js/TypeScript, + * which is released under the MIT license. + * + * You can find a copy of the license in file LICENSE in the root + * directory of this repository or package, or at + * https://github.com/restatedev/sdk-typescript/blob/main/LICENSE + */ + +export enum ProtocolMode { + BIDI_STREAM = "BIDI_STREAM", + REQUEST_RESPONSE = "REQUEST_RESPONSE", +} + +export enum ComponentType { + VIRTUAL_OBJECT = "VIRTUAL_OBJECT", + SERVICE = "SERVICE", +} + +export interface Handler { + name: string; + /* eslint-disable-next-line @typescript-eslint/no-explicit-any */ + inputSchema?: any; + /* eslint-disable-next-line @typescript-eslint/no-explicit-any */ + outputSchema?: any; +} + +export interface Component { + fullyQualifiedComponentName: string; + componentType: ComponentType; + handlers: Handler[]; +} + +export interface Deployment { + protocolMode: ProtocolMode; + minProtocolVersion: number; + maxProtocolVersion: number; + components: Component[]; +} diff --git a/src/types/grpc.ts b/src/types/grpc.ts deleted file mode 100644 index 33a0725c..00000000 --- a/src/types/grpc.ts +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH - * - * This file is part of the Restate SDK for Node.js/TypeScript, - * which is released under the MIT license. - * - * You can find a copy of the license in file LICENSE in the root - * directory of this repository or package, or at - * https://github.com/restatedev/sdk-typescript/blob/main/LICENSE - */ - -import { Context } from "../context"; -import { FileDescriptorProto } from "ts-proto-descriptors"; - -export class GrpcServiceMethod { - constructor( - readonly name: string, // the gRPC name as defined in the .proto file - readonly localName: string, // the method name as defined in the class. - readonly keyedContext: boolean, // If the method expects a keyed context - readonly localFn: (instance: unknown, input: I) => Promise, // the actual function - readonly inputDecoder: (buf: Uint8Array) => I, // the protobuf decoder - readonly outputEncoder: (output: O) => Uint8Array // protobuf encoder - ) {} -} - -export class GrpcService { - constructor( - readonly name: string, - readonly packge: string, - readonly impl: object, - readonly methods: Array> - ) {} -} - -export class HostedGrpcServiceMethod { - constructor( - readonly instance: unknown, - readonly pkg: string, - readonly service: string, - readonly method: GrpcServiceMethod - ) {} - - // The end of an invoke is either a response (Uint8Array) or a SuspensionMessage - async invoke(context: Context, inBytes: Uint8Array): Promise { - const instanceWithContext = setContext(this.instance, context); - const input = this.method.inputDecoder(inBytes); - const result: O = await this.method.localFn(instanceWithContext, input); - return this.method.outputEncoder(result); - } -} - -function setContext(instance: T, context: Context): T { - // creates a *new*, per call object that shares all the properties that @instance has - // except '$$restate' which is a unique, per call pointer to a restate context. - // - // The following line create a new object, that its prototype is @instance. - // and that object has a $$restate property. - const wrapper = Object.create(instance as object, { - $$restate: { value: context }, - }); - return wrapper as T; -} - -// -// The following definitions are equivalent to the ones -// generated by the protoc ts plugin. -// we use them to traverse the FileDescriptor -// -/* eslint-disable @typescript-eslint/no-explicit-any */ -type ProtoMetaMessageOptions = { - options?: { [key: string]: any }; - fields?: { [key: string]: { [key: string]: any } }; - oneof?: { [key: string]: { [key: string]: any } }; - nested?: { [key: string]: ProtoMetaMessageOptions }; -}; - -export interface ProtoMetadata { - fileDescriptor: FileDescriptorProto; - references: { [key: string]: any }; - dependencies?: ProtoMetadata[]; - options?: { - options?: { [key: string]: any }; - services?: { - [key: string]: { - options?: { [key: string]: any }; - methods?: { [key: string]: { [key: string]: any } }; - }; - }; - messages?: { [key: string]: ProtoMetaMessageOptions }; - enums?: { - [key: string]: { - options?: { [key: string]: any }; - values?: { [key: string]: { [key: string]: any } }; - }; - }; - }; -} diff --git a/src/utils/assumptions.ts b/src/utils/assumptions.ts deleted file mode 100644 index f11bf2de..00000000 --- a/src/utils/assumptions.ts +++ /dev/null @@ -1,131 +0,0 @@ -/* - * Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH - * - * This file is part of the Restate SDK for Node.js/TypeScript, - * which is released under the MIT license. - * - * You can find a copy of the license in file LICENSE in the root - * directory of this repository or package, or at - * https://github.com/restatedev/sdk-typescript/blob/main/LICENSE - */ - -import { RpcRequest } from "../generated/proto/dynrpc"; -import { TerminalError } from "../types/errors"; - -const ASSUME_UNKEYED_SINCE_FIRST_PARAM_NOT_STRING = 1; -const ASSUME_UNKEYED_SINCE_ZERO_ARGS = 2; -const ASSUME_KEYED_SINCE_TWO_ARGS_STR_AND_ANY = 3; -const ASSUME_EITHER_KEYED_OR_UNKEYED_ONE_STR_ARG = 4; - -export const requestFromArgs = (args: unknown[]): RpcRequest => { - switch (args.length) { - case 0: { - return RpcRequest.create({ - senderAssumes: ASSUME_UNKEYED_SINCE_ZERO_ARGS, - }); - } - case 1: { - if (typeof args[0] === "string") { - return RpcRequest.create({ - key: args[0], - senderAssumes: ASSUME_EITHER_KEYED_OR_UNKEYED_ONE_STR_ARG, - }); - } else { - return RpcRequest.create({ - request: args[0], - senderAssumes: ASSUME_UNKEYED_SINCE_FIRST_PARAM_NOT_STRING, - }); - } - } - case 2: { - if (typeof args[0] !== "string") { - throw new TerminalError( - `Two argument handlers are only possible for keyed handlers. Where the first argument must be of type 'string'.` - ); - } - return RpcRequest.create({ - key: args[0], - request: args[1], - senderAssumes: ASSUME_KEYED_SINCE_TWO_ARGS_STR_AND_ANY, - }); - } - default: { - throw new TerminalError("wrong number of arguments " + args.length); - } - } -}; - -/* eslint-disable @typescript-eslint/ban-types, @typescript-eslint/no-explicit-any */ -export type JsType = - | string - | number - | boolean - | Object - | null - | Array - | undefined; -/* eslint-enable @typescript-eslint/ban-types, @typescript-eslint/no-explicit-any */ - -const requireThat = (condition: boolean, errorMessage: string) => { - if (!condition) { - throw new TerminalError(errorMessage); - } -}; - -export const verifyAssumptions = ( - isKeyed: boolean, - request: RpcRequest -): { key?: string; request?: JsType } => { - const assumption = request.senderAssumes ?? 0; - switch (assumption) { - case 0: { - // no assumption: this comes from an ingress. - const hasKeyProperty = - typeof request.key === "string" && request.key.length > 0; - if (isKeyed) { - requireThat( - hasKeyProperty, - "Trying to call a keyed handler with a missing or empty 'key' property." - ); - } else { - requireThat( - !hasKeyProperty, - "Trying to call a an unkeyed handler with a 'key' property. Did you mean using the 'request' property instead?" - ); - } - return { key: request.key, request: request.request }; - } - case ASSUME_UNKEYED_SINCE_FIRST_PARAM_NOT_STRING: { - requireThat( - !isKeyed, - "Trying to call a keyed handler with a missing key. This could happen if the first argument passed is not a 'string'." - ); - return { request: request.request }; - } - case ASSUME_UNKEYED_SINCE_ZERO_ARGS: { - requireThat( - !isKeyed, - "A keyed handler must at least be invoked with a single non empty string argument, that represents the key. 0 arguments given." - ); - return { request: request.request }; - } - case ASSUME_KEYED_SINCE_TWO_ARGS_STR_AND_ANY: { - requireThat( - isKeyed, - "An unkeyed handler must have at most 1 argument. two given." - ); - return { key: request.key, request: request.request }; - } - case ASSUME_EITHER_KEYED_OR_UNKEYED_ONE_STR_ARG: { - if (isKeyed) { - return { key: request.key }; - } - return { request: request.key }; - } - default: { - throw new TerminalError( - `Unknown assumption id ${assumption}. This indicates an incorrect (or involuntary) setting of the assumption property at the ingress request, or an SDK bug.` - ); - } - } -}; diff --git a/src/utils/serde.ts b/src/utils/serde.ts new file mode 100644 index 00000000..250b6c9e --- /dev/null +++ b/src/utils/serde.ts @@ -0,0 +1,14 @@ +import { Buffer } from "node:buffer"; + +/* eslint-disable @typescript-eslint/no-explicit-any */ +export function serializeJson(item: any): Uint8Array { + const str = JSON.stringify(item ?? null); + return Buffer.from(str); +} + +/* eslint-disable @typescript-eslint/no-explicit-any */ +export function deserializeJson(buf: Uint8Array): any { + const b = Buffer.from(buf); + const str = b.toString(); + return JSON.parse(str); +} diff --git a/src/utils/utils.ts b/src/utils/utils.ts index da60713b..c1a2b325 100644 --- a/src/utils/utils.ts +++ b/src/utils/utils.ts @@ -47,30 +47,6 @@ export function jsonDeserialize(json: string): T { ) as T; } -// When using google.protobuf.Value in RPC handler responses, we want to roughly match the behaviour of JSON.stringify -// for example in converting Date objects to a UTC string -export function jsonSafeAny(key: string, value: any): any { - if ( - value !== undefined && - value !== null && - typeof value.toJSON == "function" - ) { - return value.toJSON(key) as any; - } else if (globalThis.Array.isArray(value)) { - // in place replace - value.forEach((_, i) => (value[i] = jsonSafeAny(i.toString(), value[i]))); - return value; - } else if (typeof value === "object") { - Object.keys(value).forEach((key) => { - value[key] = jsonSafeAny(key, value[key]); - }); - return value; - } else { - // primitive that doesn't have a toJSON method, with no children - return value; - } -} - export function formatMessageAsJson(obj: any): string { const newObj = { ...(obj as Record) }; for (const [key, value] of Object.entries(newObj)) { diff --git a/src/workflows/workflow_wrapper_service.ts b/src/workflows/workflow_wrapper_service.ts index 3c2c6dce..041800b1 100644 --- a/src/workflows/workflow_wrapper_service.ts +++ b/src/workflows/workflow_wrapper_service.ts @@ -32,23 +32,25 @@ class SharedContextImpl implements wf.SharedWfContext { get(stateName: string): Promise { return this.ctx - .rpc(this.stateServiceApi) + .object(this.stateServiceApi, this.wfId) .getState(this.wfId, stateName) as Promise; } promise(name: string): wf.DurablePromise { // Create the awakeable to complete const awk = this.ctx.awakeable(); - this.ctx.send(this.stateServiceApi).subscribePromise(this.wfId, { - promiseName: name, - awkId: awk.id, - }); + this.ctx + .objectSend(this.stateServiceApi, this.wfId) + .subscribePromise(this.wfId, { + promiseName: name, + awkId: awk.id, + }); // Prepare implementation of DurablePromise const peek = async (): Promise => { const result = await this.ctx - .rpc(this.stateServiceApi) + .object(this.stateServiceApi, this.wfId) .peekPromise(this.wfId, { promiseName: name }); if (result === null) { @@ -63,17 +65,21 @@ class SharedContextImpl implements wf.SharedWfContext { const resolve = (value: T) => { const currentValue = value === undefined ? null : value; - this.ctx.send(this.stateServiceApi).completePromise(this.wfId, { - promiseName: name, - completion: { value: currentValue }, - }); + this.ctx + .objectSend(this.stateServiceApi, this.wfId) + .completePromise(this.wfId, { + promiseName: name, + completion: { value: currentValue }, + }); }; const reject = (errorMsg: string) => { - this.ctx.send(this.stateServiceApi).completePromise(this.wfId, { - promiseName: name, - completion: { error: errorMsg }, - }); + this.ctx + .objectSend(this.stateServiceApi, this.wfId) + .completePromise(this.wfId, { + promiseName: name, + completion: { error: errorMsg }, + }); }; return Object.defineProperties(awk.promise, { @@ -114,20 +120,26 @@ class ExclusiveContextImpl extends SharedContextImpl implements wf.WfContext { } this.ctx - .send(this.stateServiceApi) + .objectSend(this.stateServiceApi, this.wfId) .setState(this.wfId, { stateName, value }); } clear(stateName: string): void { - this.ctx.send(this.stateServiceApi).clearState(this.wfId, stateName); + this.ctx + .objectSend(this.stateServiceApi, this.wfId) + .clearState(this.wfId, stateName); } stateKeys(): Promise> { - return this.ctx.rpc(this.stateServiceApi).stateKeys(this.wfId); + return this.ctx + .object(this.stateServiceApi, this.wfId) + .stateKeys(this.wfId); } clearAll(): void { - this.ctx.send(this.stateServiceApi).clearAllState(this.wfId); + this.ctx + .objectSend(this.stateServiceApi, this.wfId) + .clearAllState(this.wfId); } sideEffect( @@ -151,17 +163,38 @@ class ExclusiveContextImpl extends SharedContextImpl implements wf.WfContext { return this.ctx.sleep(millis); } - rpc(opts: restate.ServiceApi): restate.Client { - return this.ctx.rpc(opts); + key(): string { + const kctx = this.ctx as restate.KeyedContext; + return kctx.key(); } - send(opts: restate.ServiceApi): restate.SendClient { - return this.ctx.send(opts); + + service(opts: restate.ServiceApi): restate.Client { + return this.ctx.service(opts); + } + object(opts: restate.ServiceApi, key: string): restate.Client { + return this.ctx.object(opts, key); + } + objectSend( + opts: restate.ServiceApi, + key: string + ): restate.SendClient { + return this.ctx.objectSend(opts, key); + } + serviceSend(opts: restate.ServiceApi): restate.SendClient { + return this.ctx.serviceSend(opts); + } + objectSendDelayed( + opts: restate.ServiceApi, + delay: number, + key: string + ): restate.SendClient { + return this.ctx.objectSendDelayed(opts, delay, key); } - sendDelayed( + serviceSendDelayed( opts: restate.ServiceApi, delay: number ): restate.SendClient { - return this.ctx.sendDelayed(opts, delay); + return this.ctx.serviceSendDelayed(opts, delay); } } @@ -182,10 +215,10 @@ export function createWrapperService( checkRequestAndWorkflowId(request); const started = await ctx - .rpc(stateServiceApi) + .object(stateServiceApi, request.workflowId) .startWorkflow(request.workflowId); if (started === wf.WorkflowStartResult.STARTED) { - ctx.send(wrapperServiceApi).run(request); + ctx.service(wrapperServiceApi).run(request); } return started; }, @@ -205,18 +238,22 @@ export function createWrapperService( const result = await workflow.run(wfCtx, request); const resultValue = result !== undefined ? result : {}; await ctx - .rpc(stateServiceApi) + .object(stateServiceApi, request.workflowId) .finishOrFailWorkflow(request.workflowId, { value: resultValue }); return result; } catch (err) { const msg = stringifyError(err); await ctx - .rpc(stateServiceApi) + .object(stateServiceApi, request.workflowId) .finishOrFailWorkflow(request.workflowId, { error: msg }); throw err; } finally { ctx - .sendDelayed(stateServiceApi, DEFAULT_RETENTION_PERIOD) + .objectSendDelayed( + stateServiceApi, + DEFAULT_RETENTION_PERIOD, + request.workflowId + ) .dispose(request.workflowId); } }, @@ -229,7 +266,7 @@ export function createWrapperService( const awakeable = ctx.awakeable(); await ctx - .rpc(stateServiceApi) + .object(stateServiceApi, request.workflowId) .subscribeResult(request.workflowId, awakeable.id); return awakeable.promise; }, @@ -239,7 +276,9 @@ export function createWrapperService( request: wf.WorkflowRequest ): Promise => { checkRequestAndWorkflowId(request); - return ctx.rpc(stateServiceApi).getStatus(request.workflowId); + return ctx + .object(stateServiceApi, request.workflowId) + .getStatus(request.workflowId); }, }; diff --git a/test/awakeable.test.ts b/test/awakeable.test.ts index 3a78df68..3e84acef 100644 --- a/test/awakeable.test.ts +++ b/test/awakeable.test.ts @@ -29,7 +29,7 @@ import { } from "./protoutils"; import { TestDriver, TestResponse, TestGreeter } from "./testdriver"; -import { ProtocolMode } from "../src/generated/proto/discovery"; +import { ProtocolMode } from "../src/types/discovery"; class AwakeableGreeter implements TestGreeter { async greet(ctx: restate.KeyedContext): Promise { @@ -46,7 +46,7 @@ class AwakeableGreeter implements TestGreeter { describe("AwakeableGreeter", () => { it("sends message to runtime", async () => { const result = await new TestDriver(new AwakeableGreeter(), [ - startMessage(), + startMessage({ key: "Till" }), inputMessage(greetRequest("Till")), ]).run(); @@ -56,7 +56,10 @@ describe("AwakeableGreeter", () => { it("sends message to runtime for request-response case", async () => { const result = await new TestDriver( new AwakeableGreeter(), - [startMessage(1), inputMessage(greetRequest("Till"))], + [ + startMessage({ knownEntries: 1, key: "Till" }), + inputMessage(greetRequest("Till")), + ], ProtocolMode.REQUEST_RESPONSE ).run(); @@ -65,7 +68,7 @@ describe("AwakeableGreeter", () => { it("handles completion with value", async () => { const result = await new TestDriver(new AwakeableGreeter(), [ - startMessage(), + startMessage({ key: "Till" }), inputMessage(greetRequest("Till")), completionMessage(1, JSON.stringify("Francesco")), ]).run(); @@ -79,7 +82,7 @@ describe("AwakeableGreeter", () => { it("handles completion with empty string value", async () => { const result = await new TestDriver(new AwakeableGreeter(), [ - startMessage(), + startMessage({ key: "Till" }), inputMessage(greetRequest("Till")), completionMessage(1, JSON.stringify("")), ]).run(); @@ -93,7 +96,7 @@ describe("AwakeableGreeter", () => { it("handles completion with empty object value", async () => { const result = await new TestDriver(new AwakeableGreeter(), [ - startMessage(), + startMessage({ key: "Till" }), inputMessage(greetRequest("Till")), completionMessage(1, JSON.stringify({})), ]).run(); @@ -109,7 +112,7 @@ describe("AwakeableGreeter", () => { it("handles completion with failure", async () => { const result = await new TestDriver(new AwakeableGreeter(), [ - startMessage(), + startMessage({ key: "Till" }), inputMessage(greetRequest("Till")), completionMessage( 1, @@ -127,7 +130,7 @@ describe("AwakeableGreeter", () => { it("handles replay with value", async () => { const result = await new TestDriver(new AwakeableGreeter(), [ - startMessage(), + startMessage({ key: "Till" }), inputMessage(greetRequest("Till")), awakeableMessage("Francesco"), ]).run(); @@ -140,7 +143,7 @@ describe("AwakeableGreeter", () => { it("handles replay with failure", async () => { const result = await new TestDriver(new AwakeableGreeter(), [ - startMessage(), + startMessage({ key: "Till" }), inputMessage(greetRequest("Till")), awakeableMessage(undefined, failure("Something went wrong")), ]).run(); @@ -152,7 +155,7 @@ describe("AwakeableGreeter", () => { it("fails on journal mismatch. Completed with CompleteAwakeable during replay.", async () => { const result = await new TestDriver(new AwakeableGreeter(), [ - startMessage(), + startMessage({ key: "Till" }), inputMessage(greetRequest("Till")), resolveAwakeableMessage("awakeable-1", "hello"), // should have been an awakeableMessage ]).run(); @@ -177,7 +180,7 @@ class AwakeableNull implements TestGreeter { describe("AwakeableNull", () => { it("handles completion with null value", async () => { const result = await new TestDriver(new AwakeableNull(), [ - startMessage(), + startMessage({}), inputMessage(greetRequest("Till")), completionMessage(1, JSON.stringify(null)), ]).run(); diff --git a/test/complete_awakeable.test.ts b/test/complete_awakeable.test.ts index 7fb19a1e..c0e82c9f 100644 --- a/test/complete_awakeable.test.ts +++ b/test/complete_awakeable.test.ts @@ -40,7 +40,7 @@ class ResolveAwakeableGreeter implements TestGreeter { describe("ResolveAwakeableGreeter", () => { it("sends message to runtime", async () => { const result = await new TestDriver(new ResolveAwakeableGreeter("hello"), [ - startMessage(), + startMessage({}), inputMessage(greetRequest("Till")), ]).run(); @@ -54,7 +54,7 @@ describe("ResolveAwakeableGreeter", () => { it("resolve with undefined value", async () => { const result = await new TestDriver( new ResolveAwakeableGreeter(undefined), - [startMessage(), inputMessage(greetRequest("Till"))] + [startMessage({}), inputMessage(greetRequest("Till"))] ).run(); expect(result).toStrictEqual([ @@ -66,7 +66,7 @@ describe("ResolveAwakeableGreeter", () => { it("sends message to runtime for empty string", async () => { const result = await new TestDriver(new ResolveAwakeableGreeter(""), [ - startMessage(), + startMessage({}), inputMessage(greetRequest("Till")), ]).run(); @@ -79,7 +79,7 @@ describe("ResolveAwakeableGreeter", () => { it("handles replay with value", async () => { const result = await new TestDriver(new ResolveAwakeableGreeter("hello"), [ - startMessage(), + startMessage({}), inputMessage(greetRequest("Till")), resolveAwakeableMessage(getAwakeableId(1), "hello"), ]).run(); @@ -92,7 +92,7 @@ describe("ResolveAwakeableGreeter", () => { it("handles replay with value empty string", async () => { const result = await new TestDriver(new ResolveAwakeableGreeter(""), [ - startMessage(), + startMessage({}), inputMessage(greetRequest("Till")), resolveAwakeableMessage(getAwakeableId(1), ""), ]).run(); @@ -105,7 +105,7 @@ describe("ResolveAwakeableGreeter", () => { it("fails on journal mismatch. Completed with invoke during replay.", async () => { const result = await new TestDriver(new ResolveAwakeableGreeter("hello"), [ - startMessage(), + startMessage({}), inputMessage(greetRequest("Till")), invokeMessage( "test.TestGreeter", @@ -121,7 +121,7 @@ describe("ResolveAwakeableGreeter", () => { it("fails on journal mismatch. Completed with wrong id.", async () => { const result = await new TestDriver(new ResolveAwakeableGreeter("hello"), [ - startMessage(2), + startMessage({ knownEntries: 2 }), inputMessage(greetRequest("Till")), resolveAwakeableMessage( "1234", // this should have been getAwakeableId(1) @@ -149,7 +149,7 @@ describe("RejectAwakeableGreeter", () => { it("sends message to runtime", async () => { const result = await new TestDriver( new RejectAwakeableGreeter("my bad error"), - [startMessage(), inputMessage(greetRequest("Till"))] + [startMessage({}), inputMessage(greetRequest("Till"))] ).run(); expect(result).toStrictEqual([ diff --git a/test/eager_state.test.ts b/test/eager_state.test.ts index e4fb71ed..8c0da257 100644 --- a/test/eager_state.test.ts +++ b/test/eager_state.test.ts @@ -33,7 +33,7 @@ import { startMessage, suspensionMessage, } from "./protoutils"; -import { ProtocolMode } from "../src/generated/proto/discovery"; +import { ProtocolMode } from "../src/types/discovery"; const input = inputMessage(greetRequest("Two")); const COMPLETE_STATE = false; @@ -50,7 +50,7 @@ describe("GetEmpty", () => { it("handles complete state without key present", async () => { const result = await new TestDriver( new GetEmpty(), - [startMessage(1, COMPLETE_STATE), input], + [startMessage({ knownEntries: 1, partialState: COMPLETE_STATE }), input], ProtocolMode.BIDI_STREAM ).run(); @@ -64,7 +64,7 @@ describe("GetEmpty", () => { it("handles partial state without key present ", async () => { const result = await new TestDriver( new GetEmpty(), - [startMessage(1), input], + [startMessage({ knownEntries: 1 }), input], ProtocolMode.BIDI_STREAM ).run(); @@ -77,7 +77,11 @@ describe("GetEmpty", () => { it("handles replay of partial state", async () => { const result = await new TestDriver( new GetEmpty(), - [startMessage(2), input, getStateMessage("STATE", undefined, true)], + [ + startMessage({ knownEntries: 2 }), + input, + getStateMessage("STATE", undefined, true), + ], ProtocolMode.BIDI_STREAM ).run(); @@ -100,7 +104,14 @@ describe("Get", () => { it("handles complete state with key present", async () => { const result = await new TestDriver( new Get(), - [startMessage(1, COMPLETE_STATE, [keyVal("STATE", "One")]), input], + [ + startMessage({ + knownEntries: 1, + partialState: COMPLETE_STATE, + state: [keyVal("STATE", "One")], + }), + input, + ], ProtocolMode.BIDI_STREAM ).run(); @@ -114,7 +125,10 @@ describe("Get", () => { it("handles partial state with key present ", async () => { const result = await new TestDriver( new Get(), - [startMessage(1, undefined, [keyVal("STATE", "One")]), input], + [ + startMessage({ knownEntries: 1, state: [keyVal("STATE", "One")] }), + input, + ], ProtocolMode.BIDI_STREAM ).run(); @@ -128,7 +142,7 @@ describe("Get", () => { it("handles partial state without key present", async () => { const result = await new TestDriver( new Get(), - [startMessage(2), input], + [startMessage({ knownEntries: 2 }), input], ProtocolMode.BIDI_STREAM ).run(); diff --git a/test/get_and_set_state.test.ts b/test/get_and_set_state.test.ts index 7c4356c3..7fab68c7 100644 --- a/test/get_and_set_state.test.ts +++ b/test/get_and_set_state.test.ts @@ -32,7 +32,7 @@ import { startMessage, suspensionMessage, } from "./protoutils"; -import { ProtocolMode } from "../src/generated/proto/discovery"; +import { ProtocolMode } from "../src/types/discovery"; class GetAndSetGreeter implements TestGreeter { async greet( diff --git a/test/get_state.test.ts b/test/get_state.test.ts index 09e7ddb8..1876cf03 100644 --- a/test/get_state.test.ts +++ b/test/get_state.test.ts @@ -26,8 +26,8 @@ import { startMessage, suspensionMessage, } from "./protoutils"; -import { ProtocolMode } from "../src/generated/proto/discovery"; import { TestDriver, TestGreeter, TestResponse } from "./testdriver"; +import { ProtocolMode } from "../src/types/discovery"; class GetStringStateGreeter implements TestGreeter { async greet(ctx: restate.KeyedContext): Promise { diff --git a/test/protoutils.ts b/test/protoutils.ts index 6a4a48a8..15634b60 100644 --- a/test/protoutils.ts +++ b/test/protoutils.ts @@ -67,13 +67,20 @@ import { jsonSerialize, formatMessageAsJson } from "../src/utils/utils"; import { rlog } from "../src/logger"; import { ErrorCodes, RestateErrorCodes } from "../src/types/errors"; import { SUPPORTED_PROTOCOL_VERSION } from "../src/io/decoder"; -import { RpcRequest, RpcResponse } from "../src/generated/proto/dynrpc"; -export function startMessage( - knownEntries?: number, - partialState?: boolean, - state?: Buffer[][] -): Message { +export type StartMessageOpts = { + knownEntries?: number; + partialState?: boolean; + state?: Buffer[][]; + key?: string; +}; + +export function startMessage({ + knownEntries, + partialState, + state, + key, +}: StartMessageOpts): Message { return new Message( START_MESSAGE_TYPE, StartMessage.create({ @@ -85,6 +92,7 @@ export function startMessage( knownEntries: knownEntries, // only used for the Lambda case. For bidi streaming, this will be imputed by the testdriver stateMap: toStateEntries(state || []), partialState: partialState !== false, + key: key ?? "Till", }), undefined, SUPPORTED_PROTOCOL_VERSION, @@ -338,7 +346,8 @@ export function invokeMessage( methodName: string, parameter: Uint8Array, value?: Uint8Array, - failure?: Failure + failure?: Failure, + key?: string ): Message { if (value != undefined) { return new Message( @@ -348,6 +357,7 @@ export function invokeMessage( methodName: methodName, parameter: Buffer.from(parameter), value: Buffer.from(value), + key, }) ); } else if (failure != undefined) { @@ -358,6 +368,7 @@ export function invokeMessage( methodName: methodName, parameter: Buffer.from(parameter), failure: failure, + key, }) ); } else { @@ -367,6 +378,7 @@ export function invokeMessage( serviceName: serviceName, methodName: methodName, parameter: Buffer.from(parameter), + key, }) ); } @@ -376,7 +388,8 @@ export function backgroundInvokeMessage( serviceName: string, methodName: string, parameter: Uint8Array, - invokeTime?: number + invokeTime?: number, + key?: string ): Message { return invokeTime ? new Message( @@ -386,6 +399,7 @@ export function backgroundInvokeMessage( methodName: methodName, parameter: Buffer.from(parameter), invokeTime: invokeTime, + key, }) ) : new Message( @@ -518,20 +532,13 @@ export function failureWithTerminal( } export function greetRequest(myName: string): Uint8Array { - const rpc = RpcRequest.create({ - key: myName, - request: { name: myName }, - }); - - return RpcRequest.encode(rpc).finish(); + const str = JSON.stringify(myName); + return Buffer.from(str); } export function greetResponse(myGreeting: string): Uint8Array { - const rpc = RpcResponse.create({ - response: { greeting: myGreeting }, - }); - - return RpcResponse.encode(rpc).finish(); + const str = JSON.stringify(myGreeting); + return Buffer.from(str); } export function checkError( diff --git a/test/sleep.test.ts b/test/sleep.test.ts index cf01df9a..4e763759 100644 --- a/test/sleep.test.ts +++ b/test/sleep.test.ts @@ -30,7 +30,7 @@ import { import { SLEEP_ENTRY_MESSAGE_TYPE } from "../src/types/protocol"; import { Empty } from "../src/generated/google/protobuf/empty"; import { TestDriver, TestGreeter, TestResponse } from "./testdriver"; -import { ProtocolMode } from "../src/generated/proto/discovery"; +import { ProtocolMode } from "../src/types/discovery"; const wakeupTime = 1835661783000; diff --git a/test/state_machine.test.ts b/test/state_machine.test.ts index 46cc2fbe..d2047b11 100644 --- a/test/state_machine.test.ts +++ b/test/state_machine.test.ts @@ -31,7 +31,7 @@ class Greeter implements TestGreeter { describe("Greeter", () => { it("sends message to runtime", async () => { const result = await new TestDriver(new Greeter(), [ - startMessage(1), + startMessage({ knownEntries: 1, key: "Pete" }), inputMessage(greetRequest("Pete")), ]).run(); @@ -43,7 +43,7 @@ describe("Greeter", () => { it("handles replay of output message", async () => { const result = await new TestDriver(new Greeter(), [ - startMessage(2), + startMessage({ knownEntries: 2, key: "Pete" }), inputMessage(greetRequest("Pete")), outputMessage(greetResponse("Hello")), ]).run(); @@ -53,7 +53,7 @@ describe("Greeter", () => { it("fails invocation if input is failed", async () => { const result = await new TestDriver(new Greeter(), [ - startMessage(1), + startMessage({ knownEntries: 1 }), inputMessage(undefined, failure("Canceled")), ]).run(); diff --git a/test/testdriver.ts b/test/testdriver.ts index 44fd89a2..88cd9713 100644 --- a/test/testdriver.ts +++ b/test/testdriver.ts @@ -18,14 +18,13 @@ import { import { Connection } from "../src/connection/connection"; import { formatMessageAsJson } from "../src/utils/utils"; import { Message } from "../src/types/types"; -import { HostedGrpcServiceMethod } from "../src/types/grpc"; -import { ProtocolMode } from "../src/generated/proto/discovery"; import { rlog } from "../src/logger"; import { StateMachine } from "../src/state_machine"; import { InvocationBuilder } from "../src/invocation"; import { EndpointImpl } from "../src/endpoint/endpoint_impl"; import { KeyedContext, ServiceApi } from "../src/context"; import { keyedRouter } from "../src/public_api"; +import { ProtocolMode } from "../src/types/discovery"; export type TestRequest = { name: string; @@ -49,12 +48,11 @@ export interface TestGreeter { greet(ctx: KeyedContext, message: TestRequest): Promise; } -export class TestDriver implements Connection { +export class TestDriver implements Connection { private readonly result: Message[] = []; private restateServer: TestRestateServer; - private method: HostedGrpcServiceMethod; - private stateMachine: StateMachine; + private stateMachine: StateMachine; private completionMessages: Message[]; constructor( @@ -72,23 +70,6 @@ export class TestDriver implements Connection { this.restateServer.bindKeyedRouter(GreeterApi.path, svc); - // this.restateServer.bindService({ - // descriptor: protoMetadata, - // service: "TestGreeter", - // instance: instance, - // }); - - const methodName = "/greeter/greet"; - - const hostedGrpcServiceMethod: HostedGrpcServiceMethod | undefined = - this.restateServer.methodByUrl("/invoke" + methodName); - - if (hostedGrpcServiceMethod) { - this.method = hostedGrpcServiceMethod; - } else { - throw new Error("Method not found: " + methodName); - } - if (entries.length < 2) { throw new Error( "Less than two runtime messages supplied for test. Need to have at least start message and input message." @@ -121,6 +102,7 @@ export class TestDriver implements Connection { knownEntries: endOfReplay - 1, stateMap: startEntry.stateMap, partialState: startEntry.partialState, + key: startEntry.key, }), msg.completed, msg.protocolVersion, @@ -154,7 +136,19 @@ export class TestDriver implements Connection { ); } - const invocationBuilder = new InvocationBuilder(this.method); + const method = this.restateServer + .componenetByName("greeter") + ?.handlerMatching({ + componenetName: "greeter", + handlerName: "greet", + key: startEntry.key, + }); + + if (!method) { + throw new Error("Something is wrong with the test setup"); + } + + const invocationBuilder = new InvocationBuilder(method); replayMessages.forEach((el) => invocationBuilder.handleMessage(el)); const invocation = invocationBuilder.build(); @@ -220,10 +214,4 @@ export class TestDriver implements Connection { * make it simpler for users to understand what methods are relevant for them, * and which ones are not. */ -class TestRestateServer extends EndpointImpl { - public methodByUrl( - url: string | null | undefined - ): HostedGrpcServiceMethod | undefined { - return super.methodByUrl(url); - } -} +class TestRestateServer extends EndpointImpl {} diff --git a/test/utils.test.ts b/test/utils.test.ts index e31f1b93..274f48e4 100644 --- a/test/utils.test.ts +++ b/test/utils.test.ts @@ -12,7 +12,6 @@ import { describe, expect } from "@jest/globals"; import { jsonDeserialize, - jsonSafeAny, jsonSerialize, formatMessageAsJson, } from "../src/utils/utils"; @@ -128,75 +127,3 @@ describe("rand", () => { expect(actual).toStrictEqual(expected); }); }); - -describe("jsonSafeAny", () => { - it("handles dates", () => { - expect(jsonSafeAny("", new Date(1701878170682))).toStrictEqual( - "2023-12-06T15:56:10.682Z" - ); - expect(jsonSafeAny("", { date: new Date(1701878170682) })).toStrictEqual({ - date: "2023-12-06T15:56:10.682Z", - }); - expect( - jsonSafeAny("", { - dates: [new Date(1701878170682), new Date(1701878170683)], - }) - ).toStrictEqual({ - dates: ["2023-12-06T15:56:10.682Z", "2023-12-06T15:56:10.683Z"], - }); - }); - it("handles urls", () => { - expect(jsonSafeAny("", new URL("https://restate.dev"))).toStrictEqual( - "https://restate.dev/" - ); - }); - it("handles patched BigInts", () => { - // by default should do nothing - expect(jsonSafeAny("", BigInt("9007199254740991"))).toStrictEqual( - BigInt("9007199254740991") - ); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - (BigInt.prototype as any).toJSON = function () { - return this.toString(); - }; - expect(jsonSafeAny("", BigInt("9007199254740991"))).toStrictEqual( - "9007199254740991" - ); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - delete (BigInt.prototype as any).toJSON; - }); - it("handles custom types", () => { - const numberType = { - toJSON(): number { - return 1; - }, - }; - const stringType = { - toJSON(): string { - return "foo"; - }, - }; - expect(jsonSafeAny("", numberType)).toStrictEqual(1); - expect(jsonSafeAny("", stringType)).toStrictEqual("foo"); - }); - it("provides the correct key", () => { - const keys: string[] = []; - const typ = { - toJSON(key: string): string { - keys.push(key); - return ""; - }, - }; - expect(jsonSafeAny("", typ)).toStrictEqual(""); - expect(jsonSafeAny("", { key: typ })).toStrictEqual({ key: "" }); - expect(jsonSafeAny("", { key: [typ] })).toStrictEqual({ key: [""] }); - expect(jsonSafeAny("", { key: [0, typ] })).toStrictEqual({ key: [0, ""] }); - expect(jsonSafeAny("", { key: [0, { key2: typ }] })).toStrictEqual({ - key: [0, { key2: "" }], - }); - expect(jsonSafeAny("", { key: [0, { key2: [typ] }] })).toStrictEqual({ - key: [0, { key2: [""] }], - }); - expect(keys).toStrictEqual(["", "key", "0", "1", "key2", "0"]); - }); -});