Skip to content

Commit

Permalink
Awakeables should use completions
Browse files Browse the repository at this point in the history
  • Loading branch information
jackkleeman committed Jul 8, 2024
1 parent 80898c3 commit 876a856
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 38 deletions.
82 changes: 53 additions & 29 deletions internal/state/awakeable.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package state

import (
"bytes"
"context"
"encoding/base64"
"encoding/binary"
"fmt"
Expand All @@ -13,8 +14,44 @@ import (

const AWAKEABLE_IDENTIFIER_PREFIX = "prom_1"

type awakeable[T any] interface {
restate.Awakeable[T]
func resultFromCompletion(completion *protocol.CompletionMessage) restate.Result[[]byte] {
switch result := completion.Result.(type) {
case *protocol.CompletionMessage_Empty:
return restate.Result[[]byte]{Value: nil}
case *protocol.CompletionMessage_Value:
return restate.Result[[]byte]{Value: result.Value}
case *protocol.CompletionMessage_Failure:
return restate.Result[[]byte]{Err: restate.TerminalError(ErrorFromFailure(result.Failure))}
default:
panic("unreachable")
}
}

type completionAwakeable struct {
ctx context.Context
invocationID []byte
entryIndex uint32
completionFut *CompletionFuture
}

func (c *completionAwakeable) Id() string { return awakeableID(c.invocationID, c.entryIndex) }
func (c *completionAwakeable) Chan() <-chan restate.Result[[]byte] {
ch := make(chan restate.Result[[]byte], 1)
if completion, ok := c.completionFut.Done(); ok {
// fast path
ch <- resultFromCompletion(completion)
return ch
}
// slow path
go func() {
completion, err := c.completionFut.Await(c.ctx)
if err != nil {
ch <- restate.Result[[]byte]{Err: err}
} else {
ch <- resultFromCompletion(completion)
}
}()
return ch
}

type completedAwakeable[T any] struct {
Expand All @@ -30,20 +67,6 @@ func (c *completedAwakeable[T]) Chan() <-chan restate.Result[T] {
return ch
}

type suspendingAwakeable[T any] struct {
invocationID []byte
entryIndex uint32
}

func (c *suspendingAwakeable[T]) Id() string { return awakeableID(c.invocationID, c.entryIndex) }

// this is a temporary hack; always suspend when this channel is read
// currently needed because we don't have a way to process the completion while the invocation is in progress
// and so can only deal with it on replay
func (c *suspendingAwakeable[T]) Chan() <-chan restate.Result[T] {
panic(&suspend{resumeEntry: c.entryIndex})
}

func awakeableID(invocationID []byte, entryIndex uint32) string {
bytes := make([]byte, 0, len(invocationID)+4)
bytes = append(bytes, invocationID...)
Expand All @@ -55,37 +78,38 @@ func (c *Machine) awakeable() (restate.Awakeable[[]byte], error) {
awakeable, err := replayOrNew(
c,
wire.AwakeableEntryMessageType,
func(entry *wire.AwakeableEntryMessage) (awakeable[[]byte], error) {
func(entry *wire.AwakeableEntryMessage) (restate.Awakeable[[]byte], error) {
if entry.Payload.Result == nil {
return &suspendingAwakeable[[]byte]{entryIndex: c.entryIndex, invocationID: c.id}, nil
completionFut := c.pendingCompletions[c.entryIndex]
if completionFut == nil {
return nil, restate.TerminalError(fmt.Errorf("replaying awakeable at index %d is not completed but no pending completion for it", c.entryIndex), restate.ErrProtocolViolation)
}
// replaying an uncompleted awakeable, there must be a pending entry for it
return &completionAwakeable{ctx: c.ctx, entryIndex: c.entryIndex, invocationID: c.id, completionFut: completionFut}, nil
}
switch result := entry.Payload.Result.(type) {
case *protocol.AwakeableEntryMessage_Value:
return &completedAwakeable[[]byte]{entryIndex: c.entryIndex, invocationID: c.id, result: restate.Result[[]byte]{Value: result.Value}}, nil
case *protocol.AwakeableEntryMessage_Failure:
return &completedAwakeable[[]byte]{entryIndex: c.entryIndex, invocationID: c.id, result: restate.Result[[]byte]{Err: restate.TerminalError(fmt.Errorf(result.Failure.Message), restate.Code(result.Failure.Code))}}, nil
return &completedAwakeable[[]byte]{entryIndex: c.entryIndex, invocationID: c.id, result: restate.Result[[]byte]{Err: nil}}, nil
default:
return nil, restate.TerminalError(fmt.Errorf("awakeable entry had invalid result: %v", entry.Payload.Result), restate.ErrProtocolViolation)
}
},
func() (awakeable[[]byte], error) {
if err := c._awakeable(); err != nil {
return nil, err
}
return &suspendingAwakeable[[]byte]{entryIndex: c.entryIndex, invocationID: c.id}, nil
},
c._awakeable,
)
if err != nil {
return nil, err
}
return awakeable, nil
}

func (c *Machine) _awakeable() error {
if err := c.OneWayWrite(&protocol.AwakeableEntryMessage{}); err != nil {
return err
func (c *Machine) _awakeable() (restate.Awakeable[[]byte], error) {
completionFut, err := c.WriteWithCompletion(&protocol.AwakeableEntryMessage{})
if err != nil {
return nil, err
}
return nil
return &completionAwakeable{ctx: c.ctx, entryIndex: c.entryIndex, invocationID: c.id, completionFut: completionFut}, nil
}

func (c *Machine) resolveAwakeable(id string, value []byte) error {
Expand Down
10 changes: 7 additions & 3 deletions internal/state/call.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func (m *Machine) doCall(service, key, method string, params []byte) ([]byte, er

switch result := entry.Payload.Result.(type) {
case *protocol.CallEntryMessage_Failure:
return nil, restate.WithErrorCode(fmt.Errorf(result.Failure.Message), restate.Code(result.Failure.Code))
return nil, ErrorFromFailure(result.Failure)
case *protocol.CallEntryMessage_Value:
return result.Value, nil
}
Expand All @@ -109,7 +109,7 @@ func (m *Machine) _doCall(service, key, method string, params []byte) ([]byte, e
return nil, fmt.Errorf("failed to send request message: %w", err)
}

completion, err := completionFut.Done(m.ctx)
completion, err := completionFut.Await(m.ctx)
if err != nil {
return nil, err
}
Expand All @@ -118,7 +118,7 @@ func (m *Machine) _doCall(service, key, method string, params []byte) ([]byte, e
case *protocol.CompletionMessage_Empty:
return nil, nil
case *protocol.CompletionMessage_Failure:
return nil, restate.WithErrorCode(fmt.Errorf(result.Failure.Message), restate.Code(result.Failure.Code))
return nil, ErrorFromFailure(result.Failure)
case *protocol.CompletionMessage_Value:
return result.Value, nil
}
Expand Down Expand Up @@ -175,3 +175,7 @@ func (c *Machine) _sendCall(service, key, method string, params []byte, delay ti

return nil
}

func ErrorFromFailure(failure *protocol.Failure) error {
return restate.TerminalError(fmt.Errorf(failure.Message), restate.Code(failure.Code))
}
27 changes: 26 additions & 1 deletion internal/state/completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ func (c *CompletionFuture) complete(msg *protocol.CompletionMessage) error {
return nil
}

func (c *CompletionFuture) Done(ctx context.Context) (*protocol.CompletionMessage, error) {
func (c *CompletionFuture) Await(ctx context.Context) (*protocol.CompletionMessage, error) {
c.Lock()
defer c.Unlock()
if c.done {
return c.CompletionMessage, nil
}
Expand All @@ -78,6 +80,29 @@ func (c *CompletionFuture) Done(ctx context.Context) (*protocol.CompletionMessag
}
}

func (c *CompletionFuture) Done() (*protocol.CompletionMessage, bool) {
c.Lock()
defer c.Unlock()
if c.done {
return c.CompletionMessage, true
} else {
return nil, false
}
}

func (m *Machine) checkReplayCompletion(index uint32, msg wire.Message) {
switch msg.Type() {
case wire.SetStateEntryMessageType, wire.ClearStateEntryMessageType,
wire.ClearAllStateEntryMessageType, wire.CompleteAwakeableEntryMessageType,
wire.OneWayCallEntryMessageType:
// don't need completion
default:
if !msg.Flags().Completed() {
m.pendingCompletions[index] = newCompletion()
}
}
}

func (m *Machine) Write(message proto.Message, flag wire.Flag) (*AckFuture, *CompletionFuture, error) {
index := m.entryIndex
ack := m.pendingAcks[index]
Expand Down
2 changes: 2 additions & 0 deletions internal/state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,8 @@ func (m *Machine) process(ctx *Context, start *wire.StartMessage) error {
return fmt.Errorf("failed to read entry: %w", err)
}

m.checkReplayCompletion(i, msg)

m.log.Trace().Uint16("type", uint16(msg.Type())).Msg("replay log entry")
m.entries = append(m.entries, msg)
}
Expand Down
9 changes: 4 additions & 5 deletions internal/state/sys.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package state

import (
"bytes"
"errors"
"fmt"
"sort"
"time"
Expand Down Expand Up @@ -171,7 +170,7 @@ func (m *Machine) _get(key string) ([]byte, error) {
return nil, err
}

completion, err := completionFut.Done(m.ctx)
completion, err := completionFut.Await(m.ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -249,7 +248,7 @@ func (m *Machine) _keys() ([]string, error) {
return nil, err
}

completion, err := completionFut.Done(m.ctx)
completion, err := completionFut.Await(m.ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -313,7 +312,7 @@ func (m *Machine) _sleep(until time.Time) error {
panic(&suspend{m.entryIndex})
}

if _, err := completionFut.Done(m.ctx); err != nil {
if _, err := completionFut.Await(m.ctx); err != nil {
return err
}

Expand All @@ -327,7 +326,7 @@ func (m *Machine) sideEffect(fn func() ([]byte, error)) ([]byte, error) {
func(entry *wire.RunEntryMessage) ([]byte, error) {
switch result := entry.Payload.Result.(type) {
case *protocol.RunEntryMessage_Failure:
return nil, restate.TerminalError(errors.New(result.Failure.Message), restate.Code(result.Failure.Code))
return nil, ErrorFromFailure(result.Failure)
case *protocol.RunEntryMessage_Value:
return result.Value, nil
case nil:
Expand Down

0 comments on commit 876a856

Please sign in to comment.