From 7000e54de18e5f7ba2975646244106442f91d685 Mon Sep 17 00:00:00 2001 From: Jack Kleeman Date: Tue, 16 Jul 2024 22:54:23 +0100 Subject: [PATCH] Verif fixes, simplify void (#12) * Fixes for verif tests * Simplify void --- encoding/encoding.go | 168 +++++++++++++++++++------------------- encoding/encoding_test.go | 39 +++++++-- example/utils.go | 3 +- go.mod | 2 + handler.go | 31 ++++--- internal/state/call.go | 18 ++-- internal/state/state.go | 12 +-- reflect.go | 40 ++------- router.go | 8 +- 9 files changed, 169 insertions(+), 152 deletions(-) diff --git a/encoding/encoding.go b/encoding/encoding.go index 12286f1..f10b9c1 100644 --- a/encoding/encoding.go +++ b/encoding/encoding.go @@ -8,26 +8,101 @@ import ( "google.golang.org/protobuf/proto" ) +// Void is a placeholder to signify 'no value' where a type is otherwise needed +type Void struct{} + var ( + // BinaryCodec marshals []byte and unmarshals into *[]byte + // In handlers, it uses a content type of application/octet-stream BinaryCodec PayloadCodec = binaryCodec{} - VoidCodec PayloadCodec = voidCodec{} - ProtoCodec PayloadCodec = protoCodec{} - JSONCodec PayloadCodec = jsonCodec{} - _ PayloadCodec = PairCodec{} + // VoidCodec marshals anything into []byte(nil) and skips unmarshaling + // In handlers, it requires that there is no input content-type and does not set an output content-type + VoidCodec PayloadCodec = voidCodec{} + // ProtoCodec marshals proto.Message and unmarshals into proto.Message or pointers to types that implement proto.Message + // In handlers, it uses a content-type of application/proto + ProtoCodec PayloadCodec = protoCodec{} + // JSONCodec marshals any json.Marshallable type and unmarshals into any json.Unmarshallable type + // In handlers, it uses a content-type of application/json + JSONCodec PayloadCodec = jsonCodec{} + _ RestateMarshaler = Void{} + _ RestateUnmarshaler = Void{} + _ RestateUnmarshaler = &Void{} ) -type Void struct{} +func (v Void) RestateUnmarshal(codec Codec, data []byte) error { + return nil +} + +func (v Void) InputPayload(codec Codec) *InputPayload { + return &InputPayload{} +} + +func (v Void) RestateMarshal(codec Codec) ([]byte, error) { + return nil, nil +} + +func (v Void) OutputPayload(codec Codec) *OutputPayload { + return &OutputPayload{} +} + +type RestateUnmarshaler interface { + RestateUnmarshal(codec Codec, data []byte) error + InputPayload(codec Codec) *InputPayload +} + +func InputPayloadFor(codec PayloadCodec, i any) *InputPayload { + ru, ok := i.(RestateUnmarshaler) + if ok { + return ru.InputPayload(codec) + } + return codec.InputPayload() +} + +func OutputPayloadFor(codec PayloadCodec, o any) *OutputPayload { + ru, ok := o.(RestateMarshaler) + if ok { + return ru.OutputPayload(codec) + } + return codec.OutputPayload() +} + +func RestateMarshalerFor[O any]() (RestateMarshaler, bool) { + var o O + ru, ok := any(o).(RestateMarshaler) + return ru, ok +} + +// RestateMarshaler can be implemented by types that want to control their own marshaling +type RestateMarshaler interface { + RestateMarshal(codec Codec) ([]byte, error) + OutputPayload(codec Codec) *OutputPayload +} type Codec interface { Marshal(v any) ([]byte, error) Unmarshal(data []byte, v any) error } +func Marshal(codec Codec, v any) ([]byte, error) { + if marshaler, ok := v.(RestateMarshaler); ok { + return marshaler.RestateMarshal(codec) + } + return codec.Marshal(v) +} + +func Unmarshal(codec Codec, data []byte, v any) error { + if marshaler, ok := v.(RestateUnmarshaler); ok { + return marshaler.RestateUnmarshal(codec, data) + } + return codec.Unmarshal(data, v) +} + type PayloadCodec interface { - Codec InputPayload() *InputPayload OutputPayload() *OutputPayload + Codec } + type InputPayload struct { Required bool `json:"required"` ContentType *string `json:"contentType,omitempty"` @@ -58,87 +133,16 @@ func (j voidCodec) Marshal(output any) ([]byte, error) { return nil, nil } -type PairCodec struct { - Input PayloadCodec - Output PayloadCodec -} - -func (w PairCodec) InputPayload() *InputPayload { - return w.Input.InputPayload() -} - -func (w PairCodec) OutputPayload() *OutputPayload { - return w.Output.OutputPayload() -} - -func (w PairCodec) Unmarshal(data []byte, v any) error { - return w.Input.Unmarshal(data, v) -} - -func (w PairCodec) Marshal(v any) ([]byte, error) { - return w.Output.Marshal(v) -} - -func MergeCodec(base, overlay PayloadCodec) PayloadCodec { - switch { - case base == nil && overlay == nil: - return nil - case base == nil: - return overlay - case overlay == nil: - return base - } - - basePair, baseOk := base.(PairCodec) - overlayPair, overlayOk := overlay.(PairCodec) - - switch { - case baseOk && overlayOk: - return PairCodec{ - Input: MergeCodec(basePair.Input, overlayPair.Input), - Output: MergeCodec(basePair.Output, overlayPair.Output), - } - case baseOk: - return PairCodec{ - Input: MergeCodec(basePair.Input, overlay), - Output: MergeCodec(basePair.Output, overlay), - } - case overlayOk: - return PairCodec{ - Input: MergeCodec(base, overlayPair.Input), - Output: MergeCodec(base, overlayPair.Output), - } - default: - // just two non-pairs; keep base - return base - } -} - -func PartialVoidCodec[I any, O any]() PayloadCodec { - var input I - var output O - _, inputVoid := any(input).(Void) - _, outputVoid := any(output).(Void) - switch { - case inputVoid && outputVoid: - return VoidCodec - case inputVoid: - return PairCodec{Input: VoidCodec, Output: nil} - case outputVoid: - return PairCodec{Input: nil, Output: VoidCodec} - default: - return nil - } -} - type binaryCodec struct{} func (j binaryCodec) InputPayload() *InputPayload { - return &InputPayload{Required: true, ContentType: proto.String("application/octet-stream")} + // Required false because 0 bytes is a valid input + return &InputPayload{Required: false, ContentType: proto.String("application/octet-stream")} } func (j binaryCodec) OutputPayload() *OutputPayload { - return &OutputPayload{ContentType: proto.String("application/octet-stream")} + // SetContentTypeIfEmpty true because 0 bytes is a valid output + return &OutputPayload{ContentType: proto.String("application/octet-stream"), SetContentTypeIfEmpty: true} } func (j binaryCodec) Unmarshal(data []byte, input any) (err error) { @@ -171,7 +175,7 @@ func (j jsonCodec) OutputPayload() *OutputPayload { } func (j jsonCodec) Unmarshal(data []byte, input any) (err error) { - return json.Unmarshal(data, &input) + return json.Unmarshal(data, input) } func (j jsonCodec) Marshal(output any) ([]byte, error) { diff --git a/encoding/encoding_test.go b/encoding/encoding_test.go index d3bbbae..6a9dd75 100644 --- a/encoding/encoding_test.go +++ b/encoding/encoding_test.go @@ -34,40 +34,67 @@ func checkMessage(t *testing.T, msg *protocol.AwakeableEntryMessage) { func TestProto(t *testing.T) { p := ProtoCodec - _, err := p.Marshal(protocol.AwakeableEntryMessage{Name: "foobar"}) + _, err := Marshal(p, protocol.AwakeableEntryMessage{Name: "foobar"}) if err == nil { t.Fatalf("expected error when marshaling non-pointer proto Message") } - bytes, err := p.Marshal(&protocol.AwakeableEntryMessage{Name: "foobar"}) + bytes, err := Marshal(p, &protocol.AwakeableEntryMessage{Name: "foobar"}) if err != nil { t.Fatal(err) } { msg := &protocol.AwakeableEntryMessage{} - willSucceed(t, p.Unmarshal(bytes, msg)) + willSucceed(t, Unmarshal(p, bytes, msg)) checkMessage(t, msg) } { inner := &protocol.AwakeableEntryMessage{} msg := &inner - willSucceed(t, p.Unmarshal(bytes, msg)) + willSucceed(t, Unmarshal(p, bytes, msg)) checkMessage(t, *msg) } { msg := new(*protocol.AwakeableEntryMessage) - willSucceed(t, p.Unmarshal(bytes, msg)) + willSucceed(t, Unmarshal(p, bytes, msg)) checkMessage(t, *msg) } { var msg *protocol.AwakeableEntryMessage willPanic(t, func() { - p.Unmarshal(bytes, msg) + Unmarshal(p, bytes, msg) }) } +} + +func TestVoid(t *testing.T) { + codecs := map[string]Codec{ + "json": JSONCodec, + "proto": ProtoCodec, + "binary": BinaryCodec, + } + for name, codec := range codecs { + t.Run(name, func(t *testing.T) { + bytes, err := Marshal(codec, Void{}) + if err != nil { + t.Fatal(err) + } + + if bytes != nil { + t.Fatalf("expected bytes to be nil, found %v", bytes) + } + if err := Unmarshal(codec, []byte{1, 2, 3}, &Void{}); err != nil { + t.Fatal(err) + } + + if err := Unmarshal(codec, []byte{1, 2, 3}, Void{}); err != nil { + t.Fatal(err) + } + }) + } } diff --git a/example/utils.go b/example/utils.go index 6d05c7d..bef6af1 100644 --- a/example/utils.go +++ b/example/utils.go @@ -1,6 +1,7 @@ package main import ( + "errors" "fmt" "math/big" @@ -24,7 +25,7 @@ var bigCounter = restate. } bytes, err := restate.GetAs[[]byte](ctx, "counter", restate.WithBinary) - if err != nil && err != restate.ErrKeyNotFound { + if err != nil && !errors.Is(err, restate.ErrKeyNotFound) { return "", err } newCount := big.NewInt(0).Add(big.NewInt(0).SetBytes(bytes), delta) diff --git a/go.mod b/go.mod index 6372916..a613689 100644 --- a/go.mod +++ b/go.mod @@ -2,6 +2,8 @@ module github.com/restatedev/sdk-go go 1.21.0 +toolchain go1.21.12 + require ( github.com/golang-jwt/jwt/v5 v5.2.1 github.com/google/uuid v1.6.0 diff --git a/handler.go b/handler.go index 37a8b6d..e1583f1 100644 --- a/handler.go +++ b/handler.go @@ -11,6 +11,8 @@ import ( // Void is a placeholder to signify 'no value' where a type is otherwise needed. It can be used in several contexts: // 1. Input types for handlers - the request payload codec will default to a encoding.VoidCodec which will reject input at the ingress // 2. Output types for handlers - the response payload codec will default to a encoding.VoidCodec which will send no bytes and set no content-type +// 3. Input for a outgoing Request or Send - no bytes will be sent +// 4. The output type for an outgoing Request - the response body will be ignored. A pointer is also accepted. type Void = encoding.Void type ObjectHandler interface { @@ -62,9 +64,6 @@ func NewServiceHandler[I any, O any](fn ServiceHandlerFn[I, O], options ...Servi for _, opt := range options { opt.beforeServiceHandler(&opts) } - if opts.codec == nil { - opts.codec = encoding.PartialVoidCodec[I, O]() - } return &serviceHandler[I, O]{ fn: fn, options: opts, @@ -73,7 +72,7 @@ func NewServiceHandler[I any, O any](fn ServiceHandlerFn[I, O], options ...Servi func (h *serviceHandler[I, O]) Call(ctx Context, bytes []byte) ([]byte, error) { var input I - if err := h.options.codec.Unmarshal(bytes, &input); err != nil { + if err := encoding.Unmarshal(h.options.codec, bytes, &input); err != nil { return nil, TerminalError(fmt.Errorf("request could not be decoded into handler input type: %w", err), http.StatusBadRequest) } @@ -85,7 +84,7 @@ func (h *serviceHandler[I, O]) Call(ctx Context, bytes []byte) ([]byte, error) { return nil, err } - bytes, err = h.options.codec.Marshal(output) + bytes, err = encoding.Marshal(h.options.codec, output) if err != nil { return nil, TerminalError(fmt.Errorf("failed to serialize output: %w", err)) } @@ -94,11 +93,13 @@ func (h *serviceHandler[I, O]) Call(ctx Context, bytes []byte) ([]byte, error) { } func (h *serviceHandler[I, O]) InputPayload() *encoding.InputPayload { - return h.options.codec.InputPayload() + var i I + return encoding.InputPayloadFor(h.options.codec, i) } func (h *serviceHandler[I, O]) OutputPayload() *encoding.OutputPayload { - return h.options.codec.OutputPayload() + var o O + return encoding.OutputPayloadFor(h.options.codec, o) } func (h *serviceHandler[I, O]) HandlerType() *internal.ServiceHandlerType { @@ -134,9 +135,6 @@ func NewObjectHandler[I any, O any](fn ObjectHandlerFn[I, O], options ...ObjectH for _, opt := range options { opt.beforeObjectHandler(&opts) } - if opts.codec == nil { - opts.codec = encoding.PartialVoidCodec[I, O]() - } return &objectHandler[I, O]{ exclusiveFn: fn, options: opts, @@ -149,9 +147,6 @@ func NewObjectSharedHandler[I any, O any](fn ObjectSharedHandlerFn[I, O], option for _, opt := range options { opt.beforeObjectHandler(&opts) } - if opts.codec == nil { - opts.codec = encoding.PartialVoidCodec[I, O]() - } return &objectHandler[I, O]{ sharedFn: fn, options: opts, @@ -161,7 +156,7 @@ func NewObjectSharedHandler[I any, O any](fn ObjectSharedHandlerFn[I, O], option func (h *objectHandler[I, O]) Call(ctx ObjectContext, bytes []byte) ([]byte, error) { var input I - if err := h.options.codec.Unmarshal(bytes, &input); err != nil { + if err := encoding.Unmarshal(h.options.codec, bytes, &input); err != nil { return nil, TerminalError(fmt.Errorf("request could not be decoded into handler input type: %w", err), http.StatusBadRequest) } @@ -183,7 +178,7 @@ func (h *objectHandler[I, O]) Call(ctx ObjectContext, bytes []byte) ([]byte, err return nil, err } - bytes, err = h.options.codec.Marshal(output) + bytes, err = encoding.Marshal(h.options.codec, output) if err != nil { return nil, TerminalError(fmt.Errorf("failed to serialize output: %w", err)) } @@ -192,11 +187,13 @@ func (h *objectHandler[I, O]) Call(ctx ObjectContext, bytes []byte) ([]byte, err } func (h *objectHandler[I, O]) InputPayload() *encoding.InputPayload { - return h.options.codec.InputPayload() + var i I + return encoding.InputPayloadFor(h.options.codec, i) } func (h *objectHandler[I, O]) OutputPayload() *encoding.OutputPayload { - return h.options.codec.OutputPayload() + var o O + return encoding.OutputPayloadFor(h.options.codec, o) } func (h *objectHandler[I, O]) getOptions() *objectHandlerOptions { diff --git a/internal/state/call.go b/internal/state/call.go index 8c9e5cf..033d5e9 100644 --- a/internal/state/call.go +++ b/internal/state/call.go @@ -6,6 +6,7 @@ import ( "time" restate "github.com/restatedev/sdk-go" + "github.com/restatedev/sdk-go/encoding" "github.com/restatedev/sdk-go/generated/proto/protocol" "github.com/restatedev/sdk-go/internal/errors" "github.com/restatedev/sdk-go/internal/futures" @@ -23,10 +24,15 @@ type serviceCall struct { // RequestFuture makes a call and returns a handle on the response func (c *serviceCall) RequestFuture(input any) (restate.ResponseFuture, error) { - bytes, err := c.options.Codec.Marshal(input) - if err != nil { - return nil, errors.NewTerminalError(fmt.Errorf("failed to marshal RequestFuture input: %w", err)) + var bytes []byte + if (input != restate.Void{}) { + var err error + bytes, err = encoding.Marshal(c.options.Codec, input) + if err != nil { + return nil, errors.NewTerminalError(fmt.Errorf("failed to marshal RequestFuture input: %w", err)) + } } + entry, entryIndex := c.machine.doCall(c.service, c.key, c.method, bytes) return decodingResponseFuture{ @@ -46,8 +52,8 @@ func (d decodingResponseFuture) Response(output any) (err error) { return err } - if err := d.options.Codec.Unmarshal(bytes, output); err != nil { - return errors.NewTerminalError(fmt.Errorf("failed to unmarshal Call response into O: %w", err)) + if err := encoding.Unmarshal(d.options.Codec, bytes, output); err != nil { + return errors.NewTerminalError(fmt.Errorf("failed to unmarshal Call response into output: %w", err)) } return nil @@ -64,7 +70,7 @@ func (c *serviceCall) Request(input any, output any) error { // Send runs a call in the background after delay duration func (c *serviceCall) Send(input any, delay time.Duration) error { - bytes, err := c.options.Codec.Marshal(input) + bytes, err := encoding.Marshal(c.options.Codec, input) if err != nil { return errors.NewTerminalError(fmt.Errorf("failed to marshal Send input: %w", err)) } diff --git a/internal/state/state.go b/internal/state/state.go index 8d6f3cf..c722c4c 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -60,7 +60,7 @@ func (c *Context) Set(key string, value any, opts ...options.SetOption) error { o.Codec = encoding.JSONCodec } - bytes, err := o.Codec.Marshal(value) + bytes, err := encoding.Marshal(o.Codec, value) if err != nil { return errors.NewTerminalError(fmt.Errorf("failed to marshal Set value: %w", err)) } @@ -94,7 +94,7 @@ func (c *Context) Get(key string, output any, opts ...options.GetOption) error { return errors.ErrKeyNotFound } - if err := o.Codec.Unmarshal(bytes, output); err != nil { + if err := encoding.Unmarshal(o.Codec, bytes, output); err != nil { return errors.NewTerminalError(fmt.Errorf("failed to unmarshal Get state into output: %w", err)) } @@ -163,7 +163,7 @@ func (c *Context) Run(fn func(ctx restate.RunContext) (any, error), output any, return nil, err } - bytes, err := o.Codec.Marshal(output) + bytes, err := encoding.Marshal(o.Codec, output) if err != nil { return nil, errors.NewTerminalError(fmt.Errorf("failed to marshal Run output: %w", err)) } @@ -174,7 +174,7 @@ func (c *Context) Run(fn func(ctx restate.RunContext) (any, error), output any, return err } - if err := o.Codec.Unmarshal(bytes, output); err != nil { + if err := encoding.Unmarshal(o.Codec, bytes, output); err != nil { return errors.NewTerminalError(fmt.Errorf("failed to unmarshal Run output: %w", err)) } @@ -211,7 +211,7 @@ func (d decodingAwakeable) Result(output any) (err error) { if err != nil { return err } - if err := d.codec.Unmarshal(bytes, output); err != nil { + if err := encoding.Unmarshal(d.codec, bytes, output); err != nil { return errors.NewTerminalError(fmt.Errorf("failed to unmarshal Awakeable result into output: %w", err)) } return @@ -225,7 +225,7 @@ func (c *Context) ResolveAwakeable(id string, value any, opts ...options.Resolve if o.Codec == nil { o.Codec = encoding.JSONCodec } - bytes, err := o.Codec.Marshal(value) + bytes, err := encoding.Marshal(o.Codec, value) if err != nil { return errors.NewTerminalError(fmt.Errorf("failed to marshal ResolveAwakeable value: %w", err)) } diff --git a/reflect.go b/reflect.go index 24c5b4e..41f363f 100644 --- a/reflect.go +++ b/reflect.go @@ -78,20 +78,8 @@ func Object(object any, options ...ObjectRouterOption) *ObjectRouter { input := mtype.In(2) output := mtype.Out(0) - var codec encoding.PayloadCodec - switch { - case input == typeOfVoid && output == typeOfVoid: - codec = encoding.VoidCodec - case input == typeOfVoid: - codec = encoding.PairCodec{Input: encoding.VoidCodec, Output: nil} - case output == typeOfVoid: - codec = encoding.PairCodec{Input: nil, Output: encoding.VoidCodec} - default: - codec = nil - } - router.Handler(mname, &objectReflectHandler{ - objectHandlerOptions{codec}, + objectHandlerOptions{}, handlerType, reflectHandler{ fn: method.Func, @@ -156,20 +144,8 @@ func Service(service any, options ...ServiceRouterOption) *ServiceRouter { input := mtype.In(2) output := mtype.Out(0) - var codec encoding.PayloadCodec - switch { - case input == typeOfVoid && output == typeOfVoid: - codec = encoding.VoidCodec - case input == typeOfVoid: - codec = encoding.PairCodec{Input: encoding.VoidCodec, Output: nil} - case output == typeOfVoid: - codec = encoding.PairCodec{Input: nil, Output: encoding.VoidCodec} - default: - codec = nil - } - router.Handler(mname, &serviceReflectHandler{ - serviceHandlerOptions{codec: codec}, + serviceHandlerOptions{}, reflectHandler{ fn: method.Func, receiver: val, @@ -202,7 +178,7 @@ var _ ObjectHandler = (*objectReflectHandler)(nil) func (h *objectReflectHandler) Call(ctx ObjectContext, bytes []byte) ([]byte, error) { input := reflect.New(h.input) - if err := h.options.codec.Unmarshal(bytes, input.Interface()); err != nil { + if err := encoding.Unmarshal(h.options.codec, bytes, input.Interface()); err != nil { return nil, TerminalError(fmt.Errorf("request could not be decoded into handler input type: %w", err), http.StatusBadRequest) } @@ -219,7 +195,7 @@ func (h *objectReflectHandler) Call(ctx ObjectContext, bytes []byte) ([]byte, er return nil, errI.(error) } - bytes, err := h.options.codec.Marshal(outI) + bytes, err := encoding.Marshal(h.options.codec, outI) if err != nil { return nil, TerminalError(fmt.Errorf("failed to serialize output: %w", err)) } @@ -232,11 +208,11 @@ func (h *objectReflectHandler) getOptions() *objectHandlerOptions { } func (h *objectReflectHandler) InputPayload() *encoding.InputPayload { - return h.options.codec.InputPayload() + return encoding.InputPayloadFor(h.options.codec, reflect.Zero(h.input).Interface()) } func (h *objectReflectHandler) OutputPayload() *encoding.OutputPayload { - return h.options.codec.OutputPayload() + return encoding.OutputPayloadFor(h.options.codec, reflect.Zero(h.output).Interface()) } func (h *objectReflectHandler) HandlerType() *internal.ServiceHandlerType { @@ -253,7 +229,7 @@ var _ ServiceHandler = (*serviceReflectHandler)(nil) func (h *serviceReflectHandler) Call(ctx Context, bytes []byte) ([]byte, error) { input := reflect.New(h.input) - if err := h.options.codec.Unmarshal(bytes, input.Interface()); err != nil { + if err := encoding.Unmarshal(h.options.codec, bytes, input.Interface()); err != nil { return nil, TerminalError(fmt.Errorf("request could not be decoded into handler input type: %w", err), http.StatusBadRequest) } @@ -270,7 +246,7 @@ func (h *serviceReflectHandler) Call(ctx Context, bytes []byte) ([]byte, error) return nil, errI.(error) } - bytes, err := h.options.codec.Marshal(outI) + bytes, err := encoding.Marshal(h.options.codec, outI) if err != nil { return nil, TerminalError(fmt.Errorf("failed to serialize output: %w", err)) } diff --git a/router.go b/router.go index 9cd52b4..5f93006 100644 --- a/router.go +++ b/router.go @@ -52,7 +52,9 @@ func (r *ServiceRouter) Name() string { // Handler registers a new handler by name func (r *ServiceRouter) Handler(name string, handler ServiceHandler) *ServiceRouter { - handler.getOptions().codec = encoding.MergeCodec(handler.getOptions().codec, r.options.defaultCodec) + if handler.getOptions().codec == nil { + handler.getOptions().codec = r.options.defaultCodec + } r.handlers[name] = handler return r } @@ -102,7 +104,9 @@ func (r *ObjectRouter) Name() string { } func (r *ObjectRouter) Handler(name string, handler ObjectHandler) *ObjectRouter { - handler.getOptions().codec = encoding.MergeCodec(handler.getOptions().codec, r.options.defaultCodec) + if handler.getOptions().codec == nil { + handler.getOptions().codec = r.options.defaultCodec + } r.handlers[name] = handler return r }