Skip to content

Commit

Permalink
completel example
Browse files Browse the repository at this point in the history
also fix issue with status get
  • Loading branch information
muhamadazmy committed Mar 17, 2024
1 parent a621732 commit fe862ad
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 77 deletions.
11 changes: 8 additions & 3 deletions example/ticket_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"errors"

"github.com/muhamadazmy/restate-sdk-go"
"github.com/rs/zerolog/log"
)

type TicketStatus int
Expand All @@ -14,7 +15,8 @@ const (
TicketSold TicketStatus = 2
)

func reserve(ctx restate.Context, _ string, _ restate.Void) (bool, error) {
func reserve(ctx restate.Context, ticketId string, _ restate.Void) (bool, error) {
log.Info().Str("ticket", ticketId).Msg("reserving ticket")
status, err := restate.GetAs[TicketStatus](ctx, "status")
if err != nil && !errors.Is(err, restate.ErrKeyNotFound) {
return false, err
Expand All @@ -27,7 +29,8 @@ func reserve(ctx restate.Context, _ string, _ restate.Void) (bool, error) {
return false, nil
}

func unreserve(ctx restate.Context, _ string, _ restate.Void) (void restate.Void, err error) {
func unreserve(ctx restate.Context, ticketId string, _ restate.Void) (void restate.Void, err error) {
log.Info().Str("ticket", ticketId).Msg("un-reserving ticket")
status, err := restate.GetAs[TicketStatus](ctx, "status")
if err != nil && !errors.Is(err, restate.ErrKeyNotFound) {
return void, err
Expand All @@ -40,7 +43,9 @@ func unreserve(ctx restate.Context, _ string, _ restate.Void) (void restate.Void
return void, nil
}

func markAsSold(ctx restate.Context, _ string, _ restate.Void) (void restate.Void, err error) {
func markAsSold(ctx restate.Context, ticketId string, _ restate.Void) (void restate.Void, err error) {
log.Info().Str("ticket", ticketId).Msg("mark ticket as sold")

status, err := restate.GetAs[TicketStatus](ctx, "status")
if err != nil && !errors.Is(err, restate.ErrKeyNotFound) {
return void, err
Expand Down
2 changes: 1 addition & 1 deletion example/user_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func addTicket(ctx restate.Context, userId, ticketId string) (bool, error) {
// add ticket to list of tickets
tickets, err := restate.GetAs[[]string](ctx, "tickets")

if err != nil && errors.Is(err, restate.ErrKeyNotFound) {
if err != nil && !errors.Is(err, restate.ErrKeyNotFound) {
return false, err
}

Expand Down
29 changes: 18 additions & 11 deletions internal/state/call.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func (c *serviceCall) Send(key string, body any, delay time.Duration) error {
return c.machine.sendCall(c.service, c.method, key, body, delay)
}

func (c *Machine) makeRequest(key string, body any) ([]byte, error) {
func (m *Machine) makeRequest(key string, body any) ([]byte, error) {

input, err := json.Marshal(body)
if err != nil {
Expand All @@ -69,14 +69,16 @@ func (c *Machine) makeRequest(key string, body any) ([]byte, error) {
return proto.Marshal(params)
}

func (c *Machine) doCall(service, method, key string, input, output any) error {
params, err := c.makeRequest(key, input)
func (m *Machine) doCall(service, method, key string, input, output any) error {
m.log.Debug().Str("service", service).Str("method", method).Msg("in do call")

params, err := m.makeRequest(key, input)
if err != nil {
return err
}

bytes, err := replayOrNew(
c,
m,
wire.InvokeEntryMessageType,
func(entry *wire.InvokeEntryMessage) ([]byte, error) {
if entry.Payload.ServiceName != service ||
Expand All @@ -89,12 +91,17 @@ func (c *Machine) doCall(service, method, key string, input, output any) error {
case *protocol.InvokeEntryMessage_Failure:
return nil, fmt.Errorf("[%d] %s", result.Failure.Code, result.Failure.Message)
case *protocol.InvokeEntryMessage_Value:
return result.Value, nil
var rpcResponse dynrpc.RpcResponse
if err := proto.Unmarshal(result.Value, &rpcResponse); err != nil {
return nil, fmt.Errorf("failed to decode rpc response: %w", err)
}

return rpcResponse.Response.MarshalJSON()
}

return nil, errUnreachable
}, func() ([]byte, error) {
return c._doCall(service, method, params)
return m._doCall(service, method, params)
})

if err != nil {
Expand All @@ -106,14 +113,14 @@ func (c *Machine) doCall(service, method, key string, input, output any) error {
}

if err := json.Unmarshal(bytes, output); err != nil {
return restate.TerminalError(fmt.Errorf("failed to decode response: %w", err))
return restate.TerminalError(fmt.Errorf("failed to decode response (%s): %w", string(bytes), err))
}

return nil
}

func (c *Machine) _doCall(service, method string, params []byte) ([]byte, error) {
err := c.protocol.Write(&protocol.InvokeEntryMessage{
func (m *Machine) _doCall(service, method string, params []byte) ([]byte, error) {
err := m.protocol.Write(&protocol.InvokeEntryMessage{
ServiceName: service,
MethodName: method,
Parameter: params,
Expand All @@ -123,7 +130,7 @@ func (c *Machine) _doCall(service, method string, params []byte) ([]byte, error)
return nil, fmt.Errorf("failed to send request message: %w", err)
}

response, err := c.protocol.Read()
response, err := m.protocol.Read()
if err != nil {
return nil, fmt.Errorf("failed to read response message: %w", err)
}
Expand All @@ -150,7 +157,7 @@ func (c *Machine) _doCall(service, method string, params []byte) ([]byte, error)

var rpcResponse dynrpc.RpcResponse
if err := proto.Unmarshal(output, &rpcResponse); err != nil {
return nil, fmt.Errorf("failed to decode rpc response: %w", err)
return nil, fmt.Errorf("failed to decode rpc response(%s,%s): %w", service, method, err)
}

return rpcResponse.Response.MarshalJSON()
Expand Down
27 changes: 18 additions & 9 deletions internal/state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package state

import (
"context"
"encoding/base64"
"fmt"
"io"
"runtime/debug"
Expand All @@ -15,6 +14,7 @@ import (
"github.com/muhamadazmy/restate-sdk-go/generated/proto/protocol"
"github.com/muhamadazmy/restate-sdk-go/internal/wire"

"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"google.golang.org/protobuf/proto"
)
Expand Down Expand Up @@ -109,13 +109,15 @@ type Machine struct {
mutex sync.Mutex

// state
id []byte
id string

partial bool
current map[string][]byte

entries []wire.Message
entryIndex int

log zerolog.Logger
}

func NewMachine(handler restate.Handler, conn io.ReadWriter) *Machine {
Expand All @@ -127,7 +129,7 @@ func NewMachine(handler restate.Handler, conn io.ReadWriter) *Machine {
}

// Start starts the state machine
func (m *Machine) Start(inner context.Context) error {
func (m *Machine) Start(inner context.Context, trace string) error {
// reader starts a rea
msg, err := m.protocol.Read()
if err != nil {
Expand All @@ -141,18 +143,27 @@ func (m *Machine) Start(inner context.Context) error {

start := msg.(*wire.StartMessage)

m.id = start.Payload.DebugId
if start.Version != Version {
return ErrInvalidVersion
}

m.log = log.With().Str("id", start.Payload.DebugId).Str("method", trace).Logger()

ctx := newContext(inner, m)

log.Trace().Str("id", base64.URLEncoding.EncodeToString(m.id)).Msg("start invocation")
m.log.Debug().Msg("start invocation")
defer m.log.Debug().Msg("invocation ended")

return m.process(ctx, start)
}

// handle handler response and build proper response message
func (m *Machine) output(r *dynrpc.RpcResponse, err error) proto.Message {
if err != nil {
m.log.Error().Err(err).Msg("failure")
}

if err != nil && restate.IsTerminalError(err) {
// terminal errors.
return &protocol.OutputStreamEntryMessage{
Expand Down Expand Up @@ -207,7 +218,7 @@ func (m *Machine) invoke(ctx *Context, input *dynrpc.RpcRequest) error {
})

if wErr != nil {
log.Error().Err(wErr).Msg("error sending failure message")
m.log.Error().Err(wErr).Msg("error sending failure message")
}
}
}()
Expand All @@ -218,8 +229,6 @@ func (m *Machine) invoke(ctx *Context, input *dynrpc.RpcRequest) error {
}

func (m *Machine) process(ctx *Context, start *wire.StartMessage) error {
m.id = start.Payload.Id

for _, entry := range start.Payload.StateMap {
m.current[string(entry.Key)] = entry.Value
}
Expand All @@ -234,7 +243,7 @@ func (m *Machine) process(ctx *Context, start *wire.StartMessage) error {
return ErrUnexpectedMessage
}

log.Trace().Uint32("known entries", start.Payload.KnownEntries).Msg("known entires")
m.log.Trace().Uint32("known entries", start.Payload.KnownEntries).Msg("known entires")
m.entries = make([]wire.Message, 0, start.Payload.KnownEntries-1)

// we don't track the poll input entry
Expand All @@ -244,7 +253,7 @@ func (m *Machine) process(ctx *Context, start *wire.StartMessage) error {
return fmt.Errorf("failed to read entry: %w", err)
}

log.Trace().Uint16("type", uint16(msg.Type())).Msg("replay log entry")
m.log.Trace().Uint16("type", uint16(msg.Type())).Msg("replay log entry")
m.entries = append(m.entries, msg)
}

Expand Down
Loading

0 comments on commit fe862ad

Please sign in to comment.