Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

State machine #1

Merged
merged 5 commits into from
Mar 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
name: Go
on: [push]

jobs:
build:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4
- name: Setup Go
uses: actions/setup-go@v4
with:
go-version: "1.22.x"
- name: Install dependencies
run: go get .
- name: Vet
run: go vet -v ./...
- name: Build
run: go build -v ./...
- name: Test with the Go CLI
run: go test -v ./...
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
go.sum
34 changes: 34 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/posener/h2conn v0.0.0-20231204025407-3997deeca0f0 h1:zZg03nifrj6ayWNaDO8tNj57tqrOIKDwiBaLkhtK7Kk=
github.com/posener/h2conn v0.0.0-20231204025407-3997deeca0f0/go.mod h1:bblJa8QcHntareAJYfLJUzLj42sUFBKCBeTDK5LyUrw=
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/zerolog v1.32.0 h1:keLypqrlIjaFsbmJOBdB/qvyF8KEtCWHwobLp5l/mQ0=
github.com/rs/zerolog v1.32.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
google.golang.org/protobuf v1.32.0 h1:pPC6BG5ex8PDFnkbrGU3EixyhKcQ2aDuBS36lqK/C7I=
google.golang.org/protobuf v1.32.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
5 changes: 3 additions & 2 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ import (
"google.golang.org/protobuf/types/known/structpb"
)

// Empty struct used as a quick shortcut when you don't care about
// the input or output type
// Void is a placeholder used usually for functions that their signature require that
// you accept an input or return an output but the function implementation does not
// require them
type Void struct{}

func (v Void) MarshalJSON() ([]byte, error) {
Expand Down
64 changes: 52 additions & 12 deletions internal/state/call.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package state

import (
"bytes"
"encoding/json"
"fmt"
"time"
Expand Down Expand Up @@ -69,18 +70,39 @@ func (c *Machine) makeRequest(key string, body any) ([]byte, error) {
}

func (c *Machine) doCall(service, method, key string, body any) ([]byte, error) {
c.mutex.Lock()
defer c.mutex.Unlock()

input, err := c.makeRequest(key, body)
params, err := c.makeRequest(key, body)
if err != nil {
return nil, err
}

err = c.protocol.Write(&protocol.InvokeEntryMessage{
return replayOrNew(
c,
wire.InvokeEntryMessageType,
func(entry *wire.InvokeEntryMessage) ([]byte, error) {
if entry.Payload.ServiceName != service ||
entry.Payload.MethodName != method ||
!bytes.Equal(entry.Payload.Parameter, params) {
return nil, errEntryMismatch
}

switch result := entry.Payload.Result.(type) {
case *protocol.InvokeEntryMessage_Failure:
return nil, fmt.Errorf("[%d] %s", result.Failure.Code, result.Failure.Message)
case *protocol.InvokeEntryMessage_Value:
return result.Value, nil
}

return nil, errUnreachable
}, func() ([]byte, error) {
return c._doCall(service, method, params)
})
}

func (c *Machine) _doCall(service, method string, params []byte) ([]byte, error) {
err := c.protocol.Write(&protocol.InvokeEntryMessage{
ServiceName: service,
MethodName: method,
Parameter: input,
Parameter: params,
})

if err != nil {
Expand Down Expand Up @@ -121,23 +143,41 @@ func (c *Machine) doCall(service, method, key string, body any) ([]byte, error)
}

func (c *Machine) sendCall(service, method, key string, body any, delay time.Duration) error {
c.mutex.Lock()
defer c.mutex.Unlock()

input, err := c.makeRequest(key, body)
params, err := c.makeRequest(key, body)
if err != nil {
return err
}

_, err = replayOrNew(
c,
wire.BackgroundInvokeEntryMessageType,
func(entry *wire.BackgroundInvokeEntryMessage) (restate.Void, error) {
if entry.Payload.ServiceName != service ||
entry.Payload.MethodName != method ||
!bytes.Equal(entry.Payload.Parameter, params) {
return restate.Void{}, errEntryMismatch
}

return restate.Void{}, nil
},
func() (restate.Void, error) {
return restate.Void{}, c._sendCall(service, method, params, delay)
},
)

return err
}

func (c *Machine) _sendCall(service, method string, params []byte, delay time.Duration) error {
var invokeTime uint64
if delay != 0 {
invokeTime = uint64(time.Now().Add(delay).UnixMilli())
}

err = c.protocol.Write(&protocol.BackgroundInvokeEntryMessage{
err := c.protocol.Write(&protocol.BackgroundInvokeEntryMessage{
ServiceName: service,
MethodName: method,
Parameter: input,
Parameter: params,
InvokeTime: invokeTime,
})

Expand Down
110 changes: 85 additions & 25 deletions internal/state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ const (
var (
ErrUnexpectedMessage = fmt.Errorf("unexpected message")
ErrInvalidVersion = fmt.Errorf("invalid version number")

errUnreachable = fmt.Errorf("unreachable")
)

var (
Expand Down Expand Up @@ -61,6 +59,10 @@ func (c *Context) Get(key string) ([]byte, error) {
return c.machine.get(key)
}

func (c *Context) Keys() ([]string, error) {
return c.machine.keys()
}

func (c *Context) Sleep(until time.Time) error {
return c.machine.sleep(until)
}
Expand Down Expand Up @@ -98,7 +100,8 @@ type Machine struct {
partial bool
current map[string][]byte

entries []wire.Message
entries []wire.Message
entryIndex int
}

func NewMachine(handler restate.Handler, conn io.ReadWriter) *Machine {
Expand All @@ -109,11 +112,33 @@ func NewMachine(handler restate.Handler, conn io.ReadWriter) *Machine {
}
}

func (m *Machine) output(r *dynrpc.RpcResponse, err error) proto.Message {
// TODO: if err is terminal return outputStreamEntryMessage but if error is
// not terminal, return ErrorMessage instead.
//var output protocol.OutputStreamEntryMessage
// Start starts the state machine
func (m *Machine) Start(inner context.Context) error {
// reader starts a rea
msg, err := m.protocol.Read()
if err != nil {
return err
}

if msg.Type() != wire.StartMessageType {
// invalid negotiation
return ErrUnexpectedMessage
}

start := msg.(*wire.StartMessage)

if start.Version != Version {
return ErrInvalidVersion
}

ctx := newContext(inner, m)

log.Trace().Str("id", base64.URLEncoding.EncodeToString(m.id)).Msg("start invocation")
return m.process(ctx, start)
}

// handle handler response and build proper response message
func (m *Machine) output(r *dynrpc.RpcResponse, err error) proto.Message {
if err != nil && restate.IsTerminalError(err) {
// terminal errors.
return &protocol.OutputStreamEntryMessage{
Expand Down Expand Up @@ -195,14 +220,18 @@ func (m *Machine) process(ctx *Context, start *wire.StartMessage) error {
return ErrUnexpectedMessage
}

log.Debug().Uint32("known entries", start.Payload.KnownEntries).Msg("known entires")
log.Trace().Uint32("known entries", start.Payload.KnownEntries).Msg("known entires")
m.entries = make([]wire.Message, 0, start.Payload.KnownEntries-1)

// we don't track the poll input entry
for i := uint32(1); i < start.Payload.KnownEntries; i++ {
msg, err := m.protocol.Read()
if err != nil {
return fmt.Errorf("failed to read entry: %w", err)
}

log.Debug().Uint16("type", uint16(msg.Type())).Msg("received entry")
log.Trace().Uint16("type", uint16(msg.Type())).Msg("replay log entry")
m.entries = append(m.entries, msg)
}

inputMsg := msg.(*wire.PollInputEntry)
Expand All @@ -216,26 +245,57 @@ func (m *Machine) process(ctx *Context, start *wire.StartMessage) error {

}

func (m *Machine) Start(inner context.Context) error {
// reader starts a rea
msg, err := m.protocol.Read()
if err != nil {
return err
func (c *Machine) currentEntry() (wire.Message, bool) {
if c.entryIndex < len(c.entries) {
return c.entries[c.entryIndex], true
}

if msg.Type() != wire.StartMessageType {
// invalid negotiation
return ErrUnexpectedMessage
}
return nil, false
}

start := msg.(*wire.StartMessage)
// replayOrNew is a utility function to easily either
// replay a log entry, or create a new one if one
// does not exist
//
// this should be an instance method on Machine but unfortunately
// go does not support generics on instance methods
//
// the idea is when called, it will check if there is a log
// entry at current index, then compare the entry message type
// if not matching, that's obviously an error with the code version
// (code has changed and now doesn't match the play log)
//
// if type is okay, the function will then call a `replay“ callback.
// the replay callback just need to extract the result from the entry
//
// otherwise this function will call a `new` callback to create a new entry in the log
// by sending the proper runtime messages
func replayOrNew[M wire.Message, O any](
m *Machine,
typ wire.Type,
replay func(msg M) (O, error),
new func() (O, error),
) (output O, err error) {

m.mutex.Lock()
defer m.mutex.Unlock()

if start.Version != Version {
return ErrInvalidVersion
}
defer func() {
m.entryIndex += 1
}()

ctx := newContext(inner, m)
// check if there is an entry as this index
entry, ok := m.currentEntry()

log.Trace().Str("id", base64.URLEncoding.EncodeToString(m.id)).Msg("start invocation")
return m.process(ctx, start)
// if entry exists, we need to replay it
// by calling the replay function
if ok {
if entry.Type() != typ {
return output, errEntryMismatch
}
return replay(entry.(M))
}

// other wise call the new function
return new()
}
Loading
Loading