diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml new file mode 100644 index 0000000..6c73ed2 --- /dev/null +++ b/.github/workflows/test.yaml @@ -0,0 +1,21 @@ +name: Go +on: [push] + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - name: Setup Go + uses: actions/setup-go@v4 + with: + go-version: "1.22.x" + - name: Install dependencies + run: go get . + - name: Vet + run: go vet -v ./... + - name: Build + run: go build -v ./... + - name: Test with the Go CLI + run: go test -v ./... diff --git a/.gitignore b/.gitignore index 08cb523..e69de29 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +0,0 @@ -go.sum diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..2787314 --- /dev/null +++ b/go.sum @@ -0,0 +1,34 @@ +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/posener/h2conn v0.0.0-20231204025407-3997deeca0f0 h1:zZg03nifrj6ayWNaDO8tNj57tqrOIKDwiBaLkhtK7Kk= +github.com/posener/h2conn v0.0.0-20231204025407-3997deeca0f0/go.mod h1:bblJa8QcHntareAJYfLJUzLj42sUFBKCBeTDK5LyUrw= +github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= +github.com/rs/zerolog v1.32.0 h1:keLypqrlIjaFsbmJOBdB/qvyF8KEtCWHwobLp5l/mQ0= +github.com/rs/zerolog v1.32.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= +golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= +golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +google.golang.org/protobuf v1.32.0 h1:pPC6BG5ex8PDFnkbrGU3EixyhKcQ2aDuBS36lqK/C7I= +google.golang.org/protobuf v1.32.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/handler.go b/handler.go index 96c0e88..2d55b9d 100644 --- a/handler.go +++ b/handler.go @@ -9,8 +9,9 @@ import ( "google.golang.org/protobuf/types/known/structpb" ) -// Empty struct used as a quick shortcut when you don't care about -// the input or output type +// Void is a placeholder used usually for functions that their signature require that +// you accept an input or return an output but the function implementation does not +// require them type Void struct{} func (v Void) MarshalJSON() ([]byte, error) { diff --git a/internal/state/call.go b/internal/state/call.go index 0922135..8e9eacb 100644 --- a/internal/state/call.go +++ b/internal/state/call.go @@ -1,6 +1,7 @@ package state import ( + "bytes" "encoding/json" "fmt" "time" @@ -69,18 +70,39 @@ func (c *Machine) makeRequest(key string, body any) ([]byte, error) { } func (c *Machine) doCall(service, method, key string, body any) ([]byte, error) { - c.mutex.Lock() - defer c.mutex.Unlock() - - input, err := c.makeRequest(key, body) + params, err := c.makeRequest(key, body) if err != nil { return nil, err } - err = c.protocol.Write(&protocol.InvokeEntryMessage{ + return replayOrNew( + c, + wire.InvokeEntryMessageType, + func(entry *wire.InvokeEntryMessage) ([]byte, error) { + if entry.Payload.ServiceName != service || + entry.Payload.MethodName != method || + !bytes.Equal(entry.Payload.Parameter, params) { + return nil, errEntryMismatch + } + + switch result := entry.Payload.Result.(type) { + case *protocol.InvokeEntryMessage_Failure: + return nil, fmt.Errorf("[%d] %s", result.Failure.Code, result.Failure.Message) + case *protocol.InvokeEntryMessage_Value: + return result.Value, nil + } + + return nil, errUnreachable + }, func() ([]byte, error) { + return c._doCall(service, method, params) + }) +} + +func (c *Machine) _doCall(service, method string, params []byte) ([]byte, error) { + err := c.protocol.Write(&protocol.InvokeEntryMessage{ ServiceName: service, MethodName: method, - Parameter: input, + Parameter: params, }) if err != nil { @@ -121,23 +143,41 @@ func (c *Machine) doCall(service, method, key string, body any) ([]byte, error) } func (c *Machine) sendCall(service, method, key string, body any, delay time.Duration) error { - c.mutex.Lock() - defer c.mutex.Unlock() - - input, err := c.makeRequest(key, body) + params, err := c.makeRequest(key, body) if err != nil { return err } + _, err = replayOrNew( + c, + wire.BackgroundInvokeEntryMessageType, + func(entry *wire.BackgroundInvokeEntryMessage) (restate.Void, error) { + if entry.Payload.ServiceName != service || + entry.Payload.MethodName != method || + !bytes.Equal(entry.Payload.Parameter, params) { + return restate.Void{}, errEntryMismatch + } + + return restate.Void{}, nil + }, + func() (restate.Void, error) { + return restate.Void{}, c._sendCall(service, method, params, delay) + }, + ) + + return err +} + +func (c *Machine) _sendCall(service, method string, params []byte, delay time.Duration) error { var invokeTime uint64 if delay != 0 { invokeTime = uint64(time.Now().Add(delay).UnixMilli()) } - err = c.protocol.Write(&protocol.BackgroundInvokeEntryMessage{ + err := c.protocol.Write(&protocol.BackgroundInvokeEntryMessage{ ServiceName: service, MethodName: method, - Parameter: input, + Parameter: params, InvokeTime: invokeTime, }) diff --git a/internal/state/state.go b/internal/state/state.go index 8a9e24c..08abd51 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -25,8 +25,6 @@ const ( var ( ErrUnexpectedMessage = fmt.Errorf("unexpected message") ErrInvalidVersion = fmt.Errorf("invalid version number") - - errUnreachable = fmt.Errorf("unreachable") ) var ( @@ -61,6 +59,10 @@ func (c *Context) Get(key string) ([]byte, error) { return c.machine.get(key) } +func (c *Context) Keys() ([]string, error) { + return c.machine.keys() +} + func (c *Context) Sleep(until time.Time) error { return c.machine.sleep(until) } @@ -98,7 +100,8 @@ type Machine struct { partial bool current map[string][]byte - entries []wire.Message + entries []wire.Message + entryIndex int } func NewMachine(handler restate.Handler, conn io.ReadWriter) *Machine { @@ -109,11 +112,33 @@ func NewMachine(handler restate.Handler, conn io.ReadWriter) *Machine { } } -func (m *Machine) output(r *dynrpc.RpcResponse, err error) proto.Message { - // TODO: if err is terminal return outputStreamEntryMessage but if error is - // not terminal, return ErrorMessage instead. - //var output protocol.OutputStreamEntryMessage +// Start starts the state machine +func (m *Machine) Start(inner context.Context) error { + // reader starts a rea + msg, err := m.protocol.Read() + if err != nil { + return err + } + + if msg.Type() != wire.StartMessageType { + // invalid negotiation + return ErrUnexpectedMessage + } + + start := msg.(*wire.StartMessage) + if start.Version != Version { + return ErrInvalidVersion + } + + ctx := newContext(inner, m) + + log.Trace().Str("id", base64.URLEncoding.EncodeToString(m.id)).Msg("start invocation") + return m.process(ctx, start) +} + +// handle handler response and build proper response message +func (m *Machine) output(r *dynrpc.RpcResponse, err error) proto.Message { if err != nil && restate.IsTerminalError(err) { // terminal errors. return &protocol.OutputStreamEntryMessage{ @@ -195,14 +220,18 @@ func (m *Machine) process(ctx *Context, start *wire.StartMessage) error { return ErrUnexpectedMessage } - log.Debug().Uint32("known entries", start.Payload.KnownEntries).Msg("known entires") + log.Trace().Uint32("known entries", start.Payload.KnownEntries).Msg("known entires") + m.entries = make([]wire.Message, 0, start.Payload.KnownEntries-1) + + // we don't track the poll input entry for i := uint32(1); i < start.Payload.KnownEntries; i++ { msg, err := m.protocol.Read() if err != nil { return fmt.Errorf("failed to read entry: %w", err) } - log.Debug().Uint16("type", uint16(msg.Type())).Msg("received entry") + log.Trace().Uint16("type", uint16(msg.Type())).Msg("replay log entry") + m.entries = append(m.entries, msg) } inputMsg := msg.(*wire.PollInputEntry) @@ -216,26 +245,57 @@ func (m *Machine) process(ctx *Context, start *wire.StartMessage) error { } -func (m *Machine) Start(inner context.Context) error { - // reader starts a rea - msg, err := m.protocol.Read() - if err != nil { - return err +func (c *Machine) currentEntry() (wire.Message, bool) { + if c.entryIndex < len(c.entries) { + return c.entries[c.entryIndex], true } - if msg.Type() != wire.StartMessageType { - // invalid negotiation - return ErrUnexpectedMessage - } + return nil, false +} - start := msg.(*wire.StartMessage) +// replayOrNew is a utility function to easily either +// replay a log entry, or create a new one if one +// does not exist +// +// this should be an instance method on Machine but unfortunately +// go does not support generics on instance methods +// +// the idea is when called, it will check if there is a log +// entry at current index, then compare the entry message type +// if not matching, that's obviously an error with the code version +// (code has changed and now doesn't match the play log) +// +// if type is okay, the function will then call a `replay“ callback. +// the replay callback just need to extract the result from the entry +// +// otherwise this function will call a `new` callback to create a new entry in the log +// by sending the proper runtime messages +func replayOrNew[M wire.Message, O any]( + m *Machine, + typ wire.Type, + replay func(msg M) (O, error), + new func() (O, error), +) (output O, err error) { + + m.mutex.Lock() + defer m.mutex.Unlock() - if start.Version != Version { - return ErrInvalidVersion - } + defer func() { + m.entryIndex += 1 + }() - ctx := newContext(inner, m) + // check if there is an entry as this index + entry, ok := m.currentEntry() - log.Trace().Str("id", base64.URLEncoding.EncodeToString(m.id)).Msg("start invocation") - return m.process(ctx, start) + // if entry exists, we need to replay it + // by calling the replay function + if ok { + if entry.Type() != typ { + return output, errEntryMismatch + } + return replay(entry.(M)) + } + + // other wise call the new function + return new() } diff --git a/internal/state/sys.go b/internal/state/sys.go index 897bb42..e64543c 100644 --- a/internal/state/sys.go +++ b/internal/state/sys.go @@ -1,17 +1,40 @@ package state import ( + "bytes" "fmt" "time" + "github.com/muhamadazmy/restate-sdk-go" "github.com/muhamadazmy/restate-sdk-go/generated/proto/protocol" "github.com/muhamadazmy/restate-sdk-go/internal/wire" + "github.com/rs/zerolog/log" + "google.golang.org/protobuf/proto" +) + +var ( + errEntryMismatch = restate.WithErrorCode(fmt.Errorf("log entry mismatch"), 32) + errUnreachable = fmt.Errorf("unreachable") ) func (c *Machine) set(key string, value []byte) error { - c.mutex.Lock() - defer c.mutex.Unlock() + _, err := replayOrNew( + c, + wire.SetStateEntryMessageType, + func(entry *wire.SetStateEntryMessage) (void restate.Void, err error) { + if string(entry.Payload.Key) != key || !bytes.Equal(entry.Payload.Value, value) { + return void, errEntryMismatch + } + + return + }, func() (void restate.Void, err error) { + return void, c._set(key, value) + }) + return err +} + +func (c *Machine) _set(key string, value []byte) error { c.current[key] = value return c.protocol.Write( @@ -22,9 +45,24 @@ func (c *Machine) set(key string, value []byte) error { } func (c *Machine) clear(key string) error { - c.mutex.Lock() - defer c.mutex.Unlock() + _, err := replayOrNew( + c, + wire.ClearStateEntryMessageType, + func(entry *wire.ClearStateEntryMessage) (void restate.Void, err error) { + if string(entry.Payload.Key) != key { + return void, errEntryMismatch + } + + return void, nil + }, func() (restate.Void, error) { + return restate.Void{}, c._clear(key) + }, + ) + + return err +} +func (c *Machine) _clear(key string) error { return c.protocol.Write( &protocol.ClearStateEntryMessage{ Key: []byte(key), @@ -32,24 +70,56 @@ func (c *Machine) clear(key string) error { ) } -// clearAll drops all associated keys func (c *Machine) clearAll() error { - c.mutex.Lock() - defer c.mutex.Unlock() + _, err := replayOrNew( + c, + wire.ClearAllStateEntryMessageType, + func(entry *wire.ClearAllStateEntryMessage) (void restate.Void, err error) { + return + }, func() (restate.Void, error) { + return restate.Void{}, c._clearAll() + }, + ) + + return err +} +// clearAll drops all associated keys +func (c *Machine) _clearAll() error { return c.protocol.Write( &protocol.ClearAllStateEntryMessage{}, ) } func (c *Machine) get(key string) ([]byte, error) { + return replayOrNew( + c, + wire.GetStateEntryMessageType, + func(entry *wire.GetStateEntryMessage) ([]byte, error) { + if string(entry.Payload.Key) != key { + return nil, errEntryMismatch + } + + switch result := entry.Payload.Result.(type) { + case *protocol.GetStateEntryMessage_Empty: + return nil, nil + case *protocol.GetStateEntryMessage_Failure: + return nil, fmt.Errorf("[%d] %s", result.Failure.Code, result.Failure.Message) + case *protocol.GetStateEntryMessage_Value: + return result.Value, nil + } + + return nil, fmt.Errorf("unreachable") + }, func() ([]byte, error) { + return c._get(key) + }) +} + +func (c *Machine) _get(key string) ([]byte, error) { msg := &protocol.GetStateEntryMessage{ Key: []byte(key), } - c.mutex.Lock() - defer c.mutex.Unlock() - value, ok := c.current[key] if ok { @@ -110,7 +180,86 @@ func (c *Machine) get(key string) ([]byte, error) { return nil, fmt.Errorf("unreachable") } +func (c *Machine) keys() ([]string, error) { + return replayOrNew( + c, + wire.GetStateKeysEntryMessageType, + func(entry *wire.GetStateKeysEntryMessage) ([]string, error) { + switch result := entry.Payload.Result.(type) { + case *protocol.GetStateKeysEntryMessage_Failure: + return nil, fmt.Errorf("[%d] %s", result.Failure.Code, result.Failure.Message) + case *protocol.GetStateKeysEntryMessage_Value: + keys := make([]string, 0, len(result.Value.Keys)) + for _, key := range result.Value.Keys { + keys = append(keys, string(key)) + } + return keys, nil + } + + return nil, errUnreachable + }, + c._keys, + ) +} + +func (c *Machine) _keys() ([]string, error) { + if err := c.protocol.Write(&protocol.GetStateKeysEntryMessage{}); err != nil { + return nil, err + } + + msg, err := c.protocol.Read() + if err != nil { + return nil, err + } + + if msg.Type() != wire.CompletionMessageType { + log.Error().Stringer("type", msg.Type()).Msg("receiving message of type") + return nil, ErrUnexpectedMessage + } + + response := msg.(*wire.CompletionMessage) + + switch value := response.Payload.Result.(type) { + case *protocol.CompletionMessage_Empty: + return nil, nil + case *protocol.CompletionMessage_Failure: + // the get state entry message is not failable so this should + // never happen + return nil, fmt.Errorf("[%d] %s", value.Failure.Code, value.Failure.Message) + case *protocol.CompletionMessage_Value: + var keys protocol.GetStateKeysEntryMessage_StateKeys + + if err := proto.Unmarshal(value.Value, &keys); err != nil { + return nil, err + } + + values := make([]string, 0, len(keys.Keys)) + for _, key := range keys.Keys { + values = append(values, string(key)) + } + + return values, nil + } + + return nil, nil +} + func (c *Machine) sleep(until time.Time) error { + _, err := replayOrNew( + c, + wire.SleepEntryMessageType, + func(entry *wire.SleepEntryMessage) (void restate.Void, err error) { + // we shouldn't verify the time because this would be different every time + return + }, func() (restate.Void, error) { + return restate.Void{}, c._sleep(until) + }, + ) + + return err +} + +func (c *Machine) _sleep(until time.Time) error { if err := c.protocol.Write(&protocol.SleepEntryMessage{ WakeUpTime: uint64(until.UnixMilli()), }); err != nil { diff --git a/internal/wire/wire.go b/internal/wire/wire.go index 584a164..408ceed 100644 --- a/internal/wire/wire.go +++ b/internal/wire/wire.go @@ -35,8 +35,9 @@ const ( // State GetStateEntryMessageType Type = 0x0800 SetStateEntryMessageType Type = 0x0800 + 1 - ClearStateEntryMessageType = 0x0800 + 2 + ClearStateEntryMessageType Type = 0x0800 + 2 ClearAllStateEntryMessageType Type = 0x0800 + 3 + GetStateKeysEntryMessageType Type = 0x0800 + 4 //SysCalls SleepEntryMessageType Type = 0x0C00 @@ -187,6 +188,8 @@ func (s *Protocol) Write(message proto.Message, flags ...Flag) error { typ = InvokeEntryMessageType case *protocol.BackgroundInvokeEntryMessage: typ = BackgroundInvokeEntryMessageType + case *protocol.GetStateKeysEntryMessage: + typ = GetStateKeysEntryMessageType default: return fmt.Errorf("can not send message of unknown message type") } @@ -253,11 +256,60 @@ var ( return msg, proto.Unmarshal(bytes, &msg.Payload) }, + SetStateEntryMessageType: func(header Header, bytes []byte) (Message, error) { + msg := &SetStateEntryMessage{ + Header: header, + } + + return msg, proto.Unmarshal(bytes, &msg.Payload) + }, + ClearStateEntryMessageType: func(header Header, bytes []byte) (Message, error) { + msg := &ClearStateEntryMessage{ + Header: header, + } + + return msg, proto.Unmarshal(bytes, &msg.Payload) + }, + ClearAllStateEntryMessageType: func(header Header, bytes []byte) (Message, error) { + msg := &ClearAllStateEntryMessage{ + Header: header, + } + + return msg, proto.Unmarshal(bytes, &msg.Payload) + }, + GetStateKeysEntryMessageType: func(header Header, bytes []byte) (Message, error) { + msg := &GetStateKeysEntryMessage{ + Header: header, + } + + return msg, proto.Unmarshal(bytes, &msg.Payload) + }, CompletionMessageType: func(header Header, bytes []byte) (Message, error) { msg := &CompletionMessage{ Header: header, } + return msg, proto.Unmarshal(bytes, &msg.Payload) + }, + SleepEntryMessageType: func(header Header, bytes []byte) (Message, error) { + msg := &SleepEntryMessage{ + Header: header, + } + + return msg, proto.Unmarshal(bytes, &msg.Payload) + }, + InvokeEntryMessageType: func(header Header, bytes []byte) (Message, error) { + msg := &InvokeEntryMessage{ + Header: header, + } + + return msg, proto.Unmarshal(bytes, &msg.Payload) + }, + BackgroundInvokeEntryMessageType: func(header Header, bytes []byte) (Message, error) { + msg := &BackgroundInvokeEntryMessage{ + Header: header, + } + return msg, proto.Unmarshal(bytes, &msg.Payload) }, } @@ -284,7 +336,42 @@ type GetStateEntryMessage struct { Payload protocol.GetStateEntryMessage } +type SetStateEntryMessage struct { + Header + Payload protocol.SetStateEntryMessage +} + +type ClearStateEntryMessage struct { + Header + Payload protocol.ClearStateEntryMessage +} + +type ClearAllStateEntryMessage struct { + Header + Payload protocol.ClearAllStateEntryMessage +} + +type GetStateKeysEntryMessage struct { + Header + Payload protocol.GetStateKeysEntryMessage +} + type CompletionMessage struct { Header Payload protocol.CompletionMessage } + +type SleepEntryMessage struct { + Header + Payload protocol.SleepEntryMessage +} + +type InvokeEntryMessage struct { + Header + Payload protocol.InvokeEntryMessage +} + +type BackgroundInvokeEntryMessage struct { + Header + Payload protocol.BackgroundInvokeEntryMessage +} diff --git a/router.go b/router.go index 3d34908..25f4df7 100644 --- a/router.go +++ b/router.go @@ -30,6 +30,8 @@ type Context interface { // ClearAll drops all stored state associated with key ClearAll() error + Keys() ([]string, error) + Sleep(until time.Time) error Service(service string) Service diff --git a/server/restate.go b/server/restate.go index 819312c..405beaa 100644 --- a/server/restate.go +++ b/server/restate.go @@ -70,7 +70,7 @@ func (r *Restate) discover() (resource *discovery.ServiceDiscoveryResponse, err } func (r *Restate) discoverHandler(writer http.ResponseWriter, _ *http.Request) { - log.Debug().Msg("discover called") + log.Trace().Msg("discover called") writer.Header().Add("Content-Type", "application/proto") response, err := r.discover() @@ -97,7 +97,7 @@ func (r *Restate) discoverHandler(writer http.ResponseWriter, _ *http.Request) { // takes care of function call func (r *Restate) callHandler(service, fn string, writer http.ResponseWriter, request *http.Request) { - log.Debug().Str("service", service).Str("handler", fn).Msg("got a call to service function") + log.Trace().Str("service", service).Str("handler", fn).Msg("new invocation") writer.Header().Add("content-type", "application/restate") @@ -128,7 +128,7 @@ func (r *Restate) callHandler(service, fn string, writer http.ResponseWriter, re log.Error().Err(err).Msg("failed to handle invocation") } - log.Info().Msg("invocation ended") + log.Trace().Msg("invocation ended") } func (r *Restate) handler(writer http.ResponseWriter, request *http.Request) { diff --git a/test/main.go b/test/main.go index f93a1da..140faab 100644 --- a/test/main.go +++ b/test/main.go @@ -16,18 +16,35 @@ type J = map[string]interface{} type Tickets struct{} -func (t *Tickets) Reserve(ctx restate.Context, id string, _ restate.Void) (restate.Void, error) { - if err := ctx.Set("reserved", []byte{1}); err != nil { - return restate.Void{}, err +func (t *Tickets) Reserve(ctx restate.Context, id string, _ restate.Void) (string, error) { + + fmt.Println(ctx.Keys()) + + count, err := ctx.Get("reserved") + if err != nil { + return "", err } - if err := ctx.Service("Tickets").Method("UnReserve").Send(id, nil, 0); err != nil { - return restate.Void{}, err + if len(count) == 0 { + count = make([]byte, 1) + } + count[0] += 1 + if err := ctx.Set("reserved", count); err != nil { + return "", err } - // i wanna return a non terminal error - //return restate.Void{}, fmt.Errorf("not terminal error") - return restate.Void{}, nil + if err := ctx.Set("another key", []byte{}); err != nil { + return "", err + } + + if err := ctx.Service("Tickets").Method("UnReserve").Send(id, nil, 30*time.Second); err != nil { + return "", fmt.Errorf("failed to schedule 'unreserve': %w", err) + } + + //return "", fmt.Errorf("something went wrong") + // // i wanna return a non terminal error + // //return restate.Void{}, fmt.Errorf("not terminal error") + return fmt.Sprint(count[0]), nil } func (t *Tickets) UnReserve(ctx restate.Context, id string, _ restate.Void) (restate.Void, error) { @@ -46,8 +63,6 @@ func Echo(ctx restate.Context, name string) (string, error) { return "", err } - panic("failed intentionally after a call") - return fmt.Sprintf("echo: %s", string(response)), nil }