diff --git a/compiler/function.go b/compiler/function.go index be9e12b..4a962b4 100644 --- a/compiler/function.go +++ b/compiler/function.go @@ -83,14 +83,16 @@ func collectFunctypes(p *packages.Package, name string, fn ast.Node, scope *func return v } - signature := copyFunctionType(functionTypeOf(fn)) - signature.TypeParams = nil - + // The function syntax may be generic, requiring translation of type param + // placeholders to known type args. var typeArg func(*types.TypeParam) types.Type if g != nil { typeArg = g.typeArgOf } + signature := copyFunctionType(functionTypeOf(fn)) + signature.TypeParams = nil + recv := copyFieldList(functionRecvOf(fn)) for _, fields := range []*ast.FieldList{recv, signature.Params, signature.Results} { if fields != nil { @@ -126,7 +128,11 @@ func collectFunctypes(p *packages.Package, name string, fn ast.Node, scope *func for _, name := range s.Names { typ := p.TypesInfo.TypeOf(name) if typ == nil { - scope.insert(name, s.Type) + // FIXME: this means that TypesInfo was not updated when syntax was + // generated or mutated. The following workaround is required as a + // result. + e := substituteTypeArgs(p, s.Type, typeArg) + scope.insert(name, e) } else { scope.insert(name, typeExpr(p, typ, typeArg)) } diff --git a/compiler/testdata/coroutine_durable.go b/compiler/testdata/coroutine_durable.go index 32eb38c..16db274 100644 --- a/compiler/testdata/coroutine_durable.go +++ b/compiler/testdata/coroutine_durable.go @@ -3577,8 +3577,8 @@ func init() { D uintptr X0 *struct { IP int - X0 *IdentityGenericStruct[T] - X1 T + X0 *IdentityGenericStruct[int] + X1 int } }]("github.com/stealthrocket/coroutine/compiler/testdata.(*IdentityGenericStruct[go.shape.int]).Closure.func2") _types.RegisterFunc[func()]("github.com/stealthrocket/coroutine/compiler/testdata.(*IdentityGenericStruct[go.shape.int]).Run") diff --git a/compiler/types.go b/compiler/types.go index fc43e19..1d433b2 100644 --- a/compiler/types.go +++ b/compiler/types.go @@ -130,3 +130,80 @@ func newFields(p *packages.Package, tuple *types.Tuple, typeArg func(*types.Type } return fields } + +// substituteTypeArgs replaces all type parameter placeholders +// with type args. +// +// It returns a deep copy of the input expr. +func substituteTypeArgs(p *packages.Package, expr ast.Expr, typeArg func(*types.TypeParam) types.Type) ast.Expr { + if expr == nil { + return nil + } + switch e := expr.(type) { + case *ast.ArrayType: + return &ast.ArrayType{ + Elt: substituteTypeArgs(p, e.Elt, typeArg), + Len: substituteTypeArgs(p, e.Len, typeArg), + } + case *ast.MapType: + return &ast.MapType{ + Key: substituteTypeArgs(p, e.Key, typeArg), + Value: substituteTypeArgs(p, e.Value, typeArg), + } + case *ast.FuncType: + return &ast.FuncType{ + TypeParams: substituteFieldList(p, e.TypeParams, typeArg), + Params: substituteFieldList(p, e.Params, typeArg), + Results: substituteFieldList(p, e.Results, typeArg), + } + case *ast.ChanType: + return &ast.ChanType{ + Dir: e.Dir, + Value: substituteTypeArgs(p, e.Value, typeArg), + } + case *ast.StructType: + return &ast.StructType{ + Fields: substituteFieldList(p, e.Fields, typeArg), + } + case *ast.StarExpr: + return &ast.StarExpr{ + X: substituteTypeArgs(p, e.X, typeArg), + } + case *ast.SelectorExpr: + return &ast.SelectorExpr{ + X: substituteTypeArgs(p, e.X, typeArg), + Sel: e.Sel, + } + case *ast.IndexExpr: + return &ast.IndexExpr{ + X: substituteTypeArgs(p, e.X, typeArg), + Index: substituteTypeArgs(p, e.Index, typeArg), + } + case *ast.Ident: + t := p.TypesInfo.TypeOf(e) + tp, ok := t.(*types.TypeParam) + if !ok { + return e + } + return typeExpr(p, typeArg(tp), typeArg) + case *ast.BasicLit: + return e + default: + panic(fmt.Sprintf("not implemented: %T", e)) + } +} + +func substituteFieldList(p *packages.Package, f *ast.FieldList, typeArg func(*types.TypeParam) types.Type) *ast.FieldList { + if f == nil || f.List == nil { + return f + } + fields := make([]*ast.Field, len(f.List)) + for i, field := range f.List { + fields[i] = &ast.Field{ + Names: field.Names, + Type: substituteTypeArgs(p, field.Type, typeArg), + Tag: field.Tag, + } + } + return &ast.FieldList{List: fields} +}