Skip to content

Commit

Permalink
Verif fixes, simplify void (#12)
Browse files Browse the repository at this point in the history
* Fixes for verif tests

* Simplify void
  • Loading branch information
jackkleeman authored Jul 16, 2024
1 parent e6f1b63 commit 7000e54
Show file tree
Hide file tree
Showing 9 changed files with 169 additions and 152 deletions.
168 changes: 86 additions & 82 deletions encoding/encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,101 @@ import (
"google.golang.org/protobuf/proto"
)

// Void is a placeholder to signify 'no value' where a type is otherwise needed
type Void struct{}

var (
// BinaryCodec marshals []byte and unmarshals into *[]byte
// In handlers, it uses a content type of application/octet-stream
BinaryCodec PayloadCodec = binaryCodec{}
VoidCodec PayloadCodec = voidCodec{}
ProtoCodec PayloadCodec = protoCodec{}
JSONCodec PayloadCodec = jsonCodec{}
_ PayloadCodec = PairCodec{}
// VoidCodec marshals anything into []byte(nil) and skips unmarshaling
// In handlers, it requires that there is no input content-type and does not set an output content-type
VoidCodec PayloadCodec = voidCodec{}
// ProtoCodec marshals proto.Message and unmarshals into proto.Message or pointers to types that implement proto.Message
// In handlers, it uses a content-type of application/proto
ProtoCodec PayloadCodec = protoCodec{}
// JSONCodec marshals any json.Marshallable type and unmarshals into any json.Unmarshallable type
// In handlers, it uses a content-type of application/json
JSONCodec PayloadCodec = jsonCodec{}
_ RestateMarshaler = Void{}
_ RestateUnmarshaler = Void{}
_ RestateUnmarshaler = &Void{}
)

type Void struct{}
func (v Void) RestateUnmarshal(codec Codec, data []byte) error {
return nil
}

func (v Void) InputPayload(codec Codec) *InputPayload {
return &InputPayload{}
}

func (v Void) RestateMarshal(codec Codec) ([]byte, error) {
return nil, nil
}

func (v Void) OutputPayload(codec Codec) *OutputPayload {
return &OutputPayload{}
}

type RestateUnmarshaler interface {
RestateUnmarshal(codec Codec, data []byte) error
InputPayload(codec Codec) *InputPayload
}

func InputPayloadFor(codec PayloadCodec, i any) *InputPayload {
ru, ok := i.(RestateUnmarshaler)
if ok {
return ru.InputPayload(codec)
}
return codec.InputPayload()
}

func OutputPayloadFor(codec PayloadCodec, o any) *OutputPayload {
ru, ok := o.(RestateMarshaler)
if ok {
return ru.OutputPayload(codec)
}
return codec.OutputPayload()
}

func RestateMarshalerFor[O any]() (RestateMarshaler, bool) {
var o O
ru, ok := any(o).(RestateMarshaler)
return ru, ok
}

// RestateMarshaler can be implemented by types that want to control their own marshaling
type RestateMarshaler interface {
RestateMarshal(codec Codec) ([]byte, error)
OutputPayload(codec Codec) *OutputPayload
}

type Codec interface {
Marshal(v any) ([]byte, error)
Unmarshal(data []byte, v any) error
}

func Marshal(codec Codec, v any) ([]byte, error) {
if marshaler, ok := v.(RestateMarshaler); ok {
return marshaler.RestateMarshal(codec)
}
return codec.Marshal(v)
}

func Unmarshal(codec Codec, data []byte, v any) error {
if marshaler, ok := v.(RestateUnmarshaler); ok {
return marshaler.RestateUnmarshal(codec, data)
}
return codec.Unmarshal(data, v)
}

type PayloadCodec interface {
Codec
InputPayload() *InputPayload
OutputPayload() *OutputPayload
Codec
}

type InputPayload struct {
Required bool `json:"required"`
ContentType *string `json:"contentType,omitempty"`
Expand Down Expand Up @@ -58,87 +133,16 @@ func (j voidCodec) Marshal(output any) ([]byte, error) {
return nil, nil
}

type PairCodec struct {
Input PayloadCodec
Output PayloadCodec
}

func (w PairCodec) InputPayload() *InputPayload {
return w.Input.InputPayload()
}

func (w PairCodec) OutputPayload() *OutputPayload {
return w.Output.OutputPayload()
}

func (w PairCodec) Unmarshal(data []byte, v any) error {
return w.Input.Unmarshal(data, v)
}

func (w PairCodec) Marshal(v any) ([]byte, error) {
return w.Output.Marshal(v)
}

func MergeCodec(base, overlay PayloadCodec) PayloadCodec {
switch {
case base == nil && overlay == nil:
return nil
case base == nil:
return overlay
case overlay == nil:
return base
}

basePair, baseOk := base.(PairCodec)
overlayPair, overlayOk := overlay.(PairCodec)

switch {
case baseOk && overlayOk:
return PairCodec{
Input: MergeCodec(basePair.Input, overlayPair.Input),
Output: MergeCodec(basePair.Output, overlayPair.Output),
}
case baseOk:
return PairCodec{
Input: MergeCodec(basePair.Input, overlay),
Output: MergeCodec(basePair.Output, overlay),
}
case overlayOk:
return PairCodec{
Input: MergeCodec(base, overlayPair.Input),
Output: MergeCodec(base, overlayPair.Output),
}
default:
// just two non-pairs; keep base
return base
}
}

func PartialVoidCodec[I any, O any]() PayloadCodec {
var input I
var output O
_, inputVoid := any(input).(Void)
_, outputVoid := any(output).(Void)
switch {
case inputVoid && outputVoid:
return VoidCodec
case inputVoid:
return PairCodec{Input: VoidCodec, Output: nil}
case outputVoid:
return PairCodec{Input: nil, Output: VoidCodec}
default:
return nil
}
}

type binaryCodec struct{}

func (j binaryCodec) InputPayload() *InputPayload {
return &InputPayload{Required: true, ContentType: proto.String("application/octet-stream")}
// Required false because 0 bytes is a valid input
return &InputPayload{Required: false, ContentType: proto.String("application/octet-stream")}
}

func (j binaryCodec) OutputPayload() *OutputPayload {
return &OutputPayload{ContentType: proto.String("application/octet-stream")}
// SetContentTypeIfEmpty true because 0 bytes is a valid output
return &OutputPayload{ContentType: proto.String("application/octet-stream"), SetContentTypeIfEmpty: true}
}

func (j binaryCodec) Unmarshal(data []byte, input any) (err error) {
Expand Down Expand Up @@ -171,7 +175,7 @@ func (j jsonCodec) OutputPayload() *OutputPayload {
}

func (j jsonCodec) Unmarshal(data []byte, input any) (err error) {
return json.Unmarshal(data, &input)
return json.Unmarshal(data, input)
}

func (j jsonCodec) Marshal(output any) ([]byte, error) {
Expand Down
39 changes: 33 additions & 6 deletions encoding/encoding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,40 +34,67 @@ func checkMessage(t *testing.T, msg *protocol.AwakeableEntryMessage) {
func TestProto(t *testing.T) {
p := ProtoCodec

_, err := p.Marshal(protocol.AwakeableEntryMessage{Name: "foobar"})
_, err := Marshal(p, protocol.AwakeableEntryMessage{Name: "foobar"})
if err == nil {
t.Fatalf("expected error when marshaling non-pointer proto Message")
}

bytes, err := p.Marshal(&protocol.AwakeableEntryMessage{Name: "foobar"})
bytes, err := Marshal(p, &protocol.AwakeableEntryMessage{Name: "foobar"})
if err != nil {
t.Fatal(err)
}

{
msg := &protocol.AwakeableEntryMessage{}
willSucceed(t, p.Unmarshal(bytes, msg))
willSucceed(t, Unmarshal(p, bytes, msg))
checkMessage(t, msg)
}

{
inner := &protocol.AwakeableEntryMessage{}
msg := &inner
willSucceed(t, p.Unmarshal(bytes, msg))
willSucceed(t, Unmarshal(p, bytes, msg))
checkMessage(t, *msg)
}

{
msg := new(*protocol.AwakeableEntryMessage)
willSucceed(t, p.Unmarshal(bytes, msg))
willSucceed(t, Unmarshal(p, bytes, msg))
checkMessage(t, *msg)
}

{
var msg *protocol.AwakeableEntryMessage
willPanic(t, func() {
p.Unmarshal(bytes, msg)
Unmarshal(p, bytes, msg)
})
}
}

func TestVoid(t *testing.T) {
codecs := map[string]Codec{
"json": JSONCodec,
"proto": ProtoCodec,
"binary": BinaryCodec,
}
for name, codec := range codecs {
t.Run(name, func(t *testing.T) {
bytes, err := Marshal(codec, Void{})
if err != nil {
t.Fatal(err)
}

if bytes != nil {
t.Fatalf("expected bytes to be nil, found %v", bytes)
}

if err := Unmarshal(codec, []byte{1, 2, 3}, &Void{}); err != nil {
t.Fatal(err)
}

if err := Unmarshal(codec, []byte{1, 2, 3}, Void{}); err != nil {
t.Fatal(err)
}
})
}
}
3 changes: 2 additions & 1 deletion example/utils.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"errors"
"fmt"
"math/big"

Expand All @@ -24,7 +25,7 @@ var bigCounter = restate.
}

bytes, err := restate.GetAs[[]byte](ctx, "counter", restate.WithBinary)
if err != nil && err != restate.ErrKeyNotFound {
if err != nil && !errors.Is(err, restate.ErrKeyNotFound) {
return "", err
}
newCount := big.NewInt(0).Add(big.NewInt(0).SetBytes(bytes), delta)
Expand Down
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ module github.com/restatedev/sdk-go

go 1.21.0

toolchain go1.21.12

require (
github.com/golang-jwt/jwt/v5 v5.2.1
github.com/google/uuid v1.6.0
Expand Down
Loading

0 comments on commit 7000e54

Please sign in to comment.