Skip to content

Commit

Permalink
Support basic methods (#73)
Browse files Browse the repository at this point in the history
This PR adds basic support for methods.
  • Loading branch information
chriso authored Sep 25, 2023
2 parents 8be7c63 + c9d1498 commit f4ad33e
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 7 deletions.
9 changes: 5 additions & 4 deletions compiler/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,10 +371,11 @@ func (scope *scope) compileFuncDecl(p *packages.Package, fn *ast.FuncDecl, color
// as the source function (and require that the caller use build tags
// to disambiguate function calls).
gen := &ast.FuncDecl{
Recv: fn.Recv,
Doc: &ast.CommentGroup{},
Name: fn.Name,
Type: funcTypeWithNamedResults(fn.Type),
Body: scope.compileFuncBody(p, fn.Type, fn.Body, color),
Body: scope.compileFuncBody(p, fn.Type, fn.Body, fn.Recv, color),
}

// If the function declaration contains function literals, we have to
Expand Down Expand Up @@ -407,7 +408,7 @@ func (scope *scope) compileFuncLit(p *packages.Package, fn *ast.FuncLit, color *

gen := &ast.FuncLit{
Type: funcTypeWithNamedResults(fn.Type),
Body: scope.compileFuncBody(p, fn.Type, fn.Body, color),
Body: scope.compileFuncBody(p, fn.Type, fn.Body, nil, color),
}

if !isExpr(gen.Body) {
Expand All @@ -416,7 +417,7 @@ func (scope *scope) compileFuncLit(p *packages.Package, fn *ast.FuncLit, color *
return gen
}

func (scope *scope) compileFuncBody(p *packages.Package, typ *ast.FuncType, body *ast.BlockStmt, color *types.Signature) *ast.BlockStmt {
func (scope *scope) compileFuncBody(p *packages.Package, typ *ast.FuncType, body *ast.BlockStmt, recv *ast.FieldList, color *types.Signature) *ast.BlockStmt {
var defers *ast.Ident

mayYield := findCalls(body, p.TypesInfo)
Expand Down Expand Up @@ -516,7 +517,7 @@ func (scope *scope) compileFuncBody(p *packages.Package, typ *ast.FuncType, body
// declarations to the function prologue. We downgrade inline var decls and
// assignments that use := to assignments that use =. Constant decls are
// hoisted and also have their value assigned in the function prologue.
decls, frameType, frameInit := extractDecls(p, typ, body, defers, p.TypesInfo)
decls, frameType, frameInit := extractDecls(p, typ, body, recv, defers, p.TypesInfo)
renameObjects(body, p.TypesInfo, decls, frameName, frameType, frameInit, scope)

for _, decl := range decls {
Expand Down
6 changes: 6 additions & 0 deletions compiler/coroutine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,12 @@ func TestCoroutineYield(t *testing.T) {
coro: YieldingDurations,
yields: []int{100, 101, 102, 103, 104, 105, 106, 107, 108, 109},
},

{
name: "methods",
coro: func() { var s MethodGeneratorState; s.MethodGenerator(5) },
yields: []int{0, 1, 2, 3, 4, 5},
},
}

// This emulates the installation of function type information by the
Expand Down
19 changes: 18 additions & 1 deletion compiler/decls.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,27 @@ import (
// Note that declarations are extracted from all nested scopes within the
// function body, so there may be duplicate identifiers. Identifiers can be
// disambiguated using (*types.Info).ObjectOf(ident).
func extractDecls(p *packages.Package, typ *ast.FuncType, body *ast.BlockStmt, defers *ast.Ident, info *types.Info) (decls []*ast.GenDecl, frameType *ast.StructType, frameInit *ast.CompositeLit) {
func extractDecls(p *packages.Package, typ *ast.FuncType, body *ast.BlockStmt, recv *ast.FieldList, defers *ast.Ident, info *types.Info) (decls []*ast.GenDecl, frameType *ast.StructType, frameInit *ast.CompositeLit) {
frameType = &ast.StructType{Fields: &ast.FieldList{}}
frameInit = &ast.CompositeLit{Type: frameType}

if recv != nil {
for _, field := range recv.List {
for _, ident := range field.Names {
if ident.Name != "_" {
frameType.Fields.List = append(frameType.Fields.List, &ast.Field{
Names: []*ast.Ident{ident},
Type: field.Type,
})
frameInit.Elts = append(frameInit.Elts, &ast.KeyValueExpr{
Key: ident,
Value: ident,
})
}
}
}
}

if typ.Params != nil {
for _, field := range typ.Params.List {
for _, ident := range field.Names {
Expand Down
8 changes: 8 additions & 0 deletions compiler/testdata/coroutine.go
Original file line number Diff line number Diff line change
Expand Up @@ -525,3 +525,11 @@ func RangeYieldAndDeferAssign(n int) {
YieldAndDeferAssign(&i, i, i+1)
}
}

type MethodGeneratorState struct{ i int }

func (s *MethodGeneratorState) MethodGenerator(n int) {
for s.i = 0; s.i <= n; s.i++ {
coroutine.Yield[int, any](s.i)
}
}
44 changes: 44 additions & 0 deletions compiler/testdata/coroutine_generated.go
Original file line number Diff line number Diff line change
Expand Up @@ -3083,13 +3083,57 @@ func RangeYieldAndDeferAssign(n int) {
}
}
}

type MethodGeneratorState struct {
i int
}

//go:noinline
func (s *MethodGeneratorState) MethodGenerator(n int) {
_c := coroutine.LoadContext[int, any]()
_f, _fp := _c.Push()
var _f0 *struct {
X0 *MethodGeneratorState
X1 int
}
if _f.IP == 0 {
_f0 = &struct {
X0 *MethodGeneratorState
X1 int
}{X0: s, X1: n}
} else {
_f0 = _f.Get(0).(*struct {
X0 *MethodGeneratorState
X1 int
})
}
defer func() {
if _c.Unwinding() {
_f.Set(0, _f0)
_c.Store(_fp, _f)
} else {
_c.Pop()
}
}()
switch {
case _f.IP < 2:
_f0.X0.i = 0
_f.IP = 2
fallthrough
case _f.IP < 3:
for ; _f0.X0.i <= _f0.X1; _f0.X0.i, _f.IP = _f0.X0.i+1, 2 {
coroutine.Yield[int, any](_f0.X0.i)
}
}
}
func init() {
_types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.Double")
_types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.EvenSquareGenerator")
_types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.FizzBuzzIfGenerator")
_types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.FizzBuzzSwitchGenerator")
_types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.Identity")
_types.RegisterFunc[func(_ int)]("github.com/stealthrocket/coroutine/compiler/testdata.LoopBreakAndContinue")
_types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.MethodGenerator")
_types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.NestedLoops")
_types.RegisterFunc[func(n int, do func(int))]("github.com/stealthrocket/coroutine/compiler/testdata.Range")
_types.RegisterFunc[func()]("github.com/stealthrocket/coroutine/compiler/testdata.Range10ClosureCapturingPointers")
Expand Down
5 changes: 3 additions & 2 deletions compiler/testdata/http/main_generated.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@ package main
import (
http "net/http"
coroutine "github.com/stealthrocket/coroutine"
coroutine "github.com/stealthrocket/coroutine"
fmt "fmt"
_types "github.com/stealthrocket/coroutine/types"
)
import _types "github.com/stealthrocket/coroutine/types"

type yieldingRoundTripper struct {
}
//go:noinline
func RoundTrip(req *http.Request) (_ *http.Response, _ error) {
func (*yieldingRoundTripper) RoundTrip(req *http.Request) (_ *http.Response, _ error) {
_c := coroutine.LoadContext[*http.Request, *http.Response]()
_f, _fp := _c.Push()
var _f0 *struct {
Expand Down

0 comments on commit f4ad33e

Please sign in to comment.