diff --git a/encoding/encoding.go b/encoding/encoding.go new file mode 100644 index 0000000..c7dd5eb --- /dev/null +++ b/encoding/encoding.go @@ -0,0 +1,70 @@ +package encoding + +import ( + "encoding/json" + + "google.golang.org/protobuf/proto" +) + +type InputPayload struct { + Required bool `json:"required"` + ContentType *string `json:"contentType,omitempty"` + JsonSchema interface{} `json:"jsonSchema,omitempty"` +} + +type OutputPayload struct { + ContentType *string `json:"contentType,omitempty"` + SetContentTypeIfEmpty bool `json:"setContentTypeIfEmpty"` + JsonSchema interface{} `json:"jsonSchema,omitempty"` +} + +type JSONDecoder[I any] struct{} + +func (j JSONDecoder[I]) InputPayload() *InputPayload { + return &InputPayload{Required: true, ContentType: proto.String("application/json")} +} + +func (j JSONDecoder[I]) Decode(data []byte) (input I, err error) { + err = json.Unmarshal(data, &input) + return +} + +type JSONEncoder[O any] struct{} + +func (j JSONEncoder[O]) OutputPayload() *OutputPayload { + return &OutputPayload{ContentType: proto.String("application/json")} +} + +func (j JSONEncoder[O]) Encode(output O) ([]byte, error) { + return json.Marshal(output) +} + +type MessagePointer[I any] interface { + proto.Message + *I +} + +type ProtoDecoder[I any, IP MessagePointer[I]] struct{} + +func (p ProtoDecoder[I, IP]) InputPayload() *InputPayload { + return &InputPayload{Required: true, ContentType: proto.String("application/proto")} +} + +func (p ProtoDecoder[I, IP]) Decode(data []byte) (input IP, err error) { + // Unmarshal expects a non-nil pointer to a proto.Message implementing struct + // hence we must have a type parameter for the struct itself (I) and here we allocate + // a non-nil pointer of type IP + input = IP(new(I)) + err = proto.Unmarshal(data, input) + return +} + +type ProtoEncoder[O proto.Message] struct{} + +func (p ProtoEncoder[O]) OutputPayload() *OutputPayload { + return &OutputPayload{ContentType: proto.String("application/proto")} +} + +func (p ProtoEncoder[O]) Encode(output O) ([]byte, error) { + return proto.Marshal(output) +} diff --git a/go.mod b/go.mod index 2ea12cc..e331b1d 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,6 @@ require ( github.com/mr-tron/base58 v1.2.0 github.com/posener/h2conn v0.0.0-20231204025407-3997deeca0f0 github.com/stretchr/testify v1.9.0 - github.com/vmihailenco/msgpack/v5 v5.4.1 golang.org/x/net v0.21.0 google.golang.org/protobuf v1.32.0 ) @@ -16,7 +15,6 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect golang.org/x/text v0.14.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 5e2ee5a..8e4da2f 100644 --- a/go.sum +++ b/go.sum @@ -14,10 +14,6 @@ github.com/posener/h2conn v0.0.0-20231204025407-3997deeca0f0 h1:zZg03nifrj6ayWNa github.com/posener/h2conn v0.0.0-20231204025407-3997deeca0f0/go.mod h1:bblJa8QcHntareAJYfLJUzLj42sUFBKCBeTDK5LyUrw= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8= -github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok= -github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= -github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= diff --git a/handler.go b/handler.go index d27c4a3..af66b28 100644 --- a/handler.go +++ b/handler.go @@ -3,6 +3,8 @@ package restate import ( "encoding/json" "fmt" + + "github.com/restatedev/sdk-go/encoding" ) // Void is a placeholder used usually for functions that their signature require that @@ -10,45 +12,78 @@ import ( // require them type Void struct{} -func (v Void) MarshalJSON() ([]byte, error) { - return []byte("null"), nil +type VoidDecoder struct{} + +func (v VoidDecoder) InputPayload() *encoding.InputPayload { + return &encoding.InputPayload{} +} + +func (v VoidDecoder) Decode(data []byte) (input Void, err error) { + if len(data) > 0 { + err = fmt.Errorf("restate.Void decoder expects no request data") + } + return +} + +type VoidEncoder struct{} + +func (v VoidEncoder) OutputPayload() *encoding.OutputPayload { + return &encoding.OutputPayload{} } -func (v *Void) UnmarshalJSON(_ []byte) error { - return nil +func (v VoidEncoder) Encode(output Void) ([]byte, error) { + return nil, nil } type serviceHandler[I any, O any] struct { - fn ServiceHandlerFn[I, O] + fn ServiceHandlerFn[I, O] + decoder Decoder[I] + encoder Encoder[O] } -// NewServiceHandler create a new handler for a service -func NewServiceHandler[I any, O any](fn ServiceHandlerFn[I, O]) *serviceHandler[I, O] { +// NewJSONServiceHandler create a new handler for a service using JSON encoding +func NewJSONServiceHandler[I any, O any](fn ServiceHandlerFn[I, O]) *serviceHandler[I, O] { return &serviceHandler[I, O]{ - fn: fn, + fn: fn, + decoder: encoding.JSONDecoder[I]{}, + encoder: encoding.JSONEncoder[O]{}, } } -func (h *serviceHandler[I, O]) Call(ctx Context, bytes []byte) ([]byte, error) { - input := new(I) +// NewProtoServiceHandler create a new handler for a service using protobuf encoding +// Input and output type must both be pointers that satisfy proto.Message +func NewProtoServiceHandler[I any, O any, IP encoding.MessagePointer[I], OP encoding.MessagePointer[O]](fn ServiceHandlerFn[IP, OP]) *serviceHandler[IP, OP] { + return &serviceHandler[IP, OP]{ + fn: fn, + decoder: encoding.ProtoDecoder[I, IP]{}, + encoder: encoding.ProtoEncoder[OP]{}, + } +} - if len(bytes) > 0 { - // use the zero value if there is no input data at all - if err := json.Unmarshal(bytes, input); err != nil { - return nil, TerminalError(fmt.Errorf("request doesn't match handler signature: %w", err)) - } +// NewServiceHandlerWithEncoders create a new handler for a service using a custom encoder/decoder implementation +func NewServiceHandlerWithEncoders[I any, O any](fn ServiceHandlerFn[I, O], decoder Decoder[I], encoder Encoder[O]) *serviceHandler[I, O] { + return &serviceHandler[I, O]{ + fn: fn, + decoder: decoder, + encoder: encoder, + } +} + +func (h *serviceHandler[I, O]) Call(ctx Context, bytes []byte) ([]byte, error) { + input, err := h.decoder.Decode(bytes) + if err != nil { + return nil, TerminalError(fmt.Errorf("request could not be decoded into handler input type: %w", err)) } - // we are sure about the fn signature so it's safe to do this output, err := h.fn( ctx, - *input, + input, ) if err != nil { return nil, err } - bytes, err = json.Marshal(output) + bytes, err = h.encoder.Encode(output) if err != nil { return nil, TerminalError(fmt.Errorf("failed to serialize output: %w", err)) } @@ -56,6 +91,14 @@ func (h *serviceHandler[I, O]) Call(ctx Context, bytes []byte) ([]byte, error) { return bytes, nil } +func (h *serviceHandler[I, O]) InputPayload() *encoding.InputPayload { + return h.decoder.InputPayload() +} + +func (h *serviceHandler[I, O]) OutputPayload() *encoding.OutputPayload { + return h.encoder.OutputPayload() +} + func (h *serviceHandler[I, O]) sealed() {} type objectHandler[I any, O any] struct { diff --git a/internal/discovery.go b/internal/discovery.go index 738d789..7734196 100644 --- a/internal/discovery.go +++ b/internal/discovery.go @@ -1,5 +1,7 @@ package internal +import "github.com/restatedev/sdk-go/encoding" + type ProtocolMode string const ( @@ -23,24 +25,12 @@ const ( ServiceHandlerType_SHARED ServiceHandlerType = "SHARED" ) -type InputPayload struct { - Required bool `json:"required"` - ContentType string `json:"contentType"` - JsonSchema interface{} `json:"jsonSchema,omitempty"` -} - -type OutputPayload struct { - ContentType string `json:"contentType"` - SetContentTypeIfEmpty bool `json:"setContentTypeIfEmpty"` - JsonSchema interface{} `json:"jsonSchema,omitempty"` -} - type Handler struct { Name string `json:"name,omitempty"` // If unspecified, defaults to EXCLUSIVE for Virtual Object. This should be unset for Services. - Ty *ServiceHandlerType `json:"ty,omitempty"` - Input *InputPayload `json:"input,omitempty"` - Output *OutputPayload `json:"output,omitempty"` + Ty *ServiceHandlerType `json:"ty,omitempty"` + Input *encoding.InputPayload `json:"input,omitempty"` + Output *encoding.OutputPayload `json:"output,omitempty"` } type Service struct { diff --git a/reflect.go b/reflect.go index 93e1e70..08a77eb 100644 --- a/reflect.go +++ b/reflect.go @@ -4,24 +4,36 @@ import ( "encoding/json" "fmt" "reflect" + + "github.com/restatedev/sdk-go/encoding" + "google.golang.org/protobuf/proto" ) type serviceNamer interface { - Name() string + ServiceName() string } var ( typeOfContext = reflect.TypeFor[Context]() typeOfObjectContext = reflect.TypeFor[ObjectContext]() + typeOfVoid = reflect.TypeFor[Void]() typeOfError = reflect.TypeFor[error]() ) +// Object converts a struct with methods into a Virtual Object where each correctly-typed +// and exported method of the struct will become a handler on the Object. The Object name defaults +// to the name of the struct, but this can be overidden by providing a `ServiceName() string` method. +// The handler name is the name of the method. Handler methods should be of the type `ObjectHandlerFn[I, O]`. +// Input types I will be deserialised from JSON except when they are restate.Void, +// in which case no input bytes or content type may be sent. Output types O will be serialised +// to JSON except when they are restate.Void, in which case no data will be sent and no content type +// set. func Object(object any) *ObjectRouter { typ := reflect.TypeOf(object) val := reflect.ValueOf(object) var name string if sn, ok := object.(serviceNamer); ok { - name = sn.Name() + name = sn.ServiceName() } else { name = reflect.Indirect(val).Type().Name() } @@ -55,22 +67,32 @@ func Object(object any) *ObjectRouter { } router.Handler(mname, &objectReflectHandler{ - fn: method.Func, - receiver: val, - input: mtype.In(2), - output: mtype.Out(0), + reflectHandler{ + fn: method.Func, + receiver: val, + input: mtype.In(2), + output: mtype.Out(0), + }, }) } return router } +// Service converts a struct with methods into a Restate Service where each correctly-typed +// and exported method of the struct will become a handler on the Service. The Service name defaults +// to the name of the struct, but this can be overidden by providing a `ServiceName() string` method. +// The handler name is the name of the method. Handler methods should be of the type `ServiceHandlerFn[I, O]`. +// Input types I will be deserialised from JSON except when they are restate.Void, +// in which case no input bytes or content type may be sent. Output types O will be serialised +// to JSON except when they are restate.Void, in which case no data will be sent and no content type +// set. func Service(service any) *ServiceRouter { typ := reflect.TypeOf(service) val := reflect.ValueOf(service) var name string if sn, ok := service.(serviceNamer); ok { - name = sn.Name() + name = sn.ServiceName() } else { name = reflect.Indirect(val).Type().Name() } @@ -106,30 +128,60 @@ func Service(service any) *ServiceRouter { } router.Handler(mname, &serviceReflectHandler{ - fn: method.Func, - receiver: val, - input: mtype.In(2), - output: mtype.Out(0), + reflectHandler{ + fn: method.Func, + receiver: val, + input: mtype.In(2), + output: mtype.Out(0), + }, }) } return router } -type objectReflectHandler struct { +type reflectHandler struct { fn reflect.Value receiver reflect.Value input reflect.Type output reflect.Type } +func (h *reflectHandler) InputPayload() *encoding.InputPayload { + if h.input == typeOfVoid { + return &encoding.InputPayload{} + } else { + return &encoding.InputPayload{ + Required: true, + ContentType: proto.String("application/json"), + } + } +} + +func (h *reflectHandler) OutputPayload() *encoding.OutputPayload { + if h.output == typeOfVoid { + return &encoding.OutputPayload{} + } else { + return &encoding.OutputPayload{ + ContentType: proto.String("application/json"), + } + } +} + +func (h *reflectHandler) sealed() {} + +type objectReflectHandler struct { + reflectHandler +} + +var _ Handler = (*objectReflectHandler)(nil) + func (h *objectReflectHandler) Call(ctx ObjectContext, bytes []byte) ([]byte, error) { input := reflect.New(h.input) - if len(bytes) > 0 { - // use the zero value if there is no input data at all + if h.input != typeOfVoid { if err := json.Unmarshal(bytes, input.Interface()); err != nil { - return nil, TerminalError(fmt.Errorf("request doesn't match handler signature: %w", err)) + return nil, TerminalError(fmt.Errorf("request could not be decoded into handler input type: %w", err)) } } @@ -146,6 +198,10 @@ func (h *objectReflectHandler) Call(ctx ObjectContext, bytes []byte) ([]byte, er return nil, errI.(error) } + if h.output == typeOfVoid { + return nil, nil + } + bytes, err := json.Marshal(outI) if err != nil { return nil, TerminalError(fmt.Errorf("failed to serialize output: %w", err)) @@ -154,23 +210,17 @@ func (h *objectReflectHandler) Call(ctx ObjectContext, bytes []byte) ([]byte, er return bytes, nil } -func (h *objectReflectHandler) sealed() {} - type serviceReflectHandler struct { - fn reflect.Value - receiver reflect.Value - input reflect.Type - output reflect.Type + reflectHandler } +var _ Handler = (*serviceReflectHandler)(nil) + func (h *serviceReflectHandler) Call(ctx Context, bytes []byte) ([]byte, error) { input := reflect.New(h.input) - if len(bytes) > 0 { - // use the zero value if there is no input data at all - if err := json.Unmarshal(bytes, input.Interface()); err != nil { - return nil, TerminalError(fmt.Errorf("request doesn't match handler signature: %w", err)) - } + if err := json.Unmarshal(bytes, input.Interface()); err != nil { + return nil, TerminalError(fmt.Errorf("request doesn't match handler signature: %w", err)) } // we are sure about the fn signature so it's safe to do this @@ -193,5 +243,3 @@ func (h *serviceReflectHandler) Call(ctx Context, bytes []byte) ([]byte, error) return bytes, nil } - -func (h *serviceReflectHandler) sealed() {} diff --git a/router.go b/router.go index abfc3fc..7f07fd7 100644 --- a/router.go +++ b/router.go @@ -7,10 +7,10 @@ import ( "log/slog" "time" + "github.com/restatedev/sdk-go/encoding" "github.com/restatedev/sdk-go/internal" "github.com/restatedev/sdk-go/internal/futures" "github.com/restatedev/sdk-go/internal/rand" - "github.com/vmihailenco/msgpack/v5" ) var ( @@ -122,6 +122,8 @@ type ServiceHandler interface { type Handler interface { sealed() + InputPayload() *encoding.InputPayload + OutputPayload() *encoding.OutputPayload } type ServiceType string @@ -162,6 +164,16 @@ type ServiceHandlerFn[I any, O any] func(ctx Context, input I) (output O, err er // ObjectHandlerFn signature for object (keyed) handler function type ObjectHandlerFn[I any, O any] func(ctx ObjectContext, input I) (output O, err error) +type Decoder[I any] interface { + InputPayload() *encoding.InputPayload + Decode(data []byte) (input I, err error) +} + +type Encoder[O any] interface { + OutputPayload() *encoding.OutputPayload + Encode(output O) ([]byte, error) +} + // ServiceRouter implements Router type ServiceRouter struct { name string @@ -231,7 +243,7 @@ func (r *ObjectRouter) Type() internal.ServiceType { // GetAs helper function to get a key as specific type. Note that // if there is no associated value with key, an error ErrKeyNotFound is // returned -// it does encoding/decoding of bytes automatically using msgpack +// it does encoding/decoding of bytes automatically using json func GetAs[T any](ctx ObjectContext, key string) (output T, err error) { bytes, err := ctx.Get(key) if err != nil { @@ -243,15 +255,15 @@ func GetAs[T any](ctx ObjectContext, key string) (output T, err error) { return output, ErrKeyNotFound } - err = msgpack.Unmarshal(bytes, &output) + err = json.Unmarshal(bytes, &output) return } // SetAs helper function to set a key value with a generic type T. -// it does encoding/decoding of bytes automatically using msgpack +// it does encoding/decoding of bytes automatically using json func SetAs[T any](ctx ObjectContext, key string, value T) error { - bytes, err := msgpack.Marshal(value) + bytes, err := json.Marshal(value) if err != nil { return err } @@ -261,7 +273,7 @@ func SetAs[T any](ctx ObjectContext, key string, value T) error { } // RunAs helper function runs a run function with specific concrete type as a result -// it does encoding/decoding of bytes automatically using msgpack +// it does encoding/decoding of bytes automatically using json func RunAs[T any](ctx Context, fn func(RunContext) (T, error)) (output T, err error) { bytes, err := ctx.Run(func(ctx RunContext) ([]byte, error) { out, err := fn(ctx) @@ -269,7 +281,7 @@ func RunAs[T any](ctx Context, fn func(RunContext) (T, error)) (output T, err er return nil, err } - bytes, err := msgpack.Marshal(out) + bytes, err := json.Marshal(out) return bytes, TerminalError(err) }) @@ -277,7 +289,7 @@ func RunAs[T any](ctx Context, fn func(RunContext) (T, error)) (output T, err er return output, err } - err = msgpack.Unmarshal(bytes, &output) + err = json.Unmarshal(bytes, &output) return output, TerminalError(err) } diff --git a/server/restate.go b/server/restate.go index 5294d7a..366717f 100644 --- a/server/restate.go +++ b/server/restate.go @@ -104,17 +104,11 @@ func (r *Restate) discover() (resource *internal.Endpoint, err error) { Handlers: make([]internal.Handler, 0, len(router.Handlers())), } - for name := range router.Handlers() { + for name, handler := range router.Handlers() { service.Handlers = append(service.Handlers, internal.Handler{ - Name: name, - Input: &internal.InputPayload{ - Required: false, - ContentType: "application/json", // TODO configurable handler encoding - }, - Output: &internal.OutputPayload{ - SetContentTypeIfEmpty: false, - ContentType: "application/json", - }, + Name: name, + Input: handler.InputPayload(), + Output: handler.OutputPayload(), }) } resource.Services = append(resource.Services, service)