From 50dd032c57fa56e4e403a9bc7c8186c76cf6090f Mon Sep 17 00:00:00 2001 From: Chris O'Hara Date: Fri, 15 Dec 2023 09:13:30 +1000 Subject: [PATCH] Substitute type args when converting types.Type to ast.Expr --- compiler/compile.go | 4 ++-- compiler/decls.go | 4 ++-- compiler/desugar.go | 4 ++-- compiler/function.go | 44 ++++++++++++++++++++------------------------ compiler/types.go | 26 ++++++++++++++++---------- 5 files changed, 42 insertions(+), 40 deletions(-) diff --git a/compiler/compile.go b/compiler/compile.go index a8c075d..9b3d927 100644 --- a/compiler/compile.go +++ b/compiler/compile.go @@ -534,8 +534,8 @@ func (scope *scope) compileFuncBody(p *packages.Package, typ *ast.FuncType, body ctx := ast.NewIdent("_c") yieldTypeExpr := make([]ast.Expr, 2) - yieldTypeExpr[0] = typeExpr(p, color.Params().At(0).Type()) - yieldTypeExpr[1] = typeExpr(p, color.Results().At(0).Type()) + yieldTypeExpr[0] = typeExpr(p, color.Params().At(0).Type(), nil) + yieldTypeExpr[1] = typeExpr(p, color.Results().At(0).Type(), nil) coroutineIdent := ast.NewIdent("coroutine") p.TypesInfo.Uses[coroutineIdent] = types.NewPkgName(token.NoPos, p.Types, "coroutine", scope.compiler.coroutinePkg.Types) diff --git a/compiler/decls.go b/compiler/decls.go index eb67bfd..e61f50e 100644 --- a/compiler/decls.go +++ b/compiler/decls.go @@ -96,7 +96,7 @@ func extractDecls(p *packages.Package, typ *ast.FuncType, body *ast.BlockStmt, r } else { for _, spec := range n.Specs { valueSpec := spec.(*ast.ValueSpec) - valueType := typeExpr(p, info.TypeOf(valueSpec.Names[0])) + valueType := typeExpr(p, info.TypeOf(valueSpec.Names[0]), nil) for _, ident := range valueSpec.Names { if ident.Name != "_" { frameType.Fields.List = append(frameType.Fields.List, &ast.Field{ @@ -132,7 +132,7 @@ func extractDecls(p *packages.Package, typ *ast.FuncType, body *ast.BlockStmt, r } frameType.Fields.List = append(frameType.Fields.List, &ast.Field{ Names: []*ast.Ident{name}, - Type: typeExpr(p, t), + Type: typeExpr(p, t, nil), }) } } diff --git a/compiler/desugar.go b/compiler/desugar.go index 888a729..2764b4e 100644 --- a/compiler/desugar.go +++ b/compiler/desugar.go @@ -314,7 +314,7 @@ func (d *desugarer) desugar(stmt ast.Stmt, breakTo, continueTo, userLabel *ast.I &ast.CallExpr{ Fun: d.builtin("make"), Args: []ast.Expr{ - typeExpr(d.pkg, keySliceType), + typeExpr(d.pkg, keySliceType, nil), &ast.BasicLit{Kind: token.INT, Value: "0"}, &ast.CallExpr{Fun: d.builtin("len"), Args: []ast.Expr{x}}, }, @@ -476,7 +476,7 @@ func (d *desugarer) desugar(stmt ast.Stmt, breakTo, continueTo, userLabel *ast.I Specs: []ast.Spec{ &ast.ValueSpec{ Names: []*ast.Ident{tmpLhs}, - Type: typeExpr(d.pkg, lhsType), + Type: typeExpr(d.pkg, lhsType, nil), }, }, }}) diff --git a/compiler/function.go b/compiler/function.go index 16d659f..be9e12b 100644 --- a/compiler/function.go +++ b/compiler/function.go @@ -86,23 +86,21 @@ func collectFunctypes(p *packages.Package, name string, fn ast.Node, scope *func signature := copyFunctionType(functionTypeOf(fn)) signature.TypeParams = nil + var typeArg func(*types.TypeParam) types.Type + if g != nil { + typeArg = g.typeArgOf + } + recv := copyFieldList(functionRecvOf(fn)) for _, fields := range []*ast.FieldList{recv, signature.Params, signature.Results} { if fields != nil { for _, field := range fields.List { for _, name := range field.Names { typ := p.TypesInfo.TypeOf(name) - if g != nil { - if instanceType, ok := g.typeOfParam(typ); ok { - typ = instanceType - } - } - if typ != nil { - _, ellipsis := field.Type.(*ast.Ellipsis) - field.Type = typeExpr(p, typ) - if a, ok := field.Type.(*ast.ArrayType); ok && a.Len == nil && ellipsis { - field.Type = &ast.Ellipsis{Elt: a.Elt} - } + _, ellipsis := field.Type.(*ast.Ellipsis) + field.Type = typeExpr(p, typ, typeArg) + if a, ok := field.Type.(*ast.ArrayType); ok && a.Len == nil && ellipsis { + field.Type = &ast.Ellipsis{Elt: a.Elt} } scope.insert(name, field.Type) } @@ -127,15 +125,10 @@ func collectFunctypes(p *packages.Package, name string, fn ast.Node, scope *func case *ast.ValueSpec: for _, name := range s.Names { typ := p.TypesInfo.TypeOf(name) - if g != nil { - if instanceType, ok := g.typeOfParam(typ); ok { - typ = instanceType - } - } if typ == nil { scope.insert(name, s.Type) } else { - scope.insert(name, typeExpr(p, typ)) + scope.insert(name, typeExpr(p, typ, typeArg)) } } } @@ -439,7 +432,7 @@ type genericInstance struct { recvPtr bool recvType *types.Named - types map[types.Type]types.Type + typeArgs map[*types.TypeParam]types.Type } func newGenericInstance(origin, instance *ssa.Function) *genericInstance { @@ -463,22 +456,25 @@ func newGenericInstance(origin, instance *ssa.Function) *genericInstance { } } - g.types = map[types.Type]types.Type{} + g.typeArgs = map[*types.TypeParam]types.Type{} if g.recvType != nil { g.scanRecvTypeArgs(func(p *types.TypeParam, _ int, arg types.Type) { - g.types[p.Obj().Type()] = arg + g.typeArgs[p] = arg }) } g.scanTypeArgs(func(p *types.TypeParam, _ int, arg types.Type) { - g.types[p.Obj().Type()] = arg + g.typeArgs[p] = arg }) return g } -func (g *genericInstance) typeOfParam(t types.Type) (types.Type, bool) { - v, ok := g.types[t] - return v, ok +func (g *genericInstance) typeArgOf(param *types.TypeParam) types.Type { + arg, ok := g.typeArgs[param] + if !ok { + panic(fmt.Sprintf("not type arg found for %s", param)) + } + return arg } func (g *genericInstance) partial() bool { diff --git a/compiler/types.go b/compiler/types.go index eae8442..bbce9f4 100644 --- a/compiler/types.go +++ b/compiler/types.go @@ -10,7 +10,10 @@ import ( "golang.org/x/tools/go/packages" ) -func typeExpr(p *packages.Package, typ types.Type) ast.Expr { +// typeExpr converts a types.Type to an ast.Expr. +// +// If typeArg is provided, it's used to resolve type parameters. +func typeExpr(p *packages.Package, typ types.Type, typeArg func(*types.TypeParam) types.Type) ast.Expr { switch t := typ.(type) { case *types.Basic: switch t { @@ -19,22 +22,22 @@ func typeExpr(p *packages.Package, typ types.Type) ast.Expr { } return ast.NewIdent(t.String()) case *types.Slice: - return &ast.ArrayType{Elt: typeExpr(p, t.Elem())} + return &ast.ArrayType{Elt: typeExpr(p, t.Elem(), typeArg)} case *types.Array: return &ast.ArrayType{ Len: &ast.BasicLit{Kind: token.INT, Value: strconv.FormatInt(t.Len(), 10)}, - Elt: typeExpr(p, t.Elem()), + Elt: typeExpr(p, t.Elem(), typeArg), } case *types.Map: return &ast.MapType{ - Key: typeExpr(p, t.Key()), - Value: typeExpr(p, t.Elem()), + Key: typeExpr(p, t.Key(), typeArg), + Value: typeExpr(p, t.Elem(), typeArg), } case *types.Struct: fields := make([]*ast.Field, t.NumFields()) for i := range fields { f := t.Field(i) - fields[i] = &ast.Field{Type: typeExpr(p, f.Type())} + fields[i] = &ast.Field{Type: typeExpr(p, f.Type(), typeArg)} if !f.Anonymous() { fields[i].Names = []*ast.Ident{ast.NewIdent(f.Name())} } @@ -44,7 +47,7 @@ func typeExpr(p *packages.Package, typ types.Type) ast.Expr { } return &ast.StructType{Fields: &ast.FieldList{List: fields}} case *types.Pointer: - return &ast.StarExpr{X: typeExpr(p, t.Elem())} + return &ast.StarExpr{X: typeExpr(p, t.Elem(), typeArg)} case *types.Interface: if t.Empty() { return ast.NewIdent("any") @@ -70,7 +73,7 @@ func typeExpr(p *packages.Package, typ types.Type) ast.Expr { if typeArgs := t.TypeArgs(); typeArgs != nil { indices := make([]ast.Expr, typeArgs.Len()) for i := range indices { - indices[i] = typeExpr(p, typeArgs.At(i)) + indices[i] = typeExpr(p, typeArgs.At(i), typeArg) } namedExpr = &ast.IndexListExpr{ X: namedExpr, @@ -81,7 +84,7 @@ func typeExpr(p *packages.Package, typ types.Type) ast.Expr { case *types.Chan: c := &ast.ChanType{ - Value: typeExpr(p, t.Elem()), + Value: typeExpr(p, t.Elem(), typeArg), } switch t.Dir() { case types.SendRecv: @@ -94,6 +97,9 @@ func typeExpr(p *packages.Package, typ types.Type) ast.Expr { return c case *types.TypeParam: + if typeArg != nil { + return typeExpr(p, typeArg(t), typeArg) + } obj := t.Obj() ident := ast.NewIdent(obj.Name()) p.TypesInfo.Defs[ident] = obj @@ -119,7 +125,7 @@ func newFields(p *packages.Package, tuple *types.Tuple) []*ast.Field { fields := make([]*ast.Field, tuple.Len()) for i := range fields { fields[i] = &ast.Field{ - Type: typeExpr(p, tuple.At(i).Type()), + Type: typeExpr(p, tuple.At(i).Type(), nil), } } return fields