diff --git a/compiler/compile.go b/compiler/compile.go index dbc5898..403d0b0 100644 --- a/compiler/compile.go +++ b/compiler/compile.go @@ -255,6 +255,7 @@ func (c *compiler) compileFunction(p *packages.Package, fn *ast.FuncDecl, color ctx := ast.NewIdent("_c") frame := ast.NewIdent("_f") + fp := ast.NewIdent("_fp") yieldTypeExpr := make([]ast.Expr, 2) yieldTypeExpr[0] = typeExpr(color.Params().At(0).Type()) @@ -277,9 +278,9 @@ func (c *compiler) compileFunction(p *packages.Package, fn *ast.FuncDecl, color }, }) - // _f := _c.Push() + // _f, _fp := _c.Push() gen.Body.List = append(gen.Body.List, &ast.AssignStmt{ - Lhs: []ast.Expr{frame}, + Lhs: []ast.Expr{frame, fp}, Tok: token.DEFINE, Rhs: []ast.Expr{ &ast.CallExpr{ @@ -379,7 +380,14 @@ func (c *compiler) compileFunction(p *packages.Package, fn *ast.FuncDecl, color Cond: &ast.CallExpr{ Fun: &ast.SelectorExpr{X: ctx, Sel: ast.NewIdent("Unwinding")}, }, - Body: &ast.BlockStmt{List: saveStmts}, + Body: &ast.BlockStmt{ + List: append(saveStmts, &ast.ExprStmt{ + X: &ast.CallExpr{ + Fun: &ast.SelectorExpr{X: ctx, Sel: ast.NewIdent("Store")}, + Args: []ast.Expr{fp, frame}, + }, + }), + }, Else: &ast.BlockStmt{List: []ast.Stmt{ &ast.ExprStmt{X: &ast.CallExpr{Fun: &ast.SelectorExpr{X: ctx, Sel: ast.NewIdent("Pop")}}}}, }, diff --git a/compiler/testdata/coroutine_durable.go b/compiler/testdata/coroutine_durable.go index 6a3d81f..031bb4d 100644 --- a/compiler/testdata/coroutine_durable.go +++ b/compiler/testdata/coroutine_durable.go @@ -15,13 +15,14 @@ import ( func Identity(n int) { _c := coroutine.LoadContext[int, any]() - _f := _c.Push() + _f, _fp := _c.Push() if _f.IP > 0 { n = _f.Get(0).(int) } defer func() { if _c.Unwinding() { _f.Set(0, n) + _c.Store(_fp, _f) } else { _c.Pop() } @@ -31,7 +32,7 @@ func Identity(n int) { func SquareGenerator(n int) { _c := coroutine.LoadContext[int, any]() - _f := _c.Push() + _f, _fp := _c.Push() var _o0 int if _f.IP > 0 { n = _f.Get(0).(int) @@ -41,6 +42,7 @@ func SquareGenerator(n int) { if _c.Unwinding() { _f.Set(0, n) _f.Set(1, _o0) + _c.Store(_fp, _f) } else { _c.Pop() } @@ -59,13 +61,14 @@ func SquareGenerator(n int) { func SquareGeneratorTwice(n int) { _c := coroutine.LoadContext[int, any]() - _f := _c.Push() + _f, _fp := _c.Push() if _f.IP > 0 { n = _f.Get(0).(int) } defer func() { if _c.Unwinding() { _f.Set(0, n) + _c.Store(_fp, _f) } else { _c.Pop() } @@ -82,7 +85,7 @@ func SquareGeneratorTwice(n int) { func SquareGeneratorTwiceLoop(n int) { _c := coroutine.LoadContext[int, any]() - _f := _c.Push() + _f, _fp := _c.Push() var _o0 int if _f.IP > 0 { n = _f.Get(0).(int) @@ -92,6 +95,7 @@ func SquareGeneratorTwiceLoop(n int) { if _c.Unwinding() { _f.Set(0, n) _f.Set(1, _o0) + _c.Store(_fp, _f) } else { _c.Pop() } @@ -110,7 +114,7 @@ func SquareGeneratorTwiceLoop(n int) { func EvenSquareGenerator(n int) { _c := coroutine.LoadContext[int, any]() - _f := _c.Push() + _f, _fp := _c.Push() var _o0 int var _o1 int if _f.IP > 0 { @@ -123,6 +127,7 @@ func EvenSquareGenerator(n int) { _f.Set(0, n) _f.Set(1, _o0) _f.Set(2, _o1) + _c.Store(_fp, _f) } else { _c.Pop() } @@ -150,7 +155,7 @@ func EvenSquareGenerator(n int) { func NestedLoops(n int) { _c := coroutine.LoadContext[int, any]() - _f := _c.Push() + _f, _fp := _c.Push() var _o0 int var _o1 int var _o2 int @@ -166,6 +171,7 @@ func NestedLoops(n int) { _f.Set(1, _o0) _f.Set(2, _o1) _f.Set(3, _o2) + _c.Store(_fp, _f) } else { _c.Pop() } @@ -202,7 +208,7 @@ func NestedLoops(n int) { func FizzBuzzIfGenerator(n int) { _c := coroutine.LoadContext[int, any]() - _f := _c.Push() + _f, _fp := _c.Push() var _o0 int var _o1 int if _f.IP > 0 { @@ -216,6 +222,7 @@ func FizzBuzzIfGenerator(n int) { _f.Set(0, n) _f.Set(1, _o0) _f.Set(2, _o1) + _c.Store(_fp, _f) } else { _c.Pop() } @@ -246,7 +253,7 @@ func FizzBuzzIfGenerator(n int) { func FizzBuzzSwitchGenerator(n int) { _c := coroutine.LoadContext[int, any]() - _f := _c.Push() + _f, _fp := _c.Push() var _o0 int if _f.IP > 0 { n = _f.Get(0).(int) @@ -256,6 +263,7 @@ func FizzBuzzSwitchGenerator(n int) { if _c.Unwinding() { _f.Set(0, n) _f.Set(1, _o0) + _c.Store(_fp, _f) } else { _c.Pop() } @@ -284,7 +292,7 @@ func FizzBuzzSwitchGenerator(n int) { func Shadowing(_ int) { _c := coroutine.LoadContext[int, any]() - _f := _c.Push() + _f, _fp := _c.Push() var _o0 int var _o1 int var _o2 int @@ -344,6 +352,7 @@ func Shadowing(_ int) { _f.Set(7, _o7) _f.Set(8, _o8) _f.Set(9, _o11) + _c.Store(_fp, _f) } else { _c.Pop() } @@ -541,7 +550,7 @@ func Shadowing(_ int) { func RangeSliceIndexGenerator(_ int) { _c := coroutine.LoadContext[int, any]() - _f := _c.Push() + _f, _fp := _c.Push() var _o0 []int var _o1 int if _f.IP > 0 { @@ -552,6 +561,7 @@ func RangeSliceIndexGenerator(_ int) { if _c.Unwinding() { _f.Set(0, _o0) _f.Set(1, _o1) + _c.Store(_fp, _f) } else { _c.Pop() } @@ -577,7 +587,7 @@ func RangeSliceIndexGenerator(_ int) { func RangeArrayIndexValueGenerator(_ int) { _c := coroutine.LoadContext[int, any]() - _f := _c.Push() + _f, _fp := _c.Push() var _o0 [3]int var _o1 int var _o2 int @@ -591,6 +601,7 @@ func RangeArrayIndexValueGenerator(_ int) { _f.Set(0, _o0) _f.Set(1, _o1) _f.Set(2, _o2) + _c.Store(_fp, _f) } else { _c.Pop() } @@ -627,7 +638,7 @@ func RangeArrayIndexValueGenerator(_ int) { func TypeSwitchingGenerator(_ int) { _c := coroutine.LoadContext[int, any]() - _f := _c.Push() + _f, _fp := _c.Push() var _o0 []any var _o1 int var _o2 any @@ -641,6 +652,7 @@ func TypeSwitchingGenerator(_ int) { _f.Set(0, _o0) _f.Set(1, _o1) _f.Set(2, _o2) + _c.Store(_fp, _f) } else { _c.Pop() } @@ -695,7 +707,7 @@ func TypeSwitchingGenerator(_ int) { func LoopBreakAndContinue(_ int) { _c := coroutine.LoadContext[int, any]() - _f := _c.Push() + _f, _fp := _c.Push() var _o0 int var _o1 int var _o2 int @@ -713,6 +725,7 @@ func LoopBreakAndContinue(_ int) { _f.Set(1, _o1) _f.Set(2, _o2) _f.Set(3, _o3) + _c.Store(_fp, _f) } else { _c.Pop() } @@ -800,7 +813,7 @@ func LoopBreakAndContinue(_ int) { func RangeOverMaps(n int) { _c := coroutine.LoadContext[int, any]() - _f := _c.Push() + _f, _fp := _c.Push() var _o0 map[int]int var _o1 map[int]int var _o2 int @@ -929,6 +942,7 @@ func RangeOverMaps(n int) { _f.Set(36, _o35) _f.Set(37, _o36) _f.Set(38, _o37) + _c.Store(_fp, _f) } else { _c.Pop() } diff --git a/context_durable.go b/context_durable.go index d73e5bf..6b9a88d 100644 --- a/context_durable.go +++ b/context_durable.go @@ -5,8 +5,9 @@ package coroutine import "github.com/stealthrocket/coroutine/internal/serde" type serializedCoroutine struct { - entry func() - stack Stack + entry func() + stack Stack + resume bool } func init() { @@ -24,9 +25,10 @@ type Context[R, S any] struct { recv R send S - // Booleans managing the completion state of the coroutine. - done bool - stop bool + // Booleans managing the state of the coroutine. + done bool + stop bool + resume bool // Entry point of the coroutine, this is captured so the associated // generator can call into the coroutine to start or resume it at the @@ -39,8 +41,9 @@ type Context[R, S any] struct { // MarshalAppend appends a serialized Context to the provided buffer. func (c *Context[R, S]) MarshalAppend(b []byte) ([]byte, error) { s := serde.Serialize(&serializedCoroutine{ - entry: c.entry, - stack: c.Stack, + entry: c.entry, + stack: c.Stack, + resume: c.resume, }) return append(b, s...), nil } @@ -54,14 +57,15 @@ func (c *Context[R, S]) Unmarshal(b []byte) (int, error) { s := v.(*serializedCoroutine) c.entry = s.entry c.Stack = s.stack + c.resume = s.resume sn := start - len(b) return sn, nil } // TODO: do we have use cases for yielding more than one value? func (c *Context[R, S]) Yield(value R) S { - if frame := c.Top(); frame.Resume { - frame.Resume = false + if c.resume { + c.resume = false if c.stop { panic(unwind{}) } @@ -71,7 +75,7 @@ func (c *Context[R, S]) Yield(value R) S { panic("cannot yield from a coroutine that has been stopped") } var zero S - frame.Resume = true + c.resume = true c.send = zero c.recv = value panic(unwind{}) @@ -80,7 +84,7 @@ func (c *Context[R, S]) Yield(value R) S { // Unwinding returns true if the coroutine is currently unwinding its stack. func (c *Context[R, S]) Unwinding() bool { - return len(c.Frames) > 0 && c.Top().Resume + return c.resume } type unwind struct{} diff --git a/coroutine_durable.go b/coroutine_durable.go index 0939352..b1bd529 100644 --- a/coroutine_durable.go +++ b/coroutine_durable.go @@ -25,13 +25,14 @@ func (c Coroutine[R, S]) Next() (hasNext bool) { defer func() { clearContext(g) + switch err := recover(); err { + case nil: + case unwind{}: + default: + panic(err) + } + if c.ctx.Unwinding() { - switch err := recover(); err { - case nil: - case unwind{}: - default: - panic(err) - } stop := c.ctx.stop c.ctx.done, hasNext = stop, !stop } else { diff --git a/stack_durable.go b/stack_durable.go index d29a3c5..b1414dc 100644 --- a/stack_durable.go +++ b/stack_durable.go @@ -12,27 +12,24 @@ type Stack struct { Frames []Frame } -// Top returns the top of the call stack. -func (s *Stack) Top() *Frame { - if len(s.Frames) == 0 { - panic("no stack frames") - } - return &s.Frames[len(s.Frames)-1] -} - // Push prepares the stack for an impending function call. // -// The stack's frame pointer is incremented, and a Frame is pushed to the -// stack if the caller is on the topmost frame. +// The stack's frame pointer is incremented, and the stack is resized +// to make room for a new frame if the caller is on the topmost frame. // // If the caller is not on the topmost frame it means that a coroutine // is being resumed and the next frame is already present on the stack. -func (s *Stack) Push() *Frame { +// +// The Frame is returned by value rather than by reference, since the +// stack's underlying frame backing array might change. Callers +// intending to serialize the stack should call Store(fp, frame) for each +// frame during stack unwinding. +func (s *Stack) Push() (frame Frame, fp int) { if s.isTop() { s.Frames = append(s.Frames, Frame{}) } s.FP++ - return &s.Frames[s.FP] + return s.Frames[s.FP], s.FP } // Pop pops the topmost stack frame after a function call. @@ -44,6 +41,14 @@ func (s *Stack) Pop() { s.FP-- } +// Store stores a frame at the specified index. +func (s *Stack) Store(i int, f Frame) { + if i < 0 || i >= len(s.Frames) { + panic("invalid frame index") + } + s.Frames[i] = f +} + func (s *Stack) isTop() bool { return s.FP == len(s.Frames)-1 } @@ -59,8 +64,4 @@ type Frame struct { // Storage holds the Serializable objects on the frame. Storage - - // Resume is true if the function associated with the frame - // previously yielded. - Resume bool }