Skip to content

Commit

Permalink
Simplify flags
Browse files Browse the repository at this point in the history
  • Loading branch information
jackkleeman committed Jul 9, 2024
1 parent 751f04d commit b2a94e5
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 168 deletions.
11 changes: 4 additions & 7 deletions internal/state/awakeable.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ func awakeableID(invocationID []byte, entryIndex uint32) string {
func (c *Machine) awakeable() (restate.Awakeable[[]byte], error) {
awakeable, err := replayOrNew(
c,
wire.AwakeableEntryMessageType,
func(entry *wire.AwakeableEntryMessage) (restate.Awakeable[[]byte], error) {
return &completionAwakeable{ctx: c.ctx, entryIndex: c.entryIndex, invocationID: c.id, entry: entry}, nil
},
Expand All @@ -96,10 +95,9 @@ func (c *Machine) _awakeable() (restate.Awakeable[[]byte], error) {
func (c *Machine) resolveAwakeable(id string, value []byte) error {
_, err := replayOrNew(
c,
wire.CompleteAwakeableEntryMessageType,
func(entry *wire.CompleteAwakeableEntryMessage) (restate.Void, error) {
messageValue, ok := entry.CompleteAwakeableEntryMessage.Result.(*protocol.CompleteAwakeableEntryMessage_Value)
if entry.CompleteAwakeableEntryMessage.Id != id || !ok || !bytes.Equal(messageValue.Value, value) {
messageValue, ok := entry.Result.(*protocol.CompleteAwakeableEntryMessage_Value)
if entry.Id != id || !ok || !bytes.Equal(messageValue.Value, value) {
return restate.Void{}, errEntryMismatch
}
return restate.Void{}, nil
Expand Down Expand Up @@ -129,10 +127,9 @@ func (c *Machine) _resolveAwakeable(id string, value []byte) error {
func (c *Machine) rejectAwakeable(id string, reason error) error {
_, err := replayOrNew(
c,
wire.CompleteAwakeableEntryMessageType,
func(entry *wire.CompleteAwakeableEntryMessage) (restate.Void, error) {
messageFailure, ok := entry.CompleteAwakeableEntryMessage.Result.(*protocol.CompleteAwakeableEntryMessage_Failure)
if entry.CompleteAwakeableEntryMessage.Id != id || !ok || messageFailure.Failure.Code != uint32(restate.ErrorCode(reason)) || messageFailure.Failure.Message != reason.Error() {
messageFailure, ok := entry.Result.(*protocol.CompleteAwakeableEntryMessage_Failure)
if entry.Id != id || !ok || messageFailure.Failure.Code != uint32(restate.ErrorCode(reason)) || messageFailure.Failure.Message != reason.Error() {
return restate.Void{}, errEntryMismatch
}
return restate.Void{}, nil
Expand Down
22 changes: 10 additions & 12 deletions internal/state/call.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,23 +75,22 @@ func (m *Machine) doCall(service, key, method string, params []byte) ([]byte, er

return replayOrNew(
m,
wire.CallEntryMessageType,
func(entry *wire.CallEntryMessage) ([]byte, error) {
if entry.CallEntryMessage.ServiceName != service ||
entry.CallEntryMessage.Key != key ||
entry.CallEntryMessage.HandlerName != method ||
!bytes.Equal(entry.CallEntryMessage.Parameter, params) {
if entry.ServiceName != service ||
entry.Key != key ||
entry.HandlerName != method ||
!bytes.Equal(entry.Parameter, params) {
return nil, errEntryMismatch
}

switch result := entry.CallEntryMessage.Result.(type) {
switch result := entry.Result.(type) {
case *protocol.CallEntryMessage_Failure:
return nil, ErrorFromFailure(result.Failure)
case *protocol.CallEntryMessage_Value:
return result.Value, nil
}

return nil, restate.TerminalError(fmt.Errorf("sync call entry had invalid result: %v", entry.CallEntryMessage.Result), restate.ErrProtocolViolation)
return nil, restate.TerminalError(fmt.Errorf("sync call entry had invalid result: %v", entry.Result), restate.ErrProtocolViolation)
}, func() ([]byte, error) {
return m._doCall(service, key, method, params)
})
Expand Down Expand Up @@ -134,12 +133,11 @@ func (c *Machine) sendCall(service, key, method string, body any, delay time.Dur

_, err = replayOrNew(
c,
wire.OneWayCallEntryMessageType,
func(entry *wire.OneWayCallEntryMessage) (restate.Void, error) {
if entry.OneWayCallEntryMessage.ServiceName != service ||
entry.OneWayCallEntryMessage.Key != key ||
entry.OneWayCallEntryMessage.HandlerName != method ||
!bytes.Equal(entry.OneWayCallEntryMessage.Parameter, params) {
if entry.ServiceName != service ||
entry.Key != key ||
entry.HandlerName != method ||
!bytes.Equal(entry.Parameter, params) {
return restate.Void{}, errEntryMismatch
}

Expand Down
24 changes: 10 additions & 14 deletions internal/state/completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ func (m *Machine) ackable(entryIndex uint32) wire.AckableMessage {
}

func (m *Machine) Write(message wire.Message) error {
var flag wire.Flag
if message, ok := message.(wire.CompleteableMessage); ok && !message.Completed() {
m.mutex.Lock()
m.pendingCompletions[m.entryIndex] = message
Expand All @@ -37,9 +36,8 @@ func (m *Machine) Write(message wire.Message) error {
m.mutex.Lock()
m.pendingAcks[m.entryIndex] = message
m.mutex.Unlock()
flag |= wire.FlagRequiresAck
}
return m.protocol.Write(message, flag)
return m.protocol.Write(message)
}

func (m *Machine) handleCompletionsAcks() {
Expand All @@ -48,27 +46,25 @@ func (m *Machine) handleCompletionsAcks() {
if err != nil {
return
}
switch msg.Type() {
case wire.CompletionMessageType:
msg := msg.(*wire.CompletionMessage)
switch msg := msg.(type) {
case *wire.CompletionMessage:
completable := m.completable(msg.EntryIndex)
if completable == nil {
m.log.Error().Uint32("index", msg.CompletionMessage.EntryIndex).Msg("failed to find pending completion at index")
m.log.Error().Uint32("index", msg.EntryIndex).Msg("failed to find pending completion at index")
continue
}
completable.Complete(&msg.CompletionMessage)
m.log.Debug().Uint32("index", msg.CompletionMessage.EntryIndex).Msg("processed completion")
case wire.EntryAckMessageType:
msg := msg.(*wire.EntryAckMessage)
ackable := m.ackable(msg.EntryAckMessage.EntryIndex)
m.log.Debug().Uint32("index", msg.EntryIndex).Msg("processed completion")
case *wire.EntryAckMessage:
ackable := m.ackable(msg.EntryIndex)
if ackable == nil {
m.log.Error().Uint32("index", msg.EntryAckMessage.EntryIndex).Msg("failed to find pending ack at index")
m.log.Error().Uint32("index", msg.EntryIndex).Msg("failed to find pending ack at index")
continue
}
ackable.Ack()
m.log.Debug().Uint32("index", msg.EntryAckMessage.EntryIndex).Msg("processed ack")
m.log.Debug().Uint32("index", msg.EntryIndex).Msg("processed ack")
default:
m.log.Error().Stringer("type", msg.Type()).Msg("unexpected non-completion non-ack message during invocation")
m.log.Error().Type("type", msg).Msg("unexpected non-completion non-ack message during invocation")
continue
}
}
Expand Down
41 changes: 20 additions & 21 deletions internal/state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,18 +167,17 @@ func (m *Machine) Start(inner context.Context, trace string) error {
return err
}

if msg.Type() != wire.StartMessageType {
start, ok := msg.(*wire.StartMessage)
if !ok {
// invalid negotiation
return wire.ErrUnexpectedMessage
}

start := msg.(*wire.StartMessage)

m.ctx = inner
m.id = start.StartMessage.Id
m.key = start.StartMessage.Key
m.id = start.Id
m.key = start.Key

m.log = m.log.With().Str("id", start.StartMessage.DebugId).Str("method", trace).Logger()
m.log = m.log.With().Str("id", start.DebugId).Str("method", trace).Logger()

ctx := newContext(inner, m)

Expand Down Expand Up @@ -248,7 +247,7 @@ func (m *Machine) invoke(ctx *Context, input []byte, outputSeen bool) error {
SuspensionMessage: protocol.SuspensionMessage{
EntryIndexes: []uint32{typ.resumeEntry},
},
}, 0)
})

if err != nil {
m.log.Error().Err(err).Msg("error sending failure message")
Expand All @@ -263,14 +262,14 @@ func (m *Machine) invoke(ctx *Context, input []byte, outputSeen bool) error {
Message: fmt.Sprint(typ),
Description: string(debug.Stack()),
},
}, 0)
})

if err != nil {
m.log.Error().Err(err).Msg("error sending failure message")
}
}

if err := m.protocol.Write(&wire.EndMessage{}, 0); err != nil {
if err := m.protocol.Write(&wire.EndMessage{}); err != nil {
m.log.Error().Err(err).Msg("error sending end message")
}
}()
Expand All @@ -282,11 +281,11 @@ func (m *Machine) invoke(ctx *Context, input []byte, outputSeen bool) error {

output := m.output(m.handler.Call(ctx, input))

return m.protocol.Write(output, 0)
return m.protocol.Write(output)
}

func (m *Machine) process(ctx *Context, start *wire.StartMessage) error {
for _, entry := range start.StartMessage.StateMap {
for _, entry := range start.StateMap {
m.current[string(entry.Key)] = entry.Value
}

Expand All @@ -296,34 +295,34 @@ func (m *Machine) process(ctx *Context, start *wire.StartMessage) error {
return err
}

if msg.Type() != wire.InputEntryMessageType {
if _, ok := msg.(*wire.InputEntryMessage); !ok {
return wire.ErrUnexpectedMessage
}

m.log.Trace().Uint32("known entries", start.StartMessage.KnownEntries).Msg("known entires")
m.entries = make([]wire.Message, 0, start.StartMessage.KnownEntries-1)
m.log.Trace().Uint32("known entries", start.KnownEntries).Msg("known entires")
m.entries = make([]wire.Message, 0, start.KnownEntries-1)

outputSeen := false

// we don't track the poll input entry
for i := uint32(1); i < start.StartMessage.KnownEntries; i++ {
for i := uint32(1); i < start.KnownEntries; i++ {
msg, err := m.protocol.Read()
if err != nil {
return fmt.Errorf("failed to read entry: %w", err)
}

m.log.Trace().Uint16("type", uint16(msg.Type())).Msg("replay log entry")
m.log.Trace().Type("type", msg).Msg("replay log entry")
m.entries = append(m.entries, msg)

if msg.Type() == wire.OutputEntryMessageType {
if _, ok := msg.(*wire.OutputEntryMessage); !ok {
outputSeen = true
}
}

go m.handleCompletionsAcks()

inputMsg := msg.(*wire.InputEntryMessage)
value := inputMsg.InputEntryMessage.GetValue()
value := inputMsg.GetValue()
return m.invoke(ctx, value, outputSeen)

}
Expand Down Expand Up @@ -355,7 +354,6 @@ func (c *Machine) currentEntry() (wire.Message, bool) {
// by sending the proper runtime messages
func replayOrNew[M wire.Message, O any](
m *Machine,
typ wire.Type,
replay func(msg M) (O, error),
new func() (O, error),
) (output O, err error) {
Expand All @@ -368,10 +366,11 @@ func replayOrNew[M wire.Message, O any](
// if entry exists, we need to replay it
// by calling the replay function
if ok {
if entry.Type() != typ {
if entry, ok := entry.(M); !ok {
return output, errEntryMismatch
} else {
return replay(entry)
}
return replay(entry.(M))
}

// other wise call the new function
Expand Down
27 changes: 10 additions & 17 deletions internal/state/sys.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@ var (
func (m *Machine) set(key string, value []byte) error {
_, err := replayOrNew(
m,
wire.SetStateEntryMessageType,
func(entry *wire.SetStateEntryMessage) (void restate.Void, err error) {
if string(entry.SetStateEntryMessage.Key) != key || !bytes.Equal(entry.SetStateEntryMessage.Value, value) {
if string(entry.Key) != key || !bytes.Equal(entry.Value, value) {
return void, errEntryMismatch
}

Expand Down Expand Up @@ -51,9 +50,8 @@ func (m *Machine) _set(key string, value []byte) error {
func (m *Machine) clear(key string) error {
_, err := replayOrNew(
m,
wire.ClearStateEntryMessageType,
func(entry *wire.ClearStateEntryMessage) (void restate.Void, err error) {
if string(entry.ClearStateEntryMessage.Key) != key {
if string(entry.Key) != key {
return void, errEntryMismatch
}

Expand Down Expand Up @@ -85,7 +83,6 @@ func (m *Machine) _clear(key string) error {
func (m *Machine) clearAll() error {
_, err := replayOrNew(
m,
wire.ClearAllStateEntryMessageType,
func(entry *wire.ClearAllStateEntryMessage) (void restate.Void, err error) {
return
}, func() (restate.Void, error) {
Expand All @@ -112,13 +109,12 @@ func (m *Machine) _clearAll() error {
func (m *Machine) get(key string) ([]byte, error) {
return replayOrNew(
m,
wire.GetStateEntryMessageType,
func(entry *wire.GetStateEntryMessage) ([]byte, error) {
if string(entry.GetStateEntryMessage.Key) != key {
if string(entry.Key) != key {
return nil, errEntryMismatch
}

switch result := entry.GetStateEntryMessage.Result.(type) {
switch result := entry.Result.(type) {
case *protocol.GetStateEntryMessage_Empty:
return nil, nil
case *protocol.GetStateEntryMessage_Failure:
Expand All @@ -127,7 +123,7 @@ func (m *Machine) get(key string) ([]byte, error) {
return result.Value, nil
}

return nil, restate.TerminalError(fmt.Errorf("get state entry had invalid result: %v", entry.GetStateEntryMessage.Result), restate.ErrProtocolViolation)
return nil, restate.TerminalError(fmt.Errorf("get state entry had invalid result: %v", entry.Result), restate.ErrProtocolViolation)
}, func() ([]byte, error) {
return m._get(key)
})
Expand Down Expand Up @@ -196,9 +192,8 @@ func (m *Machine) _get(key string) ([]byte, error) {
func (m *Machine) keys() ([]string, error) {
return replayOrNew(
m,
wire.GetStateKeysEntryMessageType,
func(entry *wire.GetStateKeysEntryMessage) ([]string, error) {
switch result := entry.GetStateKeysEntryMessage.Result.(type) {
switch result := entry.Result.(type) {
case *protocol.GetStateKeysEntryMessage_Failure:
return nil, fmt.Errorf("[%d] %s", result.Failure.Code, result.Failure.Message)
case *protocol.GetStateKeysEntryMessage_Value:
Expand All @@ -209,7 +204,7 @@ func (m *Machine) keys() ([]string, error) {
return keys, nil
}

return nil, restate.TerminalError(fmt.Errorf("found get state keys entry with invalid completion: %v", entry.GetStateKeysEntryMessage.Result), 571)
return nil, restate.TerminalError(fmt.Errorf("found get state keys entry with invalid completion: %v", entry.Result), 571)
},
m._keys,
)
Expand Down Expand Up @@ -277,7 +272,6 @@ func (m *Machine) _keys() ([]string, error) {
func (m *Machine) sleep(until time.Time) error {
_, err := replayOrNew(
m,
wire.SleepEntryMessageType,
func(entry *wire.SleepEntryMessage) (void restate.Void, err error) {
// we shouldn't verify the time because this would be different every time
return
Expand Down Expand Up @@ -318,9 +312,8 @@ func (m *Machine) _sleep(until time.Time) error {
func (m *Machine) sideEffect(fn func() ([]byte, error)) ([]byte, error) {
return replayOrNew(
m,
wire.RunEntryMessageType,
func(entry *wire.RunEntryMessage) ([]byte, error) {
switch result := entry.RunEntryMessage.Result.(type) {
switch result := entry.Result.(type) {
case *protocol.RunEntryMessage_Failure:
return nil, ErrorFromFailure(result.Failure)
case *protocol.RunEntryMessage_Value:
Expand All @@ -330,7 +323,7 @@ func (m *Machine) sideEffect(fn func() ([]byte, error)) ([]byte, error) {
return nil, nil
}

return nil, restate.TerminalError(fmt.Errorf("side effect entry had invalid result: %v", entry.RunEntryMessage.Result), restate.ErrProtocolViolation)
return nil, restate.TerminalError(fmt.Errorf("side effect entry had invalid result: %v", entry.Result), restate.ErrProtocolViolation)
},
func() ([]byte, error) {
return m._sideEffect(fn)
Expand Down Expand Up @@ -367,7 +360,7 @@ func (m *Machine) _sideEffect(fn func() ([]byte, error)) ([]byte, error) {
RelatedEntryType: &ty,
},
}
if err := m.protocol.Write(&msg, 0); err != nil {
if err := m.protocol.Write(&msg); err != nil {
return nil, err
}
}
Expand Down
Loading

0 comments on commit b2a94e5

Please sign in to comment.