Skip to content

Commit

Permalink
Substitute type args when converting types.Type to ast.Expr
Browse files Browse the repository at this point in the history
  • Loading branch information
chriso committed Dec 14, 2023
1 parent 2aefa02 commit 50dd032
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 40 deletions.
4 changes: 2 additions & 2 deletions compiler/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -534,8 +534,8 @@ func (scope *scope) compileFuncBody(p *packages.Package, typ *ast.FuncType, body
ctx := ast.NewIdent("_c")

yieldTypeExpr := make([]ast.Expr, 2)
yieldTypeExpr[0] = typeExpr(p, color.Params().At(0).Type())
yieldTypeExpr[1] = typeExpr(p, color.Results().At(0).Type())
yieldTypeExpr[0] = typeExpr(p, color.Params().At(0).Type(), nil)
yieldTypeExpr[1] = typeExpr(p, color.Results().At(0).Type(), nil)

coroutineIdent := ast.NewIdent("coroutine")
p.TypesInfo.Uses[coroutineIdent] = types.NewPkgName(token.NoPos, p.Types, "coroutine", scope.compiler.coroutinePkg.Types)
Expand Down
4 changes: 2 additions & 2 deletions compiler/decls.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func extractDecls(p *packages.Package, typ *ast.FuncType, body *ast.BlockStmt, r
} else {
for _, spec := range n.Specs {
valueSpec := spec.(*ast.ValueSpec)
valueType := typeExpr(p, info.TypeOf(valueSpec.Names[0]))
valueType := typeExpr(p, info.TypeOf(valueSpec.Names[0]), nil)
for _, ident := range valueSpec.Names {
if ident.Name != "_" {
frameType.Fields.List = append(frameType.Fields.List, &ast.Field{
Expand Down Expand Up @@ -132,7 +132,7 @@ func extractDecls(p *packages.Package, typ *ast.FuncType, body *ast.BlockStmt, r
}
frameType.Fields.List = append(frameType.Fields.List, &ast.Field{
Names: []*ast.Ident{name},
Type: typeExpr(p, t),
Type: typeExpr(p, t, nil),
})
}
}
Expand Down
4 changes: 2 additions & 2 deletions compiler/desugar.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ func (d *desugarer) desugar(stmt ast.Stmt, breakTo, continueTo, userLabel *ast.I
&ast.CallExpr{
Fun: d.builtin("make"),
Args: []ast.Expr{
typeExpr(d.pkg, keySliceType),
typeExpr(d.pkg, keySliceType, nil),
&ast.BasicLit{Kind: token.INT, Value: "0"},
&ast.CallExpr{Fun: d.builtin("len"), Args: []ast.Expr{x}},
},
Expand Down Expand Up @@ -476,7 +476,7 @@ func (d *desugarer) desugar(stmt ast.Stmt, breakTo, continueTo, userLabel *ast.I
Specs: []ast.Spec{
&ast.ValueSpec{
Names: []*ast.Ident{tmpLhs},
Type: typeExpr(d.pkg, lhsType),
Type: typeExpr(d.pkg, lhsType, nil),
},
},
}})
Expand Down
44 changes: 20 additions & 24 deletions compiler/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,23 +86,21 @@ func collectFunctypes(p *packages.Package, name string, fn ast.Node, scope *func
signature := copyFunctionType(functionTypeOf(fn))
signature.TypeParams = nil

var typeArg func(*types.TypeParam) types.Type
if g != nil {
typeArg = g.typeArgOf
}

recv := copyFieldList(functionRecvOf(fn))
for _, fields := range []*ast.FieldList{recv, signature.Params, signature.Results} {
if fields != nil {
for _, field := range fields.List {
for _, name := range field.Names {
typ := p.TypesInfo.TypeOf(name)
if g != nil {
if instanceType, ok := g.typeOfParam(typ); ok {
typ = instanceType
}
}
if typ != nil {
_, ellipsis := field.Type.(*ast.Ellipsis)
field.Type = typeExpr(p, typ)
if a, ok := field.Type.(*ast.ArrayType); ok && a.Len == nil && ellipsis {
field.Type = &ast.Ellipsis{Elt: a.Elt}
}
_, ellipsis := field.Type.(*ast.Ellipsis)
field.Type = typeExpr(p, typ, typeArg)
if a, ok := field.Type.(*ast.ArrayType); ok && a.Len == nil && ellipsis {
field.Type = &ast.Ellipsis{Elt: a.Elt}
}
scope.insert(name, field.Type)
}
Expand All @@ -127,15 +125,10 @@ func collectFunctypes(p *packages.Package, name string, fn ast.Node, scope *func
case *ast.ValueSpec:
for _, name := range s.Names {
typ := p.TypesInfo.TypeOf(name)
if g != nil {
if instanceType, ok := g.typeOfParam(typ); ok {
typ = instanceType
}
}
if typ == nil {
scope.insert(name, s.Type)
} else {
scope.insert(name, typeExpr(p, typ))
scope.insert(name, typeExpr(p, typ, typeArg))
}
}
}
Expand Down Expand Up @@ -439,7 +432,7 @@ type genericInstance struct {
recvPtr bool
recvType *types.Named

types map[types.Type]types.Type
typeArgs map[*types.TypeParam]types.Type
}

func newGenericInstance(origin, instance *ssa.Function) *genericInstance {
Expand All @@ -463,22 +456,25 @@ func newGenericInstance(origin, instance *ssa.Function) *genericInstance {
}
}

g.types = map[types.Type]types.Type{}
g.typeArgs = map[*types.TypeParam]types.Type{}
if g.recvType != nil {
g.scanRecvTypeArgs(func(p *types.TypeParam, _ int, arg types.Type) {
g.types[p.Obj().Type()] = arg
g.typeArgs[p] = arg
})
}
g.scanTypeArgs(func(p *types.TypeParam, _ int, arg types.Type) {
g.types[p.Obj().Type()] = arg
g.typeArgs[p] = arg
})

return g
}

func (g *genericInstance) typeOfParam(t types.Type) (types.Type, bool) {
v, ok := g.types[t]
return v, ok
func (g *genericInstance) typeArgOf(param *types.TypeParam) types.Type {
arg, ok := g.typeArgs[param]
if !ok {
panic(fmt.Sprintf("not type arg found for %s", param))
}
return arg
}

func (g *genericInstance) partial() bool {
Expand Down
26 changes: 16 additions & 10 deletions compiler/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ import (
"golang.org/x/tools/go/packages"
)

func typeExpr(p *packages.Package, typ types.Type) ast.Expr {
// typeExpr converts a types.Type to an ast.Expr.
//
// If typeArg is provided, it's used to resolve type parameters.
func typeExpr(p *packages.Package, typ types.Type, typeArg func(*types.TypeParam) types.Type) ast.Expr {
switch t := typ.(type) {
case *types.Basic:
switch t {
Expand All @@ -19,22 +22,22 @@ func typeExpr(p *packages.Package, typ types.Type) ast.Expr {
}
return ast.NewIdent(t.String())
case *types.Slice:
return &ast.ArrayType{Elt: typeExpr(p, t.Elem())}
return &ast.ArrayType{Elt: typeExpr(p, t.Elem(), typeArg)}
case *types.Array:
return &ast.ArrayType{
Len: &ast.BasicLit{Kind: token.INT, Value: strconv.FormatInt(t.Len(), 10)},
Elt: typeExpr(p, t.Elem()),
Elt: typeExpr(p, t.Elem(), typeArg),
}
case *types.Map:
return &ast.MapType{
Key: typeExpr(p, t.Key()),
Value: typeExpr(p, t.Elem()),
Key: typeExpr(p, t.Key(), typeArg),
Value: typeExpr(p, t.Elem(), typeArg),
}
case *types.Struct:
fields := make([]*ast.Field, t.NumFields())
for i := range fields {
f := t.Field(i)
fields[i] = &ast.Field{Type: typeExpr(p, f.Type())}
fields[i] = &ast.Field{Type: typeExpr(p, f.Type(), typeArg)}
if !f.Anonymous() {
fields[i].Names = []*ast.Ident{ast.NewIdent(f.Name())}
}
Expand All @@ -44,7 +47,7 @@ func typeExpr(p *packages.Package, typ types.Type) ast.Expr {
}
return &ast.StructType{Fields: &ast.FieldList{List: fields}}
case *types.Pointer:
return &ast.StarExpr{X: typeExpr(p, t.Elem())}
return &ast.StarExpr{X: typeExpr(p, t.Elem(), typeArg)}
case *types.Interface:
if t.Empty() {
return ast.NewIdent("any")
Expand All @@ -70,7 +73,7 @@ func typeExpr(p *packages.Package, typ types.Type) ast.Expr {
if typeArgs := t.TypeArgs(); typeArgs != nil {
indices := make([]ast.Expr, typeArgs.Len())
for i := range indices {
indices[i] = typeExpr(p, typeArgs.At(i))
indices[i] = typeExpr(p, typeArgs.At(i), typeArg)
}
namedExpr = &ast.IndexListExpr{
X: namedExpr,
Expand All @@ -81,7 +84,7 @@ func typeExpr(p *packages.Package, typ types.Type) ast.Expr {

case *types.Chan:
c := &ast.ChanType{
Value: typeExpr(p, t.Elem()),
Value: typeExpr(p, t.Elem(), typeArg),
}
switch t.Dir() {
case types.SendRecv:
Expand All @@ -94,6 +97,9 @@ func typeExpr(p *packages.Package, typ types.Type) ast.Expr {
return c

case *types.TypeParam:
if typeArg != nil {
return typeExpr(p, typeArg(t), typeArg)
}
obj := t.Obj()
ident := ast.NewIdent(obj.Name())
p.TypesInfo.Defs[ident] = obj
Expand All @@ -119,7 +125,7 @@ func newFields(p *packages.Package, tuple *types.Tuple) []*ast.Field {
fields := make([]*ast.Field, tuple.Len())
for i := range fields {
fields[i] = &ast.Field{
Type: typeExpr(p, tuple.At(i).Type()),
Type: typeExpr(p, tuple.At(i).Type(), nil),
}
}
return fields
Expand Down

0 comments on commit 50dd032

Please sign in to comment.