diff --git a/example/ticket_service.go b/example/ticket_service.go index 3b99a29..495a83d 100644 --- a/example/ticket_service.go +++ b/example/ticket_service.go @@ -4,6 +4,7 @@ import ( "errors" "github.com/muhamadazmy/restate-sdk-go" + "github.com/rs/zerolog/log" ) type TicketStatus int @@ -14,7 +15,8 @@ const ( TicketSold TicketStatus = 2 ) -func reserve(ctx restate.Context, _ string, _ restate.Void) (bool, error) { +func reserve(ctx restate.Context, ticketId string, _ restate.Void) (bool, error) { + log.Info().Str("ticket", ticketId).Msg("reserving ticket") status, err := restate.GetAs[TicketStatus](ctx, "status") if err != nil && !errors.Is(err, restate.ErrKeyNotFound) { return false, err @@ -27,7 +29,8 @@ func reserve(ctx restate.Context, _ string, _ restate.Void) (bool, error) { return false, nil } -func unreserve(ctx restate.Context, _ string, _ restate.Void) (void restate.Void, err error) { +func unreserve(ctx restate.Context, ticketId string, _ restate.Void) (void restate.Void, err error) { + log.Info().Str("ticket", ticketId).Msg("un-reserving ticket") status, err := restate.GetAs[TicketStatus](ctx, "status") if err != nil && !errors.Is(err, restate.ErrKeyNotFound) { return void, err @@ -40,7 +43,9 @@ func unreserve(ctx restate.Context, _ string, _ restate.Void) (void restate.Void return void, nil } -func markAsSold(ctx restate.Context, _ string, _ restate.Void) (void restate.Void, err error) { +func markAsSold(ctx restate.Context, ticketId string, _ restate.Void) (void restate.Void, err error) { + log.Info().Str("ticket", ticketId).Msg("mark ticket as sold") + status, err := restate.GetAs[TicketStatus](ctx, "status") if err != nil && !errors.Is(err, restate.ErrKeyNotFound) { return void, err diff --git a/example/user_session.go b/example/user_session.go index 8368f31..b0bad8d 100644 --- a/example/user_session.go +++ b/example/user_session.go @@ -22,7 +22,7 @@ func addTicket(ctx restate.Context, userId, ticketId string) (bool, error) { // add ticket to list of tickets tickets, err := restate.GetAs[[]string](ctx, "tickets") - if err != nil && errors.Is(err, restate.ErrKeyNotFound) { + if err != nil && !errors.Is(err, restate.ErrKeyNotFound) { return false, err } diff --git a/internal/state/call.go b/internal/state/call.go index c452980..47a79db 100644 --- a/internal/state/call.go +++ b/internal/state/call.go @@ -50,7 +50,7 @@ func (c *serviceCall) Send(key string, body any, delay time.Duration) error { return c.machine.sendCall(c.service, c.method, key, body, delay) } -func (c *Machine) makeRequest(key string, body any) ([]byte, error) { +func (m *Machine) makeRequest(key string, body any) ([]byte, error) { input, err := json.Marshal(body) if err != nil { @@ -69,14 +69,16 @@ func (c *Machine) makeRequest(key string, body any) ([]byte, error) { return proto.Marshal(params) } -func (c *Machine) doCall(service, method, key string, input, output any) error { - params, err := c.makeRequest(key, input) +func (m *Machine) doCall(service, method, key string, input, output any) error { + m.log.Debug().Str("service", service).Str("method", method).Msg("in do call") + + params, err := m.makeRequest(key, input) if err != nil { return err } bytes, err := replayOrNew( - c, + m, wire.InvokeEntryMessageType, func(entry *wire.InvokeEntryMessage) ([]byte, error) { if entry.Payload.ServiceName != service || @@ -89,12 +91,17 @@ func (c *Machine) doCall(service, method, key string, input, output any) error { case *protocol.InvokeEntryMessage_Failure: return nil, fmt.Errorf("[%d] %s", result.Failure.Code, result.Failure.Message) case *protocol.InvokeEntryMessage_Value: - return result.Value, nil + var rpcResponse dynrpc.RpcResponse + if err := proto.Unmarshal(result.Value, &rpcResponse); err != nil { + return nil, fmt.Errorf("failed to decode rpc response: %w", err) + } + + return rpcResponse.Response.MarshalJSON() } return nil, errUnreachable }, func() ([]byte, error) { - return c._doCall(service, method, params) + return m._doCall(service, method, params) }) if err != nil { @@ -106,14 +113,14 @@ func (c *Machine) doCall(service, method, key string, input, output any) error { } if err := json.Unmarshal(bytes, output); err != nil { - return restate.TerminalError(fmt.Errorf("failed to decode response: %w", err)) + return restate.TerminalError(fmt.Errorf("failed to decode response (%s): %w", string(bytes), err)) } return nil } -func (c *Machine) _doCall(service, method string, params []byte) ([]byte, error) { - err := c.protocol.Write(&protocol.InvokeEntryMessage{ +func (m *Machine) _doCall(service, method string, params []byte) ([]byte, error) { + err := m.protocol.Write(&protocol.InvokeEntryMessage{ ServiceName: service, MethodName: method, Parameter: params, @@ -123,7 +130,7 @@ func (c *Machine) _doCall(service, method string, params []byte) ([]byte, error) return nil, fmt.Errorf("failed to send request message: %w", err) } - response, err := c.protocol.Read() + response, err := m.protocol.Read() if err != nil { return nil, fmt.Errorf("failed to read response message: %w", err) } @@ -150,7 +157,7 @@ func (c *Machine) _doCall(service, method string, params []byte) ([]byte, error) var rpcResponse dynrpc.RpcResponse if err := proto.Unmarshal(output, &rpcResponse); err != nil { - return nil, fmt.Errorf("failed to decode rpc response: %w", err) + return nil, fmt.Errorf("failed to decode rpc response(%s,%s): %w", service, method, err) } return rpcResponse.Response.MarshalJSON() diff --git a/internal/state/state.go b/internal/state/state.go index fcff304..b6c3afb 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -2,7 +2,6 @@ package state import ( "context" - "encoding/base64" "fmt" "io" "runtime/debug" @@ -15,6 +14,7 @@ import ( "github.com/muhamadazmy/restate-sdk-go/generated/proto/protocol" "github.com/muhamadazmy/restate-sdk-go/internal/wire" + "github.com/rs/zerolog" "github.com/rs/zerolog/log" "google.golang.org/protobuf/proto" ) @@ -109,13 +109,15 @@ type Machine struct { mutex sync.Mutex // state - id []byte + id string partial bool current map[string][]byte entries []wire.Message entryIndex int + + log zerolog.Logger } func NewMachine(handler restate.Handler, conn io.ReadWriter) *Machine { @@ -127,7 +129,7 @@ func NewMachine(handler restate.Handler, conn io.ReadWriter) *Machine { } // Start starts the state machine -func (m *Machine) Start(inner context.Context) error { +func (m *Machine) Start(inner context.Context, trace string) error { // reader starts a rea msg, err := m.protocol.Read() if err != nil { @@ -141,18 +143,27 @@ func (m *Machine) Start(inner context.Context) error { start := msg.(*wire.StartMessage) + m.id = start.Payload.DebugId if start.Version != Version { return ErrInvalidVersion } + m.log = log.With().Str("id", start.Payload.DebugId).Str("method", trace).Logger() + ctx := newContext(inner, m) - log.Trace().Str("id", base64.URLEncoding.EncodeToString(m.id)).Msg("start invocation") + m.log.Debug().Msg("start invocation") + defer m.log.Debug().Msg("invocation ended") + return m.process(ctx, start) } // handle handler response and build proper response message func (m *Machine) output(r *dynrpc.RpcResponse, err error) proto.Message { + if err != nil { + m.log.Error().Err(err).Msg("failure") + } + if err != nil && restate.IsTerminalError(err) { // terminal errors. return &protocol.OutputStreamEntryMessage{ @@ -207,7 +218,7 @@ func (m *Machine) invoke(ctx *Context, input *dynrpc.RpcRequest) error { }) if wErr != nil { - log.Error().Err(wErr).Msg("error sending failure message") + m.log.Error().Err(wErr).Msg("error sending failure message") } } }() @@ -218,8 +229,6 @@ func (m *Machine) invoke(ctx *Context, input *dynrpc.RpcRequest) error { } func (m *Machine) process(ctx *Context, start *wire.StartMessage) error { - m.id = start.Payload.Id - for _, entry := range start.Payload.StateMap { m.current[string(entry.Key)] = entry.Value } @@ -234,7 +243,7 @@ func (m *Machine) process(ctx *Context, start *wire.StartMessage) error { return ErrUnexpectedMessage } - log.Trace().Uint32("known entries", start.Payload.KnownEntries).Msg("known entires") + m.log.Trace().Uint32("known entries", start.Payload.KnownEntries).Msg("known entires") m.entries = make([]wire.Message, 0, start.Payload.KnownEntries-1) // we don't track the poll input entry @@ -244,7 +253,7 @@ func (m *Machine) process(ctx *Context, start *wire.StartMessage) error { return fmt.Errorf("failed to read entry: %w", err) } - log.Trace().Uint16("type", uint16(msg.Type())).Msg("replay log entry") + m.log.Trace().Uint16("type", uint16(msg.Type())).Msg("replay log entry") m.entries = append(m.entries, msg) } diff --git a/internal/state/sys.go b/internal/state/sys.go index ce1831c..c6f9387 100644 --- a/internal/state/sys.go +++ b/internal/state/sys.go @@ -10,7 +10,6 @@ import ( "github.com/muhamadazmy/restate-sdk-go/generated/proto/javascript" "github.com/muhamadazmy/restate-sdk-go/generated/proto/protocol" "github.com/muhamadazmy/restate-sdk-go/internal/wire" - "github.com/rs/zerolog/log" "google.golang.org/protobuf/proto" ) @@ -19,9 +18,9 @@ var ( errUnreachable = fmt.Errorf("unreachable") ) -func (c *Machine) set(key string, value []byte) error { +func (m *Machine) set(key string, value []byte) error { _, err := replayOrNew( - c, + m, wire.SetStateEntryMessageType, func(entry *wire.SetStateEntryMessage) (void restate.Void, err error) { if string(entry.Payload.Key) != key || !bytes.Equal(entry.Payload.Value, value) { @@ -30,25 +29,25 @@ func (c *Machine) set(key string, value []byte) error { return }, func() (void restate.Void, err error) { - return void, c._set(key, value) + return void, m._set(key, value) }) return err } -func (c *Machine) _set(key string, value []byte) error { - c.current[key] = value +func (m *Machine) _set(key string, value []byte) error { + m.current[key] = value - return c.protocol.Write( + return m.protocol.Write( &protocol.SetStateEntryMessage{ Key: []byte(key), Value: value, }) } -func (c *Machine) clear(key string) error { +func (m *Machine) clear(key string) error { _, err := replayOrNew( - c, + m, wire.ClearStateEntryMessageType, func(entry *wire.ClearStateEntryMessage) (void restate.Void, err error) { if string(entry.Payload.Key) != key { @@ -57,45 +56,62 @@ func (c *Machine) clear(key string) error { return void, nil }, func() (restate.Void, error) { - return restate.Void{}, c._clear(key) + return restate.Void{}, m._clear(key) }, ) + if err != nil { + return err + } + + m.mutex.Lock() + defer m.mutex.Unlock() + + delete(m.current, key) + return err } -func (c *Machine) _clear(key string) error { - return c.protocol.Write( +func (m *Machine) _clear(key string) error { + return m.protocol.Write( &protocol.ClearStateEntryMessage{ Key: []byte(key), }, ) } -func (c *Machine) clearAll() error { +func (m *Machine) clearAll() error { _, err := replayOrNew( - c, + m, wire.ClearAllStateEntryMessageType, func(entry *wire.ClearAllStateEntryMessage) (void restate.Void, err error) { return }, func() (restate.Void, error) { - return restate.Void{}, c._clearAll() + return restate.Void{}, m._clearAll() }, ) + if err != nil { + return err + } - return err + m.mutex.Lock() + defer m.mutex.Unlock() + m.current = map[string][]byte{} + m.partial = false + + return nil } // clearAll drops all associated keys -func (c *Machine) _clearAll() error { - return c.protocol.Write( +func (m *Machine) _clearAll() error { + return m.protocol.Write( &protocol.ClearAllStateEntryMessage{}, ) } -func (c *Machine) get(key string) ([]byte, error) { +func (m *Machine) get(key string) ([]byte, error) { return replayOrNew( - c, + m, wire.GetStateEntryMessageType, func(entry *wire.GetStateEntryMessage) ([]byte, error) { if string(entry.Payload.Key) != key { @@ -113,16 +129,16 @@ func (c *Machine) get(key string) ([]byte, error) { return nil, fmt.Errorf("unreachable") }, func() ([]byte, error) { - return c._get(key) + return m._get(key) }) } -func (c *Machine) _get(key string) ([]byte, error) { +func (m *Machine) _get(key string) ([]byte, error) { msg := &protocol.GetStateEntryMessage{ Key: []byte(key), } - value, ok := c.current[key] + value, ok := m.current[key] if ok { // value in map, we still send the current @@ -131,32 +147,43 @@ func (c *Machine) _get(key string) ([]byte, error) { Value: value, } - if err := c.protocol.Write(msg); err != nil { + if err := m.protocol.Write(msg); err != nil { return nil, err } + // read and discard response + _, err := m.protocol.Read() + if err != nil { + return value, err + } return value, nil } // key is not in map! there are 2 cases. - if !c.partial { + if !m.partial { // current is complete. we need to return nil to the user // but also send an empty get state entry message msg.Result = &protocol.GetStateEntryMessage_Empty{} - if err := c.protocol.Write(msg); err != nil { + if err := m.protocol.Write(msg); err != nil { return nil, err } + // read and discard response + _, err := m.protocol.Read() + if err != nil { + return value, err + } + return nil, nil } - if err := c.protocol.Write(msg); err != nil { + if err := m.protocol.Write(msg); err != nil { return nil, err } // wait for completion - response, err := c.protocol.Read() + response, err := m.protocol.Read() if err != nil { return nil, err } @@ -175,16 +202,16 @@ func (c *Machine) _get(key string) ([]byte, error) { // never happen return nil, fmt.Errorf("[%d] %s", value.Failure.Code, value.Failure.Message) case *protocol.CompletionMessage_Value: - c.current[key] = value.Value + m.current[key] = value.Value return value.Value, nil } return nil, fmt.Errorf("unreachable") } -func (c *Machine) keys() ([]string, error) { +func (m *Machine) keys() ([]string, error) { return replayOrNew( - c, + m, wire.GetStateKeysEntryMessageType, func(entry *wire.GetStateKeysEntryMessage) ([]string, error) { switch result := entry.Payload.Result.(type) { @@ -200,22 +227,22 @@ func (c *Machine) keys() ([]string, error) { return nil, errUnreachable }, - c._keys, + m._keys, ) } -func (c *Machine) _keys() ([]string, error) { - if err := c.protocol.Write(&protocol.GetStateKeysEntryMessage{}); err != nil { +func (m *Machine) _keys() ([]string, error) { + if err := m.protocol.Write(&protocol.GetStateKeysEntryMessage{}); err != nil { return nil, err } - msg, err := c.protocol.Read() + msg, err := m.protocol.Read() if err != nil { return nil, err } if msg.Type() != wire.CompletionMessageType { - log.Error().Stringer("type", msg.Type()).Msg("receiving message of type") + m.log.Error().Stringer("type", msg.Type()).Msg("receiving message of type") return nil, ErrUnexpectedMessage } @@ -246,29 +273,29 @@ func (c *Machine) _keys() ([]string, error) { return nil, nil } -func (c *Machine) sleep(until time.Time) error { +func (m *Machine) sleep(until time.Time) error { _, err := replayOrNew( - c, + m, wire.SleepEntryMessageType, func(entry *wire.SleepEntryMessage) (void restate.Void, err error) { // we shouldn't verify the time because this would be different every time return }, func() (restate.Void, error) { - return restate.Void{}, c._sleep(until) + return restate.Void{}, m._sleep(until) }, ) return err } -func (c *Machine) _sleep(until time.Time) error { - if err := c.protocol.Write(&protocol.SleepEntryMessage{ +func (m *Machine) _sleep(until time.Time) error { + if err := m.protocol.Write(&protocol.SleepEntryMessage{ WakeUpTime: uint64(until.UnixMilli()), }); err != nil { return err } - response, err := c.protocol.Read() + response, err := m.protocol.Read() if err != nil { return err } @@ -280,9 +307,9 @@ func (c *Machine) _sleep(until time.Time) error { return nil } -func (c *Machine) sideEffect(fn func() ([]byte, error), bo backoff.BackOff) ([]byte, error) { +func (m *Machine) sideEffect(fn func() ([]byte, error), bo backoff.BackOff) ([]byte, error) { return replayOrNew( - c, + m, wire.SideEffectEntryMessageType, func(entry *wire.SideEffectEntryMessage) ([]byte, error) { switch result := entry.Payload.Result.(type) { @@ -299,12 +326,12 @@ func (c *Machine) sideEffect(fn func() ([]byte, error), bo backoff.BackOff) ([]b return nil, errUnreachable }, func() ([]byte, error) { - return c._sideEffect(fn, bo) + return m._sideEffect(fn, bo) }, ) } -func (c *Machine) _sideEffect(fn func() ([]byte, error), bo backoff.BackOff) ([]byte, error) { +func (m *Machine) _sideEffect(fn func() ([]byte, error), bo backoff.BackOff) ([]byte, error) { var bytes []byte err := backoff.Retry(func() error { var err error @@ -336,7 +363,7 @@ func (c *Machine) _sideEffect(fn func() ([]byte, error), bo backoff.BackOff) ([] } } - if err := c.protocol.Write(&msg); err != nil { + if err := m.protocol.Write(&msg); err != nil { return nil, err } diff --git a/router.go b/router.go index da71610..f78b544 100644 --- a/router.go +++ b/router.go @@ -136,7 +136,6 @@ func (r *KeyedRouter) Handlers() map[string]Handler { // returned // it does encoding/decoding of bytes automatically using msgpack func GetAs[T any](ctx Context, key string) (output T, err error) { - bytes, err := ctx.Get(key) if err != nil { return output, err diff --git a/server/restate.go b/server/restate.go index 405beaa..2263fea 100644 --- a/server/restate.go +++ b/server/restate.go @@ -97,7 +97,7 @@ func (r *Restate) discoverHandler(writer http.ResponseWriter, _ *http.Request) { // takes care of function call func (r *Restate) callHandler(service, fn string, writer http.ResponseWriter, request *http.Request) { - log.Trace().Str("service", service).Str("handler", fn).Msg("new invocation") + log.Debug().Str("service", service).Str("handler", fn).Msg("new request") writer.Header().Add("content-type", "application/restate") @@ -124,15 +124,13 @@ func (r *Restate) callHandler(service, fn string, writer http.ResponseWriter, re machine := state.NewMachine(handler, conn) - if err := machine.Start(request.Context()); err != nil { + if err := machine.Start(request.Context(), fmt.Sprintf("%s/%s", service, fn)); err != nil { log.Error().Err(err).Msg("failed to handle invocation") } - - log.Trace().Msg("invocation ended") } func (r *Restate) handler(writer http.ResponseWriter, request *http.Request) { - log.Info().Str("proto", request.Proto).Str("method", request.Method).Str("path", request.RequestURI).Msg("got request") + log.Trace().Str("proto", request.Proto).Str("method", request.Method).Str("path", request.RequestURI).Msg("got request") if request.Method != http.MethodPost { writer.WriteHeader(http.StatusMethodNotAllowed)