From 7f7514ef8e5c38fd923f6c8b7030e05548fe8d0d Mon Sep 17 00:00:00 2001 From: Jack Kleeman Date: Tue, 9 Jul 2024 18:32:45 +0100 Subject: [PATCH] Avoid channels in awakeables --- internal/state/awakeable.go | 41 ++++++++++++------------------------- router.go | 28 +++++++++---------------- 2 files changed, 23 insertions(+), 46 deletions(-) diff --git a/internal/state/awakeable.go b/internal/state/awakeable.go index 4db714c..19b6297 100644 --- a/internal/state/awakeable.go +++ b/internal/state/awakeable.go @@ -5,6 +5,7 @@ import ( "context" "encoding/base64" "encoding/binary" + "fmt" restate "github.com/restatedev/sdk-go" "github.com/restatedev/sdk-go/generated/proto/protocol" @@ -32,35 +33,19 @@ type completionAwakeable struct { } func (c *completionAwakeable) Id() string { return awakeableID(c.invocationID, c.entryIndex) } -func (c *completionAwakeable) Chan() <-chan restate.Result[[]byte] { - ch := make(chan restate.Result[[]byte], 1) - if c.entry.Completed() { - // fast path - ch <- resultFromAwakeable(c.entry) - return ch - } - // slow path - go func() { - if err := c.entry.Await(c.ctx); err != nil { - ch <- restate.Result[[]byte]{Err: err} - } else { - ch <- resultFromAwakeable(c.entry) +func (c *completionAwakeable) Result() ([]byte, error) { + if err := c.entry.Await(c.ctx); err != nil { + return nil, err + } else { + switch result := c.entry.Result.(type) { + case *protocol.AwakeableEntryMessage_Value: + return result.Value, nil + case *protocol.AwakeableEntryMessage_Failure: + return nil, ErrorFromFailure(result.Failure) + default: + return nil, fmt.Errorf("unexpected result in completed awakeable entry: %v", c.entry.Result) } - }() - return ch -} - -type completedAwakeable[T any] struct { - invocationID []byte - entryIndex uint32 - result restate.Result[T] -} - -func (c *completedAwakeable[T]) Id() string { return awakeableID(c.invocationID, c.entryIndex) } -func (c *completedAwakeable[T]) Chan() <-chan restate.Result[T] { - ch := make(chan restate.Result[T], 1) - ch <- c.result - return ch + } } func awakeableID(invocationID []byte, entryIndex uint32) string { diff --git a/router.go b/router.go index e322de9..9dd6cca 100644 --- a/router.go +++ b/router.go @@ -254,7 +254,7 @@ func SideEffectAs[T any](ctx Context, fn func() (T, error)) (output T, err error type Awakeable[T any] interface { Id() string - Chan() <-chan Result[T] + Result() (T, error) } type Result[T any] struct { @@ -267,23 +267,15 @@ type decodingAwakeable[T any] struct { } func (d decodingAwakeable[T]) Id() string { return d.inner.Id() } -func (d decodingAwakeable[T]) Chan() <-chan Result[T] { - inner := d.inner.Chan() - out := make(chan Result[T], 1) - go func() { - result := <-inner - if result.Err != nil { - out <- Result[T]{Err: result.Err} - } else { - var value T - if err := json.Unmarshal(result.Value, &value); err != nil { - out <- Result[T]{Err: TerminalError(err)} - } else { - out <- Result[T]{Value: value} - } - } - }() - return out +func (d decodingAwakeable[T]) Result() (out T, err error) { + bytes, err := d.inner.Result() + if err != nil { + return out, err + } + if err := json.Unmarshal(bytes, &out); err != nil { + return out, err + } + return } func AwakeableAs[T any](ctx Context) (Awakeable[T], error) {