From b9c58edbede3396929770c9241272d327e0486be Mon Sep 17 00:00:00 2001 From: Chris O'Hara Date: Sat, 23 Sep 2023 16:18:10 -0700 Subject: [PATCH] Don't include package qualifier when type is in same package --- compiler/compile.go | 10 +++---- compiler/decls.go | 6 ++-- compiler/desugar.go | 9 +++--- compiler/desugar_test.go | 4 +-- compiler/function.go | 7 ++--- compiler/testdata/http/coroc_generated.go | 28 +++++++++--------- compiler/types.go | 35 +++++++++++------------ 7 files changed, 49 insertions(+), 50 deletions(-) diff --git a/compiler/compile.go b/compiler/compile.go index de00f0b..143ad7c 100644 --- a/compiler/compile.go +++ b/compiler/compile.go @@ -391,7 +391,7 @@ func (scope *scope) compileFuncLit(p *packages.Package, fn *ast.FuncLit, color * } func (scope *scope) compileFuncBody(p *packages.Package, typ *ast.FuncType, body *ast.BlockStmt, color *types.Signature) *ast.BlockStmt { - body = desugar(body, p.TypesInfo).(*ast.BlockStmt) + body = desugar(p.Types, body, p.TypesInfo).(*ast.BlockStmt) body = astutil.Apply(body, func(cursor *astutil.Cursor) bool { switch n := cursor.Node().(type) { @@ -416,8 +416,8 @@ func (scope *scope) compileFuncBody(p *packages.Package, typ *ast.FuncType, body fp := ast.NewIdent("_fp") yieldTypeExpr := make([]ast.Expr, 2) - yieldTypeExpr[0] = typeExpr(color.Params().At(0).Type()) - yieldTypeExpr[1] = typeExpr(color.Results().At(0).Type()) + yieldTypeExpr[0] = typeExpr(p.Types, color.Params().At(0).Type()) + yieldTypeExpr[1] = typeExpr(p.Types, color.Results().At(0).Type()) // _c := coroutine.LoadContext[R, S]() gen.List = append(gen.List, &ast.AssignStmt{ @@ -460,7 +460,7 @@ func (scope *scope) compileFuncBody(p *packages.Package, typ *ast.FuncType, body // declarations to the function prologue. We downgrade inline var decls and // assignments that use := to assignments that use =. Constant decls are // hoisted and also have their value assigned in the function prologue. - decls := extractDecls(body, p.TypesInfo) + decls := extractDecls(p.Types, body, p.TypesInfo) renameObjects(body, p.TypesInfo, decls, scope) for _, decl := range decls { gen.List = append(gen.List, &ast.DeclStmt{Decl: decl}) @@ -513,7 +513,7 @@ func (scope *scope) compileFuncBody(p *packages.Package, typ *ast.FuncType, body Lhs: []ast.Expr{name}, Tok: token.ASSIGN, Rhs: []ast.Expr{ - &ast.TypeAssertExpr{X: value, Type: typeExpr(saveAndRestoreTypes[i])}, + &ast.TypeAssertExpr{X: value, Type: typeExpr(p.Types, saveAndRestoreTypes[i])}, }, }, }, diff --git a/compiler/decls.go b/compiler/decls.go index a1ef552..f09a98c 100644 --- a/compiler/decls.go +++ b/compiler/decls.go @@ -20,7 +20,7 @@ import ( // Note that declarations are extracted from all nested scopes within the // function body, so there may be duplicate identifiers. Identifiers can be // disambiguated using (*types.Info).ObjectOf(ident). -func extractDecls(tree ast.Node, info *types.Info) (decls []*ast.GenDecl) { +func extractDecls(p *types.Package, tree ast.Node, info *types.Info) (decls []*ast.GenDecl) { ast.Inspect(tree, func(node ast.Node) bool { switch n := node.(type) { case *ast.FuncLit: @@ -46,7 +46,7 @@ func extractDecls(tree ast.Node, info *types.Info) (decls []*ast.GenDecl) { Specs: []ast.Spec{ &ast.ValueSpec{ Names: []*ast.Ident{name}, - Type: typeExpr(info.TypeOf(name)), + Type: typeExpr(p, info.TypeOf(name)), }, }, }) @@ -80,7 +80,7 @@ func extractDecls(tree ast.Node, info *types.Info) (decls []*ast.GenDecl) { Specs: []ast.Spec{ &ast.ValueSpec{ Names: []*ast.Ident{name}, - Type: typeExpr(t), + Type: typeExpr(p, t), }, }, }) diff --git a/compiler/desugar.go b/compiler/desugar.go index 778d027..fe1c863 100644 --- a/compiler/desugar.go +++ b/compiler/desugar.go @@ -39,8 +39,8 @@ import ( // types.Info. If this gets unruly in the future, desugaring should be // performed after parsing AST's but before type checking so that this is // done automatically by the type checker. -func desugar(stmt ast.Stmt, info *types.Info) ast.Stmt { - d := desugarer{info: info} +func desugar(p *types.Package, stmt ast.Stmt, info *types.Info) ast.Stmt { + d := desugarer{pkg: p, info: info} stmt = d.desugar(stmt, nil, nil, nil) // Unused labels cause a compile error (label X defined and not used) @@ -56,6 +56,7 @@ func desugar(stmt ast.Stmt, info *types.Info) ast.Stmt { } type desugarer struct { + pkg *types.Package info *types.Info vars int labels int @@ -253,7 +254,7 @@ func (d *desugarer) desugar(stmt ast.Stmt, breakTo, continueTo, userLabel *ast.I &ast.CallExpr{ Fun: d.builtin("make"), Args: []ast.Expr{ - typeExpr(keySliceType), + typeExpr(d.pkg, keySliceType), &ast.BasicLit{Kind: token.INT, Value: "0"}, &ast.CallExpr{Fun: d.builtin("len"), Args: []ast.Expr{x}}, }, @@ -385,7 +386,7 @@ func (d *desugarer) desugar(stmt ast.Stmt, breakTo, continueTo, userLabel *ast.I Specs: []ast.Spec{ &ast.ValueSpec{ Names: []*ast.Ident{tmpLhs}, - Type: typeExpr(lhsType), + Type: typeExpr(d.pkg, lhsType), }, }, }}) diff --git a/compiler/desugar_test.go b/compiler/desugar_test.go index bd940c1..9ede13b 100644 --- a/compiler/desugar_test.go +++ b/compiler/desugar_test.go @@ -1204,7 +1204,7 @@ _l0: } return true }) - desugared := desugar(body, info) + desugared := desugar(nil, body, info) desugared = unnestBlocks(desugared) expect := strings.TrimSpace(test.expect) @@ -1220,7 +1220,7 @@ _l0: func formatNode(node ast.Node) string { fset := token.NewFileSet() - //ast.Print(fset, node) + // ast.Print(fset, node) var b bytes.Buffer if err := format.Node(&b, fset, node); err != nil { panic(err) diff --git a/compiler/function.go b/compiler/function.go index 2279dba..e957357 100644 --- a/compiler/function.go +++ b/compiler/function.go @@ -54,7 +54,6 @@ func generateFunctypesInit(pkg *ssa.Package, fn *ssa.Function, init *ast.BlockSt if fn.TypeParams() != nil { return // ignore non-instantiated generic functions } - var register ast.Expr if len(fn.FreeVars) == 0 { register = &ast.IndexListExpr{ @@ -63,7 +62,7 @@ func generateFunctypesInit(pkg *ssa.Package, fn *ssa.Function, init *ast.BlockSt Sel: ast.NewIdent("RegisterFunc"), }, Indices: []ast.Expr{ - newFuncType(fn.Signature), + newFuncType(pkg.Pkg, fn.Signature), }, } } else { @@ -77,7 +76,7 @@ func generateFunctypesInit(pkg *ssa.Package, fn *ssa.Function, init *ast.BlockSt for i, freeVar := range fn.FreeVars { fields[i+1] = &ast.Field{ Names: []*ast.Ident{ast.NewIdent(freeVar.Name())}, - Type: typeExpr(freeVar.Type()), + Type: typeExpr(pkg.Pkg, freeVar.Type()), } } @@ -87,7 +86,7 @@ func generateFunctypesInit(pkg *ssa.Package, fn *ssa.Function, init *ast.BlockSt Sel: ast.NewIdent("RegisterClosure"), }, Indices: []ast.Expr{ - newFuncType(fn.Signature), + newFuncType(pkg.Pkg, fn.Signature), &ast.StructType{ Fields: &ast.FieldList{ List: fields, diff --git a/compiler/testdata/http/coroc_generated.go b/compiler/testdata/http/coroc_generated.go index 60edbdc..7bcb0ea 100644 --- a/compiler/testdata/http/coroc_generated.go +++ b/compiler/testdata/http/coroc_generated.go @@ -41,11 +41,7 @@ import ( big "math/big" rand "math/rand" atomic "sync/atomic" - asn1_1 "encoding/asn1" - pem "encoding/pem" - hex "encoding/hex" - base64 "encoding/base64" - asn1 "vendor/golang.org/x/crypto/cryptobyte/asn1" + asn1_1 "vendor/golang.org/x/crypto/cryptobyte/asn1" dnsmessage "vendor/golang.org/x/net/dns/dnsmessage" route "vendor/golang.org/x/net/route" idna "vendor/golang.org/x/net/idna" @@ -57,6 +53,10 @@ import ( transform "vendor/golang.org/x/text/transform" norm "vendor/golang.org/x/text/unicode/norm" bidirule "vendor/golang.org/x/text/secure/bidirule" + asn1 "encoding/asn1" + pem "encoding/pem" + hex "encoding/hex" + base64 "encoding/base64" aes "crypto/aes" x509 "crypto/x509" tls "crypto/tls" @@ -459,15 +459,15 @@ func init() { serde.RegisterType[[]uint8]() serde.RegisterType[[]uintptr]() serde.RegisterType[aes.KeySizeError]() - serde.RegisterType[asn1.Tag]() - serde.RegisterType[asn1_1.BitString]() - serde.RegisterType[asn1_1.Enumerated]() - serde.RegisterType[asn1_1.Flag]() - serde.RegisterType[asn1_1.ObjectIdentifier]() - serde.RegisterType[asn1_1.RawContent]() - serde.RegisterType[asn1_1.RawValue]() - serde.RegisterType[asn1_1.StructuralError]() - serde.RegisterType[asn1_1.SyntaxError]() + serde.RegisterType[asn1.BitString]() + serde.RegisterType[asn1.Enumerated]() + serde.RegisterType[asn1.Flag]() + serde.RegisterType[asn1.ObjectIdentifier]() + serde.RegisterType[asn1.RawContent]() + serde.RegisterType[asn1.RawValue]() + serde.RegisterType[asn1.StructuralError]() + serde.RegisterType[asn1.SyntaxError]() + serde.RegisterType[asn1_1.Tag]() serde.RegisterType[atomic.Bool]() serde.RegisterType[atomic.Int32]() serde.RegisterType[atomic.Int64]() diff --git a/compiler/types.go b/compiler/types.go index a371912..ad598a8 100644 --- a/compiler/types.go +++ b/compiler/types.go @@ -9,7 +9,7 @@ import ( "strconv" ) -func typeExpr(typ types.Type) ast.Expr { +func typeExpr(p *types.Package, typ types.Type) ast.Expr { switch t := typ.(type) { case *types.Basic: switch t { @@ -18,22 +18,22 @@ func typeExpr(typ types.Type) ast.Expr { } return ast.NewIdent(t.String()) case *types.Slice: - return &ast.ArrayType{Elt: typeExpr(t.Elem())} + return &ast.ArrayType{Elt: typeExpr(p, t.Elem())} case *types.Array: return &ast.ArrayType{ Len: &ast.BasicLit{Kind: token.INT, Value: strconv.FormatInt(t.Len(), 10)}, - Elt: typeExpr(t.Elem()), + Elt: typeExpr(p, t.Elem()), } case *types.Map: return &ast.MapType{ - Key: typeExpr(t.Key()), - Value: typeExpr(t.Elem()), + Key: typeExpr(p, t.Key()), + Value: typeExpr(p, t.Elem()), } case *types.Struct: fields := make([]*ast.Field, t.NumFields()) for i := range fields { f := t.Field(i) - fields[i] = &ast.Field{Type: typeExpr(f.Type())} + fields[i] = &ast.Field{Type: typeExpr(p, f.Type())} if !f.Anonymous() { fields[i].Names = []*ast.Ident{ast.NewIdent(f.Name())} } @@ -43,13 +43,13 @@ func typeExpr(typ types.Type) ast.Expr { } return &ast.StructType{Fields: &ast.FieldList{List: fields}} case *types.Pointer: - return &ast.StarExpr{X: typeExpr(t.Elem())} + return &ast.StarExpr{X: typeExpr(p, t.Elem())} case *types.Interface: if t.Empty() { return ast.NewIdent("any") } case *types.Signature: - return newFuncType(t) + return newFuncType(p, t) case *types.Named: if t.TypeParams() != nil || t.TypeArgs() != nil { panic("not implemented: generic types") @@ -57,14 +57,13 @@ func typeExpr(typ types.Type) ast.Expr { obj := t.Obj() name := ast.NewIdent(obj.Name()) pkg := obj.Pkg() - if pkg == nil { + if pkg == nil || p == pkg { return name } - // TODO: this needs to be incorporated in the pass to find imports return &ast.SelectorExpr{X: ast.NewIdent(pkg.Name()), Sel: name} case *types.Chan: c := &ast.ChanType{ - Value: typeExpr(t.Elem()), + Value: typeExpr(p, t.Elem()), } switch t.Dir() { case types.SendRecv: @@ -79,24 +78,24 @@ func typeExpr(typ types.Type) ast.Expr { panic(fmt.Sprintf("not implemented: %T", typ)) } -func newFuncType(signature *types.Signature) *ast.FuncType { +func newFuncType(p *types.Package, signature *types.Signature) *ast.FuncType { return &ast.FuncType{ - Params: newFieldList(signature.Params()), - Results: newFieldList(signature.Results()), + Params: newFieldList(p, signature.Params()), + Results: newFieldList(p, signature.Results()), } } -func newFieldList(tuple *types.Tuple) *ast.FieldList { +func newFieldList(p *types.Package, tuple *types.Tuple) *ast.FieldList { return &ast.FieldList{ - List: newFields(tuple), + List: newFields(p, tuple), } } -func newFields(tuple *types.Tuple) []*ast.Field { +func newFields(p *types.Package, tuple *types.Tuple) []*ast.Field { fields := make([]*ast.Field, tuple.Len()) for i := range fields { fields[i] = &ast.Field{ - Type: typeExpr(tuple.At(i).Type()), + Type: typeExpr(p, tuple.At(i).Type()), } } return fields