From 924901f2d66a2a24f927cbb48a631d001a323591 Mon Sep 17 00:00:00 2001 From: Chris O'Hara Date: Tue, 19 Sep 2023 08:19:28 +1000 Subject: [PATCH 1/4] Twice loop --- compiler/coroutine_test.go | 6 ++ compiler/testdata/coroutine.go | 6 ++ compiler/testdata/coroutine_durable.go | 102 +++++++++++-------------- 3 files changed, 58 insertions(+), 56 deletions(-) diff --git a/compiler/coroutine_test.go b/compiler/coroutine_test.go index add4d07..136f5bc 100644 --- a/compiler/coroutine_test.go +++ b/compiler/coroutine_test.go @@ -34,6 +34,12 @@ func TestCoroutineYield(t *testing.T) { yields: []int{1, 4, 9, 16, 1, 4, 9, 16}, }, + { + name: "square generator twice loop", + coro: func() { SquareGeneratorTwiceLoop(4) }, + yields: []int{1, 4, 9, 16, 1, 4, 9, 16}, + }, + { name: "even square generator", coro: func() { EvenSquareGenerator(6) }, diff --git a/compiler/testdata/coroutine.go b/compiler/testdata/coroutine.go index 6f624f8..ed46bc6 100644 --- a/compiler/testdata/coroutine.go +++ b/compiler/testdata/coroutine.go @@ -25,6 +25,12 @@ func SquareGeneratorTwice(n int) { SquareGenerator(n) } +func SquareGeneratorTwiceLoop(n int) { + for i := 0; i < 2; i++ { + SquareGenerator(n) + } +} + func EvenSquareGenerator(n int) { for i := 1; i <= n; i++ { if mod2 := i % 2; mod2 == 0 { diff --git a/compiler/testdata/coroutine_durable.go b/compiler/testdata/coroutine_durable.go index a38ec6b..6a3d81f 100644 --- a/compiler/testdata/coroutine_durable.go +++ b/compiler/testdata/coroutine_durable.go @@ -80,6 +80,34 @@ func SquareGeneratorTwice(n int) { } } +func SquareGeneratorTwiceLoop(n int) { + _c := coroutine.LoadContext[int, any]() + _f := _c.Push() + var _o0 int + if _f.IP > 0 { + n = _f.Get(0).(int) + _o0 = _f.Get(1).(int) + } + defer func() { + if _c.Unwinding() { + _f.Set(0, n) + _f.Set(1, _o0) + } else { + _c.Pop() + } + }() + switch { + case _f.IP < 2: + _o0 = 0 + _f.IP = 2 + fallthrough + case _f.IP < 3: + for ; _o0 < 2; _o0, _f.IP = _o0+1, 2 { + SquareGenerator(n) + } + } +} + func EvenSquareGenerator(n int) { _c := coroutine.LoadContext[int, any]() _f := _c.Push() @@ -1275,11 +1303,14 @@ func RangeOverMaps(n int) { func init() { serde.RegisterType[**byte]() serde.RegisterType[*[100000]uintptr]() + serde.RegisterType[*[1125899906842623]byte]() serde.RegisterType[*[131072]uint16]() serde.RegisterType[*[140737488355327]byte]() serde.RegisterType[*[16]byte]() serde.RegisterType[*[171]uint8]() serde.RegisterType[*[1]uintptr]() + serde.RegisterType[*[268435456]uintptr]() + serde.RegisterType[*[281474976710655]uint32]() serde.RegisterType[*[2]float32]() serde.RegisterType[*[2]float64]() serde.RegisterType[*[2]uint32]() @@ -1287,7 +1318,7 @@ func init() { serde.RegisterType[*[32]rune]() serde.RegisterType[*[32]uintptr]() serde.RegisterType[*[4]byte]() - serde.RegisterType[*[512]uintptr]() + serde.RegisterType[*[562949953421311]uint16]() serde.RegisterType[*[65536]uintptr]() serde.RegisterType[*[70368744177663]uint16]() serde.RegisterType[*[8]byte]() @@ -1321,47 +1352,48 @@ func init() { serde.RegisterType[[129]uint8]() serde.RegisterType[[131072]uintptr]() serde.RegisterType[[14]byte]() + serde.RegisterType[[15]uint64]() serde.RegisterType[[16384]byte]() serde.RegisterType[[16384]uint8]() serde.RegisterType[[16]byte]() serde.RegisterType[[16]uint64]() - serde.RegisterType[[16]uintptr]() serde.RegisterType[[17]string]() serde.RegisterType[[1]byte]() - serde.RegisterType[[1]uint32]() serde.RegisterType[[1]uint64]() serde.RegisterType[[1]uint8]() serde.RegisterType[[1]uintptr]() serde.RegisterType[[20]byte]() serde.RegisterType[[21]byte]() + serde.RegisterType[[23]uint64]() serde.RegisterType[[249]uint8]() serde.RegisterType[[24]byte]() + serde.RegisterType[[24]uint32]() serde.RegisterType[[252]uintptr]() serde.RegisterType[[253]uintptr]() serde.RegisterType[[256]uint64]() - serde.RegisterType[[29]uint64]() - serde.RegisterType[[2]int32]() - serde.RegisterType[[2]uint32]() serde.RegisterType[[2]uint64]() serde.RegisterType[[2]uintptr]() serde.RegisterType[[32]byte]() serde.RegisterType[[32]string]() - serde.RegisterType[[32]uint32]() serde.RegisterType[[32]uintptr]() serde.RegisterType[[33]float64]() serde.RegisterType[[3]byte]() serde.RegisterType[[3]int]() + serde.RegisterType[[3]uint16]() + serde.RegisterType[[3]uint32]() serde.RegisterType[[4096]byte]() - serde.RegisterType[[40]int8]() + serde.RegisterType[[40]byte]() + serde.RegisterType[[44]byte]() serde.RegisterType[[4]byte]() serde.RegisterType[[4]float64]() serde.RegisterType[[4]string]() + serde.RegisterType[[4]uint16]() + serde.RegisterType[[4]uint32]() serde.RegisterType[[4]uint64]() serde.RegisterType[[4]uintptr]() serde.RegisterType[[50]uintptr]() serde.RegisterType[[512]byte]() serde.RegisterType[[512]uintptr]() - serde.RegisterType[[56]int8]() serde.RegisterType[[5]byte]() serde.RegisterType[[5]uint]() serde.RegisterType[[61]struct { @@ -1371,9 +1403,10 @@ func init() { }]() serde.RegisterType[[64488]byte]() serde.RegisterType[[64]byte]() - serde.RegisterType[[64]uint64]() serde.RegisterType[[64]uintptr]() serde.RegisterType[[65528]byte]() + serde.RegisterType[[65]uint32]() + serde.RegisterType[[65]uintptr]() serde.RegisterType[[68]struct { Size uint32 Mallocs uint64 @@ -1384,14 +1417,14 @@ func init() { serde.RegisterType[[68]uint64]() serde.RegisterType[[68]uint8]() serde.RegisterType[[6]uintptr]() - serde.RegisterType[[7]uint64]() - serde.RegisterType[[88]byte]() + serde.RegisterType[[8192]byte]() serde.RegisterType[[8]byte]() - serde.RegisterType[[8]int8]() serde.RegisterType[[8]string]() + serde.RegisterType[[8]uint64]() serde.RegisterType[[8]uint8]() serde.RegisterType[[96]byte]() serde.RegisterType[[9]string]() + serde.RegisterType[[9]uintptr]() serde.RegisterType[[][]int32]() serde.RegisterType[[]byte]() serde.RegisterType[[]float64]() @@ -1447,53 +1480,10 @@ func init() { needed bool alignme uint64 }]() - serde.RegisterType[struct { - fd int32 - cmd int32 - arg int32 - ret int32 - errno int32 - }]() serde.RegisterType[struct { fill uint64 capacity uint64 }]() - serde.RegisterType[struct { - fn uintptr - a1 uintptr - a2 uintptr - a3 uintptr - a4 uintptr - a5 uintptr - a6 uintptr - r1 uintptr - r2 uintptr - err uintptr - }]() - serde.RegisterType[struct { - fn uintptr - a1 uintptr - a2 uintptr - a3 uintptr - a4 uintptr - a5 uintptr - f1 float64 - r1 uintptr - }]() - serde.RegisterType[struct { - fn uintptr - a1 uintptr - a2 uintptr - a3 uintptr - r1 uintptr - r2 uintptr - err uintptr - }]() - serde.RegisterType[struct { - t int64 - numer uint32 - denom uint32 - }]() serde.RegisterType[struct { tick uint64 i int From 86d9f3db9d0a69bd7d8a1426d27a6c833cbd7080 Mon Sep 17 00:00:00 2001 From: Chris O'Hara Date: Tue, 19 Sep 2023 08:45:07 +1000 Subject: [PATCH 2/4] Don't suppress actual panics --- coroutine_durable.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/coroutine_durable.go b/coroutine_durable.go index ddfa7b3..0939352 100644 --- a/coroutine_durable.go +++ b/coroutine_durable.go @@ -26,7 +26,12 @@ func (c Coroutine[R, S]) Next() (hasNext bool) { clearContext(g) if c.ctx.Unwinding() { - recover() + switch err := recover(); err { + case nil: + case unwind{}: + default: + panic(err) + } stop := c.ctx.stop c.ctx.done, hasNext = stop, !stop } else { From 2bea6f5c58837d1f4fa94fea11175b02c2004558 Mon Sep 17 00:00:00 2001 From: Chris O'Hara Date: Tue, 19 Sep 2023 09:07:04 +1000 Subject: [PATCH 3/4] Fix stop check --- context_durable.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/context_durable.go b/context_durable.go index 56275ea..d73e5bf 100644 --- a/context_durable.go +++ b/context_durable.go @@ -60,9 +60,6 @@ func (c *Context[R, S]) Unmarshal(b []byte) (int, error) { // TODO: do we have use cases for yielding more than one value? func (c *Context[R, S]) Yield(value R) S { - if c.stop { - panic("cannot yield from a coroutine that has been stopped") - } if frame := c.Top(); frame.Resume { frame.Resume = false if c.stop { @@ -70,6 +67,9 @@ func (c *Context[R, S]) Yield(value R) S { } return c.send } else { + if c.stop { + panic("cannot yield from a coroutine that has been stopped") + } var zero S frame.Resume = true c.send = zero From 7c5c7624ba84091cbaa78f8eadce373bed4bcee1 Mon Sep 17 00:00:00 2001 From: Chris O'Hara Date: Tue, 19 Sep 2023 09:18:08 +1000 Subject: [PATCH 4/4] Modify frames by value then store when unwinding --- compiler/compile.go | 14 +++++++-- compiler/testdata/coroutine_durable.go | 42 +++++++++++++++++--------- context_durable.go | 26 +++++++++------- coroutine_durable.go | 13 ++++---- stack_durable.go | 33 ++++++++++---------- 5 files changed, 78 insertions(+), 50 deletions(-) diff --git a/compiler/compile.go b/compiler/compile.go index dbc5898..403d0b0 100644 --- a/compiler/compile.go +++ b/compiler/compile.go @@ -255,6 +255,7 @@ func (c *compiler) compileFunction(p *packages.Package, fn *ast.FuncDecl, color ctx := ast.NewIdent("_c") frame := ast.NewIdent("_f") + fp := ast.NewIdent("_fp") yieldTypeExpr := make([]ast.Expr, 2) yieldTypeExpr[0] = typeExpr(color.Params().At(0).Type()) @@ -277,9 +278,9 @@ func (c *compiler) compileFunction(p *packages.Package, fn *ast.FuncDecl, color }, }) - // _f := _c.Push() + // _f, _fp := _c.Push() gen.Body.List = append(gen.Body.List, &ast.AssignStmt{ - Lhs: []ast.Expr{frame}, + Lhs: []ast.Expr{frame, fp}, Tok: token.DEFINE, Rhs: []ast.Expr{ &ast.CallExpr{ @@ -379,7 +380,14 @@ func (c *compiler) compileFunction(p *packages.Package, fn *ast.FuncDecl, color Cond: &ast.CallExpr{ Fun: &ast.SelectorExpr{X: ctx, Sel: ast.NewIdent("Unwinding")}, }, - Body: &ast.BlockStmt{List: saveStmts}, + Body: &ast.BlockStmt{ + List: append(saveStmts, &ast.ExprStmt{ + X: &ast.CallExpr{ + Fun: &ast.SelectorExpr{X: ctx, Sel: ast.NewIdent("Store")}, + Args: []ast.Expr{fp, frame}, + }, + }), + }, Else: &ast.BlockStmt{List: []ast.Stmt{ &ast.ExprStmt{X: &ast.CallExpr{Fun: &ast.SelectorExpr{X: ctx, Sel: ast.NewIdent("Pop")}}}}, }, diff --git a/compiler/testdata/coroutine_durable.go b/compiler/testdata/coroutine_durable.go index 6a3d81f..031bb4d 100644 --- a/compiler/testdata/coroutine_durable.go +++ b/compiler/testdata/coroutine_durable.go @@ -15,13 +15,14 @@ import ( func Identity(n int) { _c := coroutine.LoadContext[int, any]() - _f := _c.Push() + _f, _fp := _c.Push() if _f.IP > 0 { n = _f.Get(0).(int) } defer func() { if _c.Unwinding() { _f.Set(0, n) + _c.Store(_fp, _f) } else { _c.Pop() } @@ -31,7 +32,7 @@ func Identity(n int) { func SquareGenerator(n int) { _c := coroutine.LoadContext[int, any]() - _f := _c.Push() + _f, _fp := _c.Push() var _o0 int if _f.IP > 0 { n = _f.Get(0).(int) @@ -41,6 +42,7 @@ func SquareGenerator(n int) { if _c.Unwinding() { _f.Set(0, n) _f.Set(1, _o0) + _c.Store(_fp, _f) } else { _c.Pop() } @@ -59,13 +61,14 @@ func SquareGenerator(n int) { func SquareGeneratorTwice(n int) { _c := coroutine.LoadContext[int, any]() - _f := _c.Push() + _f, _fp := _c.Push() if _f.IP > 0 { n = _f.Get(0).(int) } defer func() { if _c.Unwinding() { _f.Set(0, n) + _c.Store(_fp, _f) } else { _c.Pop() } @@ -82,7 +85,7 @@ func SquareGeneratorTwice(n int) { func SquareGeneratorTwiceLoop(n int) { _c := coroutine.LoadContext[int, any]() - _f := _c.Push() + _f, _fp := _c.Push() var _o0 int if _f.IP > 0 { n = _f.Get(0).(int) @@ -92,6 +95,7 @@ func SquareGeneratorTwiceLoop(n int) { if _c.Unwinding() { _f.Set(0, n) _f.Set(1, _o0) + _c.Store(_fp, _f) } else { _c.Pop() } @@ -110,7 +114,7 @@ func SquareGeneratorTwiceLoop(n int) { func EvenSquareGenerator(n int) { _c := coroutine.LoadContext[int, any]() - _f := _c.Push() + _f, _fp := _c.Push() var _o0 int var _o1 int if _f.IP > 0 { @@ -123,6 +127,7 @@ func EvenSquareGenerator(n int) { _f.Set(0, n) _f.Set(1, _o0) _f.Set(2, _o1) + _c.Store(_fp, _f) } else { _c.Pop() } @@ -150,7 +155,7 @@ func EvenSquareGenerator(n int) { func NestedLoops(n int) { _c := coroutine.LoadContext[int, any]() - _f := _c.Push() + _f, _fp := _c.Push() var _o0 int var _o1 int var _o2 int @@ -166,6 +171,7 @@ func NestedLoops(n int) { _f.Set(1, _o0) _f.Set(2, _o1) _f.Set(3, _o2) + _c.Store(_fp, _f) } else { _c.Pop() } @@ -202,7 +208,7 @@ func NestedLoops(n int) { func FizzBuzzIfGenerator(n int) { _c := coroutine.LoadContext[int, any]() - _f := _c.Push() + _f, _fp := _c.Push() var _o0 int var _o1 int if _f.IP > 0 { @@ -216,6 +222,7 @@ func FizzBuzzIfGenerator(n int) { _f.Set(0, n) _f.Set(1, _o0) _f.Set(2, _o1) + _c.Store(_fp, _f) } else { _c.Pop() } @@ -246,7 +253,7 @@ func FizzBuzzIfGenerator(n int) { func FizzBuzzSwitchGenerator(n int) { _c := coroutine.LoadContext[int, any]() - _f := _c.Push() + _f, _fp := _c.Push() var _o0 int if _f.IP > 0 { n = _f.Get(0).(int) @@ -256,6 +263,7 @@ func FizzBuzzSwitchGenerator(n int) { if _c.Unwinding() { _f.Set(0, n) _f.Set(1, _o0) + _c.Store(_fp, _f) } else { _c.Pop() } @@ -284,7 +292,7 @@ func FizzBuzzSwitchGenerator(n int) { func Shadowing(_ int) { _c := coroutine.LoadContext[int, any]() - _f := _c.Push() + _f, _fp := _c.Push() var _o0 int var _o1 int var _o2 int @@ -344,6 +352,7 @@ func Shadowing(_ int) { _f.Set(7, _o7) _f.Set(8, _o8) _f.Set(9, _o11) + _c.Store(_fp, _f) } else { _c.Pop() } @@ -541,7 +550,7 @@ func Shadowing(_ int) { func RangeSliceIndexGenerator(_ int) { _c := coroutine.LoadContext[int, any]() - _f := _c.Push() + _f, _fp := _c.Push() var _o0 []int var _o1 int if _f.IP > 0 { @@ -552,6 +561,7 @@ func RangeSliceIndexGenerator(_ int) { if _c.Unwinding() { _f.Set(0, _o0) _f.Set(1, _o1) + _c.Store(_fp, _f) } else { _c.Pop() } @@ -577,7 +587,7 @@ func RangeSliceIndexGenerator(_ int) { func RangeArrayIndexValueGenerator(_ int) { _c := coroutine.LoadContext[int, any]() - _f := _c.Push() + _f, _fp := _c.Push() var _o0 [3]int var _o1 int var _o2 int @@ -591,6 +601,7 @@ func RangeArrayIndexValueGenerator(_ int) { _f.Set(0, _o0) _f.Set(1, _o1) _f.Set(2, _o2) + _c.Store(_fp, _f) } else { _c.Pop() } @@ -627,7 +638,7 @@ func RangeArrayIndexValueGenerator(_ int) { func TypeSwitchingGenerator(_ int) { _c := coroutine.LoadContext[int, any]() - _f := _c.Push() + _f, _fp := _c.Push() var _o0 []any var _o1 int var _o2 any @@ -641,6 +652,7 @@ func TypeSwitchingGenerator(_ int) { _f.Set(0, _o0) _f.Set(1, _o1) _f.Set(2, _o2) + _c.Store(_fp, _f) } else { _c.Pop() } @@ -695,7 +707,7 @@ func TypeSwitchingGenerator(_ int) { func LoopBreakAndContinue(_ int) { _c := coroutine.LoadContext[int, any]() - _f := _c.Push() + _f, _fp := _c.Push() var _o0 int var _o1 int var _o2 int @@ -713,6 +725,7 @@ func LoopBreakAndContinue(_ int) { _f.Set(1, _o1) _f.Set(2, _o2) _f.Set(3, _o3) + _c.Store(_fp, _f) } else { _c.Pop() } @@ -800,7 +813,7 @@ func LoopBreakAndContinue(_ int) { func RangeOverMaps(n int) { _c := coroutine.LoadContext[int, any]() - _f := _c.Push() + _f, _fp := _c.Push() var _o0 map[int]int var _o1 map[int]int var _o2 int @@ -929,6 +942,7 @@ func RangeOverMaps(n int) { _f.Set(36, _o35) _f.Set(37, _o36) _f.Set(38, _o37) + _c.Store(_fp, _f) } else { _c.Pop() } diff --git a/context_durable.go b/context_durable.go index d73e5bf..6b9a88d 100644 --- a/context_durable.go +++ b/context_durable.go @@ -5,8 +5,9 @@ package coroutine import "github.com/stealthrocket/coroutine/internal/serde" type serializedCoroutine struct { - entry func() - stack Stack + entry func() + stack Stack + resume bool } func init() { @@ -24,9 +25,10 @@ type Context[R, S any] struct { recv R send S - // Booleans managing the completion state of the coroutine. - done bool - stop bool + // Booleans managing the state of the coroutine. + done bool + stop bool + resume bool // Entry point of the coroutine, this is captured so the associated // generator can call into the coroutine to start or resume it at the @@ -39,8 +41,9 @@ type Context[R, S any] struct { // MarshalAppend appends a serialized Context to the provided buffer. func (c *Context[R, S]) MarshalAppend(b []byte) ([]byte, error) { s := serde.Serialize(&serializedCoroutine{ - entry: c.entry, - stack: c.Stack, + entry: c.entry, + stack: c.Stack, + resume: c.resume, }) return append(b, s...), nil } @@ -54,14 +57,15 @@ func (c *Context[R, S]) Unmarshal(b []byte) (int, error) { s := v.(*serializedCoroutine) c.entry = s.entry c.Stack = s.stack + c.resume = s.resume sn := start - len(b) return sn, nil } // TODO: do we have use cases for yielding more than one value? func (c *Context[R, S]) Yield(value R) S { - if frame := c.Top(); frame.Resume { - frame.Resume = false + if c.resume { + c.resume = false if c.stop { panic(unwind{}) } @@ -71,7 +75,7 @@ func (c *Context[R, S]) Yield(value R) S { panic("cannot yield from a coroutine that has been stopped") } var zero S - frame.Resume = true + c.resume = true c.send = zero c.recv = value panic(unwind{}) @@ -80,7 +84,7 @@ func (c *Context[R, S]) Yield(value R) S { // Unwinding returns true if the coroutine is currently unwinding its stack. func (c *Context[R, S]) Unwinding() bool { - return len(c.Frames) > 0 && c.Top().Resume + return c.resume } type unwind struct{} diff --git a/coroutine_durable.go b/coroutine_durable.go index 0939352..b1bd529 100644 --- a/coroutine_durable.go +++ b/coroutine_durable.go @@ -25,13 +25,14 @@ func (c Coroutine[R, S]) Next() (hasNext bool) { defer func() { clearContext(g) + switch err := recover(); err { + case nil: + case unwind{}: + default: + panic(err) + } + if c.ctx.Unwinding() { - switch err := recover(); err { - case nil: - case unwind{}: - default: - panic(err) - } stop := c.ctx.stop c.ctx.done, hasNext = stop, !stop } else { diff --git a/stack_durable.go b/stack_durable.go index d29a3c5..b1414dc 100644 --- a/stack_durable.go +++ b/stack_durable.go @@ -12,27 +12,24 @@ type Stack struct { Frames []Frame } -// Top returns the top of the call stack. -func (s *Stack) Top() *Frame { - if len(s.Frames) == 0 { - panic("no stack frames") - } - return &s.Frames[len(s.Frames)-1] -} - // Push prepares the stack for an impending function call. // -// The stack's frame pointer is incremented, and a Frame is pushed to the -// stack if the caller is on the topmost frame. +// The stack's frame pointer is incremented, and the stack is resized +// to make room for a new frame if the caller is on the topmost frame. // // If the caller is not on the topmost frame it means that a coroutine // is being resumed and the next frame is already present on the stack. -func (s *Stack) Push() *Frame { +// +// The Frame is returned by value rather than by reference, since the +// stack's underlying frame backing array might change. Callers +// intending to serialize the stack should call Store(fp, frame) for each +// frame during stack unwinding. +func (s *Stack) Push() (frame Frame, fp int) { if s.isTop() { s.Frames = append(s.Frames, Frame{}) } s.FP++ - return &s.Frames[s.FP] + return s.Frames[s.FP], s.FP } // Pop pops the topmost stack frame after a function call. @@ -44,6 +41,14 @@ func (s *Stack) Pop() { s.FP-- } +// Store stores a frame at the specified index. +func (s *Stack) Store(i int, f Frame) { + if i < 0 || i >= len(s.Frames) { + panic("invalid frame index") + } + s.Frames[i] = f +} + func (s *Stack) isTop() bool { return s.FP == len(s.Frames)-1 } @@ -59,8 +64,4 @@ type Frame struct { // Storage holds the Serializable objects on the frame. Storage - - // Resume is true if the function associated with the frame - // previously yielded. - Resume bool }