diff --git a/compiler/compile.go b/compiler/compile.go index 4ecfcb9..d9d148a 100644 --- a/compiler/compile.go +++ b/compiler/compile.go @@ -371,10 +371,11 @@ func (scope *scope) compileFuncDecl(p *packages.Package, fn *ast.FuncDecl, color // as the source function (and require that the caller use build tags // to disambiguate function calls). gen := &ast.FuncDecl{ + Recv: fn.Recv, Doc: &ast.CommentGroup{}, Name: fn.Name, Type: funcTypeWithNamedResults(fn.Type), - Body: scope.compileFuncBody(p, fn.Type, fn.Body, color), + Body: scope.compileFuncBody(p, fn.Type, fn.Body, fn.Recv, color), } // If the function declaration contains function literals, we have to @@ -407,7 +408,7 @@ func (scope *scope) compileFuncLit(p *packages.Package, fn *ast.FuncLit, color * gen := &ast.FuncLit{ Type: funcTypeWithNamedResults(fn.Type), - Body: scope.compileFuncBody(p, fn.Type, fn.Body, color), + Body: scope.compileFuncBody(p, fn.Type, fn.Body, nil, color), } if !isExpr(gen.Body) { @@ -416,7 +417,7 @@ func (scope *scope) compileFuncLit(p *packages.Package, fn *ast.FuncLit, color * return gen } -func (scope *scope) compileFuncBody(p *packages.Package, typ *ast.FuncType, body *ast.BlockStmt, color *types.Signature) *ast.BlockStmt { +func (scope *scope) compileFuncBody(p *packages.Package, typ *ast.FuncType, body *ast.BlockStmt, recv *ast.FieldList, color *types.Signature) *ast.BlockStmt { var defers *ast.Ident mayYield := findCalls(body, p.TypesInfo) @@ -516,7 +517,7 @@ func (scope *scope) compileFuncBody(p *packages.Package, typ *ast.FuncType, body // declarations to the function prologue. We downgrade inline var decls and // assignments that use := to assignments that use =. Constant decls are // hoisted and also have their value assigned in the function prologue. - decls, frameType, frameInit := extractDecls(p, typ, body, defers, p.TypesInfo) + decls, frameType, frameInit := extractDecls(p, typ, body, recv, defers, p.TypesInfo) renameObjects(body, p.TypesInfo, decls, frameName, frameType, frameInit, scope) for _, decl := range decls { diff --git a/compiler/coroutine_test.go b/compiler/coroutine_test.go index ab1dd64..1ff8b33 100644 --- a/compiler/coroutine_test.go +++ b/compiler/coroutine_test.go @@ -186,6 +186,12 @@ func TestCoroutineYield(t *testing.T) { coro: YieldingDurations, yields: []int{100, 101, 102, 103, 104, 105, 106, 107, 108, 109}, }, + + { + name: "methods", + coro: func() { var s MethodGeneratorState; s.MethodGenerator(5) }, + yields: []int{0, 1, 2, 3, 4, 5}, + }, } // This emulates the installation of function type information by the diff --git a/compiler/decls.go b/compiler/decls.go index 41e38a1..b48f12b 100644 --- a/compiler/decls.go +++ b/compiler/decls.go @@ -22,10 +22,27 @@ import ( // Note that declarations are extracted from all nested scopes within the // function body, so there may be duplicate identifiers. Identifiers can be // disambiguated using (*types.Info).ObjectOf(ident). -func extractDecls(p *packages.Package, typ *ast.FuncType, body *ast.BlockStmt, defers *ast.Ident, info *types.Info) (decls []*ast.GenDecl, frameType *ast.StructType, frameInit *ast.CompositeLit) { +func extractDecls(p *packages.Package, typ *ast.FuncType, body *ast.BlockStmt, recv *ast.FieldList, defers *ast.Ident, info *types.Info) (decls []*ast.GenDecl, frameType *ast.StructType, frameInit *ast.CompositeLit) { frameType = &ast.StructType{Fields: &ast.FieldList{}} frameInit = &ast.CompositeLit{Type: frameType} + if recv != nil { + for _, field := range recv.List { + for _, ident := range field.Names { + if ident.Name != "_" { + frameType.Fields.List = append(frameType.Fields.List, &ast.Field{ + Names: []*ast.Ident{ident}, + Type: field.Type, + }) + frameInit.Elts = append(frameInit.Elts, &ast.KeyValueExpr{ + Key: ident, + Value: ident, + }) + } + } + } + } + if typ.Params != nil { for _, field := range typ.Params.List { for _, ident := range field.Names { diff --git a/compiler/testdata/coroutine.go b/compiler/testdata/coroutine.go index dfbbded..c76c483 100644 --- a/compiler/testdata/coroutine.go +++ b/compiler/testdata/coroutine.go @@ -525,3 +525,11 @@ func RangeYieldAndDeferAssign(n int) { YieldAndDeferAssign(&i, i, i+1) } } + +type MethodGeneratorState struct{ i int } + +func (s *MethodGeneratorState) MethodGenerator(n int) { + for s.i = 0; s.i <= n; s.i++ { + coroutine.Yield[int, any](s.i) + } +} diff --git a/compiler/testdata/coroutine_generated.go b/compiler/testdata/coroutine_generated.go index 570d09d..5526102 100644 --- a/compiler/testdata/coroutine_generated.go +++ b/compiler/testdata/coroutine_generated.go @@ -3083,6 +3083,49 @@ func RangeYieldAndDeferAssign(n int) { } } } + +type MethodGeneratorState struct { + i int +} + +//go:noinline +func (s *MethodGeneratorState) MethodGenerator(n int) { + _c := coroutine.LoadContext[int, any]() + _f, _fp := _c.Push() + var _f0 *struct { + X0 *MethodGeneratorState + X1 int + } + if _f.IP == 0 { + _f0 = &struct { + X0 *MethodGeneratorState + X1 int + }{X0: s, X1: n} + } else { + _f0 = _f.Get(0).(*struct { + X0 *MethodGeneratorState + X1 int + }) + } + defer func() { + if _c.Unwinding() { + _f.Set(0, _f0) + _c.Store(_fp, _f) + } else { + _c.Pop() + } + }() + switch { + case _f.IP < 2: + _f0.X0.i = 0 + _f.IP = 2 + fallthrough + case _f.IP < 3: + for ; _f0.X0.i <= _f0.X1; _f0.X0.i, _f.IP = _f0.X0.i+1, 2 { + coroutine.Yield[int, any](_f0.X0.i) + } + } +} func init() { _types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.Double") _types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.EvenSquareGenerator") @@ -3090,6 +3133,7 @@ func init() { _types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.FizzBuzzSwitchGenerator") _types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.Identity") _types.RegisterFunc[func(_ int)]("github.com/stealthrocket/coroutine/compiler/testdata.LoopBreakAndContinue") + _types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.MethodGenerator") _types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.NestedLoops") _types.RegisterFunc[func(n int, do func(int))]("github.com/stealthrocket/coroutine/compiler/testdata.Range") _types.RegisterFunc[func()]("github.com/stealthrocket/coroutine/compiler/testdata.Range10ClosureCapturingPointers")