From 8fddd77617819a71ab62542b7516e9a021ca5d55 Mon Sep 17 00:00:00 2001 From: Chris O'Hara Date: Thu, 14 Dec 2023 16:19:12 +1000 Subject: [PATCH 1/4] Create a test to highlight the issue --- compiler/coroutine_test.go | 6 ++++++ compiler/testdata/coroutine.go | 20 ++++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/compiler/coroutine_test.go b/compiler/coroutine_test.go index d6e92c8..a050592 100644 --- a/compiler/coroutine_test.go +++ b/compiler/coroutine_test.go @@ -215,6 +215,12 @@ func TestCoroutineYield(t *testing.T) { result: 42, }, + { + name: "closure capturing receiver and param", + coro: func() { StructClosure(0, 3) }, + yields: []int{-1, 0, 1, 2}, + }, + { name: "identity generic", coro: func() { IdentityGeneric[int](11) }, diff --git a/compiler/testdata/coroutine.go b/compiler/testdata/coroutine.go index 92d7218..cbeb9d4 100644 --- a/compiler/testdata/coroutine.go +++ b/compiler/testdata/coroutine.go @@ -577,3 +577,23 @@ func IdentityGenericInt(n int) { func IdentityGenericStructInt(n int) { (&IdentityGenericStruct[int]{n: n}).Run() } + +type Box struct { + n int +} + +func (b *Box) Closure() func() { + coroutine.Yield[int, any](-1) + return func() { + coroutine.Yield[int, any](b.n) + b.n++ + } +} + +func StructClosure(n, count int) { + box := Box{n} + fn := box.Closure() + for i := 0; i < count; i++ { + fn() + } +} From 3ef4d26aae2824c2c704df557911412485809e4e Mon Sep 17 00:00:00 2001 From: Chris O'Hara Date: Thu, 14 Dec 2023 16:20:33 +1000 Subject: [PATCH 2/4] Make sure receivers are accounted for when registering func/closure types --- compiler/function.go | 44 ++++++++- compiler/testdata/coroutine_durable.go | 129 ++++++++++++++++++++++++- 2 files changed, 170 insertions(+), 3 deletions(-) diff --git a/compiler/function.go b/compiler/function.go index e468bd0..49222eb 100644 --- a/compiler/function.go +++ b/compiler/function.go @@ -9,6 +9,7 @@ import ( "maps" "slices" "strconv" + "strings" "golang.org/x/tools/go/ast/astutil" "golang.org/x/tools/go/packages" @@ -82,7 +83,8 @@ func collectFunctypes(p *packages.Package, name string, fn ast.Node, scope *func } signature := functionTypeOf(fn) - for _, fields := range []*ast.FieldList{signature.Params, signature.Results} { + recv := functionRecvOf(fn) + for _, fields := range []*ast.FieldList{recv, signature.Params, signature.Results} { if fields != nil { for _, field := range fields.List { for _, name := range field.Names { @@ -205,7 +207,34 @@ func packagePath(p *packages.Package) string { } func functionPath(p *packages.Package, f *ast.FuncDecl) string { - return packagePath(p) + "." + f.Name.Name + var b strings.Builder + b.WriteString(packagePath(p)) + if f.Recv != nil { + b.WriteByte('.') + // https://go.dev/ref/spec#Method_declarations + // > That parameter section must declare a single non-variadic parameter + recvType := p.TypesInfo.Defs[f.Name].Type().(*types.Signature).Recv().Type() + isptr := false + if ptr, ok := recvType.(*types.Pointer); ok { + recvType = ptr.Elem() + isptr = true + } + if isptr { + b.WriteString("(*") + } + switch t := recvType.(type) { + case *types.Named: + b.WriteString(t.Obj().Name()) + default: + panic(fmt.Sprintf("not implemented: %T", t)) + } + if isptr { + b.WriteByte(')') + } + } + b.WriteByte('.') + b.WriteString(f.Name.Name) + return b.String() } func (c *compiler) generateFunctypes(p *packages.Package, f *ast.File, colors map[ast.Node]*types.Signature) { @@ -302,6 +331,17 @@ func functionTypeOf(fn ast.Node) *ast.FuncType { } } +func functionRecvOf(fn ast.Node) *ast.FieldList { + switch f := fn.(type) { + case *ast.FuncDecl: + return f.Recv + case *ast.FuncLit: + return nil + default: + panic("node is neither *ast.FuncDecl or *ast.FuncLit") + } +} + func functionBodyOf(fn ast.Node) *ast.BlockStmt { switch f := fn.(type) { case *ast.FuncDecl: diff --git a/compiler/testdata/coroutine_durable.go b/compiler/testdata/coroutine_durable.go index f13ce47..8689b57 100644 --- a/compiler/testdata/coroutine_durable.go +++ b/compiler/testdata/coroutine_durable.go @@ -3252,7 +3252,134 @@ func IdentityGenericInt(n int) { IdentityGeneric[int](n) } //go:noinline func IdentityGenericStructInt(n int) { (&IdentityGenericStruct[int]{n: n}).Run() } + +type Box struct { + n int +} + +//go:noinline +func (_fn0 *Box) Closure() (_ func()) { + _c := coroutine.LoadContext[int, any]() + var _f1 *struct { + IP int + X0 *Box + } = coroutine.Push[struct { + IP int + X0 *Box + }](&_c.Stack) + if _f1.IP == 0 { + *_f1 = struct { + IP int + X0 *Box + }{X0: _fn0} + } + defer func() { + if !_c.Unwinding() { + coroutine.Pop(&_c.Stack) + } + }() + switch { + case _f1.IP < 2: + coroutine.Yield[int, any](-1) + _f1.IP = 2 + fallthrough + case _f1.IP < 3: + return func() { + _c := coroutine.LoadContext[int, any]() + var _f0 *struct { + IP int + } = coroutine.Push[struct { + IP int + }](&_c.Stack) + if _f0.IP == 0 { + *_f0 = struct { + IP int + }{} + } + defer func() { + if !_c.Unwinding() { + coroutine.Pop(&_c.Stack) + } + }() + switch { + case _f0.IP < 2: + coroutine.Yield[int, any](_f1.X0.n) + _f0.IP = 2 + fallthrough + case _f0.IP < 3: + _f1.X0. + n++ + } + } + } + panic("unreachable") +} + +//go:noinline +func StructClosure(_fn0, _fn1 int) { + _c := coroutine.LoadContext[int, any]() + var _f0 *struct { + IP int + X0 int + X1 int + X2 Box + X3 func() + X4 int + } = coroutine.Push[struct { + IP int + X0 int + X1 int + X2 Box + X3 func() + X4 int + }](&_c.Stack) + if _f0.IP == 0 { + *_f0 = struct { + IP int + X0 int + X1 int + X2 Box + X3 func() + X4 int + }{X0: _fn0, X1: _fn1} + } + defer func() { + if !_c.Unwinding() { + coroutine.Pop(&_c.Stack) + } + }() + switch { + case _f0.IP < 2: + _f0.X2 = Box{_f0.X0} + _f0.IP = 2 + fallthrough + case _f0.IP < 3: + _f0.X3 = _f0.X2.Closure() + _f0.IP = 3 + fallthrough + case _f0.IP < 5: + switch { + case _f0.IP < 4: + _f0.X4 = 0 + _f0.IP = 4 + fallthrough + case _f0.IP < 5: + for ; _f0.X4 < _f0.X1; _f0.X4, _f0.IP = _f0.X4+1, 4 { + _f0.X3() + } + } + } +} func init() { + _types.RegisterFunc[func() (_ func())]("github.com/stealthrocket/coroutine/compiler/testdata.(*Box).Closure") + _types.RegisterClosure[func(), struct { + F uintptr + X0 *struct { + IP int + X0 *Box + } + }]("github.com/stealthrocket/coroutine/compiler/testdata.(*Box).Closure.func2") + _types.RegisterFunc[func(_fn1 int)]("github.com/stealthrocket/coroutine/compiler/testdata.(*MethodGeneratorState).MethodGenerator") _types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.Double") _types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.EvenSquareGenerator") _types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.FizzBuzzIfGenerator") @@ -3261,7 +3388,6 @@ func init() { _types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.IdentityGenericInt") _types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.IdentityGenericStructInt") _types.RegisterFunc[func(_ int)]("github.com/stealthrocket/coroutine/compiler/testdata.LoopBreakAndContinue") - _types.RegisterFunc[func(_fn1 int)]("github.com/stealthrocket/coroutine/compiler/testdata.MethodGenerator") _types.RegisterFunc[func(_fn0 int) (_ int)]("github.com/stealthrocket/coroutine/compiler/testdata.NestedLoops") _types.RegisterFunc[func(_fn0 int, _fn1 func(int))]("github.com/stealthrocket/coroutine/compiler/testdata.Range") _types.RegisterFunc[func()]("github.com/stealthrocket/coroutine/compiler/testdata.Range10ClosureCapturingPointers") @@ -3357,6 +3483,7 @@ func init() { _types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.SquareGenerator") _types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.SquareGeneratorTwice") _types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.SquareGeneratorTwiceLoop") + _types.RegisterFunc[func(_fn0, _fn1 int)]("github.com/stealthrocket/coroutine/compiler/testdata.StructClosure") _types.RegisterFunc[func(_ int)]("github.com/stealthrocket/coroutine/compiler/testdata.TypeSwitchingGenerator") _types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.VarArgs") _types.RegisterFunc[func(_fn0 *int, _fn1, _fn2 int)]("github.com/stealthrocket/coroutine/compiler/testdata.YieldAndDeferAssign") From ad6ce047a98ad1f342ec3fbe7c41e36e43b8ba9c Mon Sep 17 00:00:00 2001 From: Chris O'Hara Date: Thu, 14 Dec 2023 16:23:42 +1000 Subject: [PATCH 3/4] Write a test where both the receiver and a param are captured --- compiler/coroutine_test.go | 4 +- compiler/testdata/coroutine.go | 27 +++++---- compiler/testdata/coroutine_durable.go | 80 ++++++++++++++++---------- 3 files changed, 70 insertions(+), 41 deletions(-) diff --git a/compiler/coroutine_test.go b/compiler/coroutine_test.go index a050592..a278da2 100644 --- a/compiler/coroutine_test.go +++ b/compiler/coroutine_test.go @@ -217,8 +217,8 @@ func TestCoroutineYield(t *testing.T) { { name: "closure capturing receiver and param", - coro: func() { StructClosure(0, 3) }, - yields: []int{-1, 0, 1, 2}, + coro: func() { StructClosure(3) }, + yields: []int{-1, 10, 100, 1000, 11, 101, 1000, 12, 102, 1000}, }, { diff --git a/compiler/testdata/coroutine.go b/compiler/testdata/coroutine.go index cbeb9d4..97fc7fd 100644 --- a/compiler/testdata/coroutine.go +++ b/compiler/testdata/coroutine.go @@ -579,21 +579,28 @@ func IdentityGenericStructInt(n int) { } type Box struct { - n int + x int } -func (b *Box) Closure() func() { +func (b *Box) Closure(y int) func(int) { + // Force compilation of this method and the closure within. + // Remove once #84 is fixed. coroutine.Yield[int, any](-1) - return func() { - coroutine.Yield[int, any](b.n) - b.n++ + + return func(z int) { + coroutine.Yield[int, any](b.x) + coroutine.Yield[int, any](y) + coroutine.Yield[int, any](z) + b.x++ + y++ + z++ // mutation is lost } } -func StructClosure(n, count int) { - box := Box{n} - fn := box.Closure() - for i := 0; i < count; i++ { - fn() +func StructClosure(n int) { + box := Box{10} + fn := box.Closure(100) + for i := 0; i < n; i++ { + fn(1000) } } diff --git a/compiler/testdata/coroutine_durable.go b/compiler/testdata/coroutine_durable.go index 8689b57..0e4b610 100644 --- a/compiler/testdata/coroutine_durable.go +++ b/compiler/testdata/coroutine_durable.go @@ -3254,24 +3254,27 @@ func IdentityGenericInt(n int) { IdentityGeneric[int](n) } func IdentityGenericStructInt(n int) { (&IdentityGenericStruct[int]{n: n}).Run() } type Box struct { - n int + x int } //go:noinline -func (_fn0 *Box) Closure() (_ func()) { +func (_fn0 *Box) Closure(_fn1 int) (_ func(int)) { _c := coroutine.LoadContext[int, any]() var _f1 *struct { IP int X0 *Box + X1 int } = coroutine.Push[struct { IP int X0 *Box + X1 int }](&_c.Stack) if _f1.IP == 0 { *_f1 = struct { IP int X0 *Box - }{X0: _fn0} + X1 int + }{X0: _fn0, X1: _fn1} } defer func() { if !_c.Unwinding() { @@ -3280,21 +3283,26 @@ func (_fn0 *Box) Closure() (_ func()) { }() switch { case _f1.IP < 2: + coroutine.Yield[int, any](-1) _f1.IP = 2 fallthrough case _f1.IP < 3: - return func() { + + return func(_fn0 int) { _c := coroutine.LoadContext[int, any]() var _f0 *struct { IP int + X0 int } = coroutine.Push[struct { IP int + X0 int }](&_c.Stack) if _f0.IP == 0 { *_f0 = struct { IP int - }{} + X0 int + }{X0: _fn0} } defer func() { if !_c.Unwinding() { @@ -3303,12 +3311,28 @@ func (_fn0 *Box) Closure() (_ func()) { }() switch { case _f0.IP < 2: - coroutine.Yield[int, any](_f1.X0.n) + coroutine.Yield[int, any](_f1.X0.x) _f0.IP = 2 fallthrough case _f0.IP < 3: + coroutine.Yield[int, any](_f1.X1) + _f0.IP = 3 + fallthrough + case _f0.IP < 4: + coroutine.Yield[int, any](_f0.X0) + _f0.IP = 4 + fallthrough + case _f0.IP < 5: _f1.X0. - n++ + x++ + _f0.IP = 5 + fallthrough + case _f0.IP < 6: + _f1.X1++ + _f0.IP = 6 + fallthrough + case _f0.IP < 7: + _f0.X0++ } } } @@ -3316,32 +3340,29 @@ func (_fn0 *Box) Closure() (_ func()) { } //go:noinline -func StructClosure(_fn0, _fn1 int) { +func StructClosure(_fn0 int) { _c := coroutine.LoadContext[int, any]() var _f0 *struct { IP int X0 int - X1 int - X2 Box - X3 func() - X4 int + X1 Box + X2 func(int) + X3 int } = coroutine.Push[struct { IP int X0 int - X1 int - X2 Box - X3 func() - X4 int + X1 Box + X2 func(int) + X3 int }](&_c.Stack) if _f0.IP == 0 { *_f0 = struct { IP int X0 int - X1 int - X2 Box - X3 func() - X4 int - }{X0: _fn0, X1: _fn1} + X1 Box + X2 func(int) + X3 int + }{X0: _fn0} } defer func() { if !_c.Unwinding() { @@ -3350,33 +3371,34 @@ func StructClosure(_fn0, _fn1 int) { }() switch { case _f0.IP < 2: - _f0.X2 = Box{_f0.X0} + _f0.X1 = Box{10} _f0.IP = 2 fallthrough case _f0.IP < 3: - _f0.X3 = _f0.X2.Closure() + _f0.X2 = _f0.X1.Closure(100) _f0.IP = 3 fallthrough case _f0.IP < 5: switch { case _f0.IP < 4: - _f0.X4 = 0 + _f0.X3 = 0 _f0.IP = 4 fallthrough case _f0.IP < 5: - for ; _f0.X4 < _f0.X1; _f0.X4, _f0.IP = _f0.X4+1, 4 { - _f0.X3() + for ; _f0.X3 < _f0.X0; _f0.X3, _f0.IP = _f0.X3+1, 4 { + _f0.X2(1000) } } } } func init() { - _types.RegisterFunc[func() (_ func())]("github.com/stealthrocket/coroutine/compiler/testdata.(*Box).Closure") - _types.RegisterClosure[func(), struct { + _types.RegisterFunc[func(_fn1 int) (_ func(int))]("github.com/stealthrocket/coroutine/compiler/testdata.(*Box).Closure") + _types.RegisterClosure[func(_fn0 int), struct { F uintptr X0 *struct { IP int X0 *Box + X1 int } }]("github.com/stealthrocket/coroutine/compiler/testdata.(*Box).Closure.func2") _types.RegisterFunc[func(_fn1 int)]("github.com/stealthrocket/coroutine/compiler/testdata.(*MethodGeneratorState).MethodGenerator") @@ -3483,7 +3505,7 @@ func init() { _types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.SquareGenerator") _types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.SquareGeneratorTwice") _types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.SquareGeneratorTwiceLoop") - _types.RegisterFunc[func(_fn0, _fn1 int)]("github.com/stealthrocket/coroutine/compiler/testdata.StructClosure") + _types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.StructClosure") _types.RegisterFunc[func(_ int)]("github.com/stealthrocket/coroutine/compiler/testdata.TypeSwitchingGenerator") _types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.VarArgs") _types.RegisterFunc[func(_fn0 *int, _fn1, _fn2 int)]("github.com/stealthrocket/coroutine/compiler/testdata.YieldAndDeferAssign") From b6322fa93f6bb58c83d8b074e0cf313299d1a6a0 Mon Sep 17 00:00:00 2001 From: Chris O'Hara Date: Thu, 14 Dec 2023 16:26:53 +1000 Subject: [PATCH 4/4] Remove outdated comment --- compiler/function.go | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/compiler/function.go b/compiler/function.go index 49222eb..7a4a4f8 100644 --- a/compiler/function.go +++ b/compiler/function.go @@ -210,15 +210,14 @@ func functionPath(p *packages.Package, f *ast.FuncDecl) string { var b strings.Builder b.WriteString(packagePath(p)) if f.Recv != nil { - b.WriteByte('.') - // https://go.dev/ref/spec#Method_declarations - // > That parameter section must declare a single non-variadic parameter - recvType := p.TypesInfo.Defs[f.Name].Type().(*types.Signature).Recv().Type() + signature := p.TypesInfo.Defs[f.Name].Type().(*types.Signature) + recvType := signature.Recv().Type() isptr := false if ptr, ok := recvType.(*types.Pointer); ok { recvType = ptr.Elem() isptr = true } + b.WriteByte('.') if isptr { b.WriteString("(*") }