Skip to content

Commit

Permalink
Add imports to generated functypes file (#68)
Browse files Browse the repository at this point in the history
To properly add imports, the ident used in the generated types need to
be recorded in the package's TypeInfo. There is a risk of two files in
the same package trying to import different packages under the same
name. The compiler will stop and report the issue if that happens, but
we should consider adding a pre-pass to normalize all imports in the
package (remove dot imports, qualify them all with unique names, etc.)

Fixes #57
  • Loading branch information
pelletier authored Sep 24, 2023
2 parents e71df69 + a5e9391 commit 9eb606d
Show file tree
Hide file tree
Showing 10 changed files with 214 additions and 25 deletions.
22 changes: 15 additions & 7 deletions compiler/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ func (c *compiler) compilePackage(p *packages.Package, colors functionColors, pr
return err
}

functypesFile := generateFunctypes(prog.Package(p.Types), colors)
functypesFile := generateFunctypes(p, prog.Package(p.Types), colors)
functypesPath := filepath.Join(packageDir, "coroutine_functypes.go")
if err := c.writeFile(functypesPath, functypesFile); err != nil {
return err
Expand All @@ -306,16 +306,20 @@ func addImports(p *packages.Package, gen *ast.File) *ast.File {
switch x := n.(type) {
case *ast.SelectorExpr:
ident, ok := x.X.(*ast.Ident)
if !ok {
if !ok || ident.Name == "" {
break
}

obj := p.TypesInfo.ObjectOf(ident)
pkgname, ok := obj.(*types.PkgName)
if !ok {
break
}

pkg := pkgname.Imported().Path()
if pkg == "" {
break
}

if existing, ok := imports[ident.Name]; ok && existing != pkg {
fmt.Println("existing:", ident.Name, existing)
Expand All @@ -327,6 +331,10 @@ func addImports(p *packages.Package, gen *ast.File) *ast.File {
return true
})

if len(imports) == 0 {
return gen
}

importspecs := make([]ast.Spec, 0, len(imports))
for name, path := range imports {
importspecs = append(importspecs, &ast.ImportSpec{
Expand Down Expand Up @@ -425,7 +433,7 @@ func (scope *scope) compileFuncLit(p *packages.Package, fn *ast.FuncLit, color *
}

func (scope *scope) compileFuncBody(p *packages.Package, typ *ast.FuncType, body *ast.BlockStmt, color *types.Signature) *ast.BlockStmt {
body = desugar(p.Types, body, p.TypesInfo).(*ast.BlockStmt)
body = desugar(p, body).(*ast.BlockStmt)
body = astutil.Apply(body,
func(cursor *astutil.Cursor) bool {
switch n := cursor.Node().(type) {
Expand All @@ -450,8 +458,8 @@ func (scope *scope) compileFuncBody(p *packages.Package, typ *ast.FuncType, body
fp := ast.NewIdent("_fp")

yieldTypeExpr := make([]ast.Expr, 2)
yieldTypeExpr[0] = typeExpr(p.Types, color.Params().At(0).Type())
yieldTypeExpr[1] = typeExpr(p.Types, color.Results().At(0).Type())
yieldTypeExpr[0] = typeExpr(p, color.Params().At(0).Type())
yieldTypeExpr[1] = typeExpr(p, color.Results().At(0).Type())

// _c := coroutine.LoadContext[R, S]()
gen.List = append(gen.List, &ast.AssignStmt{
Expand Down Expand Up @@ -494,7 +502,7 @@ func (scope *scope) compileFuncBody(p *packages.Package, typ *ast.FuncType, body
// declarations to the function prologue. We downgrade inline var decls and
// assignments that use := to assignments that use =. Constant decls are
// hoisted and also have their value assigned in the function prologue.
decls := extractDecls(p.Types, body, p.TypesInfo)
decls := extractDecls(p, body, p.TypesInfo)
renameObjects(body, p.TypesInfo, decls, scope)
for _, decl := range decls {
gen.List = append(gen.List, &ast.DeclStmt{Decl: decl})
Expand Down Expand Up @@ -547,7 +555,7 @@ func (scope *scope) compileFuncBody(p *packages.Package, typ *ast.FuncType, body
Lhs: []ast.Expr{name},
Tok: token.ASSIGN,
Rhs: []ast.Expr{
&ast.TypeAssertExpr{X: value, Type: typeExpr(p.Types, saveAndRestoreTypes[i])},
&ast.TypeAssertExpr{X: value, Type: typeExpr(p, saveAndRestoreTypes[i])},
},
},
},
Expand Down
6 changes: 6 additions & 0 deletions compiler/coroutine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,12 @@ func TestCoroutineYield(t *testing.T) {
-15, 15, 150, // type switch
},
},

{
name: "yield imported type time.Duration",
coro: YieldingDurations,
yields: []int{100, 101, 102, 103, 104, 105, 106, 107, 108, 109},
},
}

// This emulates the installation of function type information by the
Expand Down
3 changes: 2 additions & 1 deletion compiler/decls.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"go/types"

"golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/go/packages"
)

// extractDecls extracts type, constant and variable declarations
Expand All @@ -20,7 +21,7 @@ import (
// Note that declarations are extracted from all nested scopes within the
// function body, so there may be duplicate identifiers. Identifiers can be
// disambiguated using (*types.Info).ObjectOf(ident).
func extractDecls(p *types.Package, tree ast.Node, info *types.Info) (decls []*ast.GenDecl) {
func extractDecls(p *packages.Package, tree ast.Node, info *types.Info) (decls []*ast.GenDecl) {
ast.Inspect(tree, func(node ast.Node) bool {
switch n := node.(type) {
case *ast.FuncLit:
Expand Down
6 changes: 4 additions & 2 deletions compiler/desugar.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"strconv"

"golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/go/packages"
)

// desugar recursively replaces sugared AST nodes with simpler constructs.
Expand Down Expand Up @@ -39,7 +40,8 @@ import (
// types.Info. If this gets unruly in the future, desugaring should be
// performed after parsing AST's but before type checking so that this is
// done automatically by the type checker.
func desugar(p *types.Package, stmt ast.Stmt, info *types.Info) ast.Stmt {
func desugar(p *packages.Package, stmt ast.Stmt) ast.Stmt {
info := p.TypesInfo
d := desugarer{pkg: p, info: info}
stmt = d.desugar(stmt, nil, nil, nil)

Expand All @@ -56,7 +58,7 @@ func desugar(p *types.Package, stmt ast.Stmt, info *types.Info) ast.Stmt {
}

type desugarer struct {
pkg *types.Package
pkg *packages.Package
info *types.Info
vars int
labels int
Expand Down
4 changes: 3 additions & 1 deletion compiler/desugar_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"testing"

"golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/go/packages"
)

func TestDesugar(t *testing.T) {
Expand Down Expand Up @@ -1204,7 +1205,8 @@ _l0:
}
return true
})
desugared := desugar(nil, body, info)
p := &packages.Package{TypesInfo: info}
desugared := desugar(p, body)
desugared = unnestBlocks(desugared)

expect := strings.TrimSpace(test.expect)
Expand Down
21 changes: 13 additions & 8 deletions compiler/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ import (
"slices"
"strconv"

"golang.org/x/tools/go/packages"
"golang.org/x/tools/go/ssa"
)

func generateFunctypes(pkg *ssa.Package, colors functionColors) *ast.File {
func generateFunctypes(p *packages.Package, pkg *ssa.Package, colors functionColors) *ast.File {
var names = make([]string, 0, len(pkg.Members))
for name := range pkg.Members {
names = append(names, name)
Expand All @@ -22,11 +23,11 @@ func generateFunctypes(pkg *ssa.Package, colors functionColors) *ast.File {
for _, name := range names {
if fn, ok := pkg.Members[name].(*ssa.Function); ok {
name := pkg.Pkg.Path() + "." + fn.Name()
generateFunctypesInit(pkg, fn, &init, name, colors)
generateFunctypesInit(p, pkg, fn, &init, name, colors)
}
}

return &ast.File{
gen := &ast.File{
Name: ast.NewIdent(pkg.Pkg.Name()),
Decls: []ast.Decl{
&ast.GenDecl{
Expand All @@ -48,9 +49,13 @@ func generateFunctypes(pkg *ssa.Package, colors functionColors) *ast.File {
},
},
}

gen = addImports(p, gen)

return gen
}

func generateFunctypesInit(pkg *ssa.Package, fn *ssa.Function, init *ast.BlockStmt, name string, colors functionColors) {
func generateFunctypesInit(p *packages.Package, pkg *ssa.Package, fn *ssa.Function, init *ast.BlockStmt, name string, colors functionColors) {
if fn.TypeParams() != nil {
return // ignore non-instantiated generic functions
}
Expand All @@ -62,7 +67,7 @@ func generateFunctypesInit(pkg *ssa.Package, fn *ssa.Function, init *ast.BlockSt
Sel: ast.NewIdent("RegisterFunc"),
},
Indices: []ast.Expr{
newFuncType(pkg.Pkg, fn.Signature),
newFuncType(p, fn.Signature),
},
}
} else {
Expand All @@ -76,7 +81,7 @@ func generateFunctypesInit(pkg *ssa.Package, fn *ssa.Function, init *ast.BlockSt
for i, freeVar := range fn.FreeVars {
fields[i+1] = &ast.Field{
Names: []*ast.Ident{ast.NewIdent(freeVar.Name())},
Type: typeExpr(pkg.Pkg, freeVar.Type()),
Type: typeExpr(p, freeVar.Type()),
}
}

Expand All @@ -86,7 +91,7 @@ func generateFunctypesInit(pkg *ssa.Package, fn *ssa.Function, init *ast.BlockSt
Sel: ast.NewIdent("RegisterClosure"),
},
Indices: []ast.Expr{
newFuncType(pkg.Pkg, fn.Signature),
newFuncType(p, fn.Signature),
&ast.StructType{
Fields: &ast.FieldList{
List: fields,
Expand Down Expand Up @@ -125,7 +130,7 @@ func generateFunctypesInit(pkg *ssa.Package, fn *ssa.Function, init *ast.BlockSt
for _, anonFunc := range anonFuncs {
index++
anonFuncName := anonFuncLinkName(name, index)
generateFunctypesInit(pkg, anonFunc, init, anonFuncName, colors)
generateFunctypesInit(p, pkg, anonFunc, init, anonFuncName, colors)
}
}

Expand Down
14 changes: 14 additions & 0 deletions compiler/testdata/coroutine.go
Original file line number Diff line number Diff line change
Expand Up @@ -486,3 +486,17 @@ func b(v int) int {
coroutine.Yield[int, any](-v)
return v
}

func YieldingDurations() {
t := new(time.Duration)
*t = time.Duration(100)

f := func() {
i := int(t.Nanoseconds())
*t = time.Duration(i + 1)
coroutine.Yield[int, any](i)
}
for i := 0; i < 10; i++ {
f()
}
}
136 changes: 136 additions & 0 deletions compiler/testdata/coroutine_durable.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 9eb606d

Please sign in to comment.