diff --git a/compiler/compile.go b/compiler/compile.go index a38c35d..9b3d927 100644 --- a/compiler/compile.go +++ b/compiler/compile.go @@ -435,7 +435,7 @@ func (scope *scope) compileFuncDecl(p *packages.Package, fn *ast.FuncDecl, color // Generate the coroutine function. At this stage, use the same name // as the source function (and require that the caller use build tags // to disambiguate function calls). - fnType := funcTypeWithNamedResults(fn.Type) + fnType := funcTypeWithNamedResults(p, fn) gen := &ast.FuncDecl{ Recv: fn.Recv, Doc: &ast.CommentGroup{}, @@ -473,7 +473,7 @@ func (scope *scope) compileFuncLit(p *packages.Package, fn *ast.FuncLit, color * log.Printf("compiling function literal %s", p.Name) gen := &ast.FuncLit{ - Type: funcTypeWithNamedResults(fn.Type), + Type: funcTypeWithNamedResults(p, fn), Body: scope.compileFuncBody(p, fn.Type, fn.Body, nil, color), } @@ -534,8 +534,8 @@ func (scope *scope) compileFuncBody(p *packages.Package, typ *ast.FuncType, body ctx := ast.NewIdent("_c") yieldTypeExpr := make([]ast.Expr, 2) - yieldTypeExpr[0] = typeExpr(p, color.Params().At(0).Type()) - yieldTypeExpr[1] = typeExpr(p, color.Results().At(0).Type()) + yieldTypeExpr[0] = typeExpr(p, color.Params().At(0).Type(), nil) + yieldTypeExpr[1] = typeExpr(p, color.Results().At(0).Type(), nil) coroutineIdent := ast.NewIdent("coroutine") p.TypesInfo.Uses[coroutineIdent] = types.NewPkgName(token.NoPos, p.Types, "coroutine", scope.compiler.coroutinePkg.Types) @@ -710,21 +710,42 @@ func isExpr(body *ast.BlockStmt) bool { return false } -func funcTypeWithNamedResults(t *ast.FuncType) *ast.FuncType { +func funcTypeWithNamedResults(p *packages.Package, n ast.Node) *ast.FuncType { + t := functionTypeOf(n) + signature := functionSignatureOf(p, n) + if signature == nil { + panic("missing type info for func decl or lit") + } if t.Results == nil { return t } - underscore := ast.NewIdent("_") funcType := *t funcType.Results = &ast.FieldList{ List: slices.Clone(t.Results.List), } + resultTypes := signature.Results() + if resultTypes == nil || resultTypes.Len() == 0 { + panic("result type count mismatch") + } + typePos := 0 for i, f := range t.Results.List { - if len(f.Names) == 0 { - field := *f - field.Names = []*ast.Ident{underscore} - funcType.Results.List[i] = &field + if len(f.Names) > 0 { + typePos += len(f.Names) + continue + } + if typePos >= resultTypes.Len() { + panic("result type count mismatch") } + t := resultTypes.At(typePos) + underscore := ast.NewIdent("_") + p.TypesInfo.Defs[underscore] = t + field := *f + field.Names = []*ast.Ident{underscore} + funcType.Results.List[i] = &field + typePos++ + } + if typePos != resultTypes.Len() { + panic("result type count mismatch") } return &funcType } diff --git a/compiler/coroutine_test.go b/compiler/coroutine_test.go index a278da2..03fac61 100644 --- a/compiler/coroutine_test.go +++ b/compiler/coroutine_test.go @@ -220,23 +220,28 @@ func TestCoroutineYield(t *testing.T) { coro: func() { StructClosure(3) }, yields: []int{-1, 10, 100, 1000, 11, 101, 1000, 12, 102, 1000}, }, - { - name: "identity generic", - coro: func() { IdentityGeneric[int](11) }, + name: "generic function", + coro: func() { IdentityGenericInt(11) }, yields: []int{11}, }, { - name: "identity generic (2)", - coro: func() { IdentityGenericInt(11) }, + name: "generic struct", + coro: func() { IdentityGenericStructInt(11) }, yields: []int{11}, }, { - name: "identity generic (3)", - coro: func() { IdentityGenericStructInt(11) }, - yields: []int{11}, + name: "generic function closure", + coro: func() { IdentityGenericClosureInt(11) }, + yields: []int{11, 11}, + }, + + { + name: "generic struct closure", + coro: func() { IdentityGenericStructClosureInt(11) }, + yields: []int{-1, 11, 100, 23, 12, 101, 45}, }, } diff --git a/compiler/decls.go b/compiler/decls.go index eb67bfd..e61f50e 100644 --- a/compiler/decls.go +++ b/compiler/decls.go @@ -96,7 +96,7 @@ func extractDecls(p *packages.Package, typ *ast.FuncType, body *ast.BlockStmt, r } else { for _, spec := range n.Specs { valueSpec := spec.(*ast.ValueSpec) - valueType := typeExpr(p, info.TypeOf(valueSpec.Names[0])) + valueType := typeExpr(p, info.TypeOf(valueSpec.Names[0]), nil) for _, ident := range valueSpec.Names { if ident.Name != "_" { frameType.Fields.List = append(frameType.Fields.List, &ast.Field{ @@ -132,7 +132,7 @@ func extractDecls(p *packages.Package, typ *ast.FuncType, body *ast.BlockStmt, r } frameType.Fields.List = append(frameType.Fields.List, &ast.Field{ Names: []*ast.Ident{name}, - Type: typeExpr(p, t), + Type: typeExpr(p, t, nil), }) } } diff --git a/compiler/desugar.go b/compiler/desugar.go index 888a729..2764b4e 100644 --- a/compiler/desugar.go +++ b/compiler/desugar.go @@ -314,7 +314,7 @@ func (d *desugarer) desugar(stmt ast.Stmt, breakTo, continueTo, userLabel *ast.I &ast.CallExpr{ Fun: d.builtin("make"), Args: []ast.Expr{ - typeExpr(d.pkg, keySliceType), + typeExpr(d.pkg, keySliceType, nil), &ast.BasicLit{Kind: token.INT, Value: "0"}, &ast.CallExpr{Fun: d.builtin("len"), Args: []ast.Expr{x}}, }, @@ -476,7 +476,7 @@ func (d *desugarer) desugar(stmt ast.Stmt, breakTo, continueTo, userLabel *ast.I Specs: []ast.Spec{ &ast.ValueSpec{ Names: []*ast.Ident{tmpLhs}, - Type: typeExpr(d.pkg, lhsType), + Type: typeExpr(d.pkg, lhsType, nil), }, }, }}) diff --git a/compiler/function.go b/compiler/function.go index 7a4a4f8..75ed118 100644 --- a/compiler/function.go +++ b/compiler/function.go @@ -13,6 +13,7 @@ import ( "golang.org/x/tools/go/ast/astutil" "golang.org/x/tools/go/packages" + "golang.org/x/tools/go/ssa" ) type functype struct { @@ -58,7 +59,7 @@ type funcvar struct { typ ast.Expr } -func collectFunctypes(p *packages.Package, name string, fn ast.Node, scope *funcscope, colors map[ast.Node]*types.Signature, functypes map[string]functype) { +func collectFunctypes(p *packages.Package, name string, fn ast.Node, scope *funcscope, colors map[ast.Node]*types.Signature, functypes map[string]functype, g *genericInstance) { type function struct { node ast.Node scope *funcscope @@ -82,12 +83,27 @@ func collectFunctypes(p *packages.Package, name string, fn ast.Node, scope *func return v } - signature := functionTypeOf(fn) - recv := functionRecvOf(fn) + // The function syntax may be generic, requiring translation of type param + // placeholders to known type args. + var typeArg func(*types.TypeParam) types.Type + if g != nil { + typeArg = g.typeArgOf + } + + signature := copyFunctionType(functionTypeOf(fn)) + signature.TypeParams = nil + + recv := copyFieldList(functionRecvOf(fn)) for _, fields := range []*ast.FieldList{recv, signature.Params, signature.Results} { if fields != nil { for _, field := range fields.List { for _, name := range field.Names { + typ := p.TypesInfo.TypeOf(name) + _, ellipsis := field.Type.(*ast.Ellipsis) + field.Type = typeExpr(p, typ, typeArg) + if a, ok := field.Type.(*ast.ArrayType); ok && a.Len == nil && ellipsis { + field.Type = &ast.Ellipsis{Elt: a.Elt} + } scope.insert(name, field.Type) } } @@ -110,11 +126,16 @@ func collectFunctypes(p *packages.Package, name string, fn ast.Node, scope *func switch s := spec.(type) { case *ast.ValueSpec: for _, name := range s.Names { - typ := s.Type + typ := p.TypesInfo.TypeOf(name) if typ == nil { - typ = typeExpr(p, p.TypesInfo.TypeOf(name)) + // FIXME: this means that TypesInfo was not updated when syntax was + // generated or mutated. The following workaround is required as a + // result. + e := substituteTypeArgs(p, s.Type, typeArg) + scope.insert(name, e) + } else { + scope.insert(name, typeExpr(p, typ, typeArg)) } - scope.insert(name, typ) } } } @@ -151,10 +172,11 @@ func collectFunctypes(p *packages.Package, name string, fn ast.Node, scope *func signature: signature, } if len(freeVars) > 0 { - fields := make([]*ast.Field, 1+len(freeVars)) - fields[0] = &ast.Field{ - Type: ast.NewIdent("uintptr"), - Names: []*ast.Ident{ast.NewIdent("F")}, + fields := []*ast.Field{ + { + Type: ast.NewIdent("uintptr"), + Names: []*ast.Ident{ast.NewIdent("F")}, + }, } for i, freeVar := range freeVars { fieldName := ast.NewIdent(fmt.Sprintf("X%d", i)) @@ -170,10 +192,17 @@ func collectFunctypes(p *packages.Package, name string, fn ast.Node, scope *func // and pointers will be less than 128 bytes on all platforms, which // means that the stack frame pointer is always captured by value. - fields[i+1] = &ast.Field{ + fields = append(fields, &ast.Field{ Type: fieldType, Names: []*ast.Ident{fieldName}, - } + }) + } + if g != nil { + // Append a field for the dictionary. + fields = append(fields, &ast.Field{ + Type: ast.NewIdent("uintptr"), + Names: []*ast.Ident{ast.NewIdent("D")}, + }) } functype.closure = &ast.StructType{ Fields: &ast.FieldList{List: fields}, @@ -193,7 +222,7 @@ func collectFunctypes(p *packages.Package, name string, fn ast.Node, scope *func for i, anonFunc := range anonFuncs[index:] { anonFuncName := anonFuncLinkName(name, index+i+1) - collectFunctypes(p, anonFuncName, anonFunc.node, anonFunc.scope, colors, functypes) + collectFunctypes(p, anonFuncName, anonFunc.node, anonFunc.scope, colors, functypes, g) } } } @@ -245,13 +274,28 @@ func (c *compiler) generateFunctypes(p *packages.Package, f *ast.File, colors ma obj := p.TypesInfo.ObjectOf(d.Name).(*types.Func) fn := c.prog.FuncValue(obj) if fn.TypeParams() != nil { - // TODO: support generics. Generate type func/closure type information for each instance from: instances := c.generics[fn] - log.Printf("warning: cannot register runtime type information for generic function %s", fn) - continue + instances := c.generics[fn] + if len(instances) == 0 { + // This can occur when a generic function is never instantiated/used, + // or when it's instantiated in a package not known to the compiler. + log.Printf("warning: cannot register runtime type information for generic function %s", fn) + continue + } + for _, instance := range instances { + g := newGenericInstance(fn, instance) + if g.partial() { + // Skip instances where not all type params have concrete types. + // I'm not sure why these are generated in the SSA program. + continue + } + scope := &funcscope{vars: map[string]*funcvar{}} + name := g.gcshapePath() + collectFunctypes(p, name, d, scope, colors, functypes, g) + } } else { scope := &funcscope{vars: map[string]*funcvar{}} name := functionPath(p, d) - collectFunctypes(p, name, d, scope, colors, functypes) + collectFunctypes(p, name, d, scope, colors, functypes, nil) } } } @@ -330,6 +374,17 @@ func functionTypeOf(fn ast.Node) *ast.FuncType { } } +func functionSignatureOf(p *packages.Package, fn ast.Node) *types.Signature { + switch f := fn.(type) { + case *ast.FuncDecl: + return p.TypesInfo.Defs[f.Name].Type().(*types.Signature) + case *ast.FuncLit: + return p.TypesInfo.TypeOf(f).(*types.Signature) + default: + panic("node is neither *ast.FuncDecl or *ast.FuncLit") + } +} + func functionRecvOf(fn ast.Node) *ast.FieldList { switch f := fn.(type) { case *ast.FuncDecl: @@ -341,6 +396,30 @@ func functionRecvOf(fn ast.Node) *ast.FieldList { } } +func copyFunctionType(f *ast.FuncType) *ast.FuncType { + return &ast.FuncType{ + TypeParams: copyFieldList(f.TypeParams), + Params: copyFieldList(f.Params), + Results: copyFieldList(f.Results), + } +} + +func copyFieldList(f *ast.FieldList) *ast.FieldList { + if f == nil { + return nil + } + list := make([]*ast.Field, len(f.List)) + for i := range f.List { + list[i] = copyField(f.List[i]) + } + return &ast.FieldList{List: list} +} + +func copyField(f *ast.Field) *ast.Field { + c := *f + return &c +} + func functionBodyOf(fn ast.Node) *ast.BlockStmt { switch f := fn.(type) { case *ast.FuncDecl: @@ -351,3 +430,157 @@ func functionBodyOf(fn ast.Node) *ast.BlockStmt { panic("node is neither *ast.FuncDecl or *ast.FuncLit") } } + +type genericInstance struct { + origin *ssa.Function + instance *ssa.Function + + recvPtr bool + recvType *types.Named + + typeArgs map[*types.TypeParam]types.Type +} + +func newGenericInstance(origin, instance *ssa.Function) *genericInstance { + g := &genericInstance{origin: origin, instance: instance} + + if recv := g.instance.Signature.Recv(); recv != nil { + switch t := recv.Type().(type) { + case *types.Pointer: + g.recvPtr = true + switch pt := t.Elem().(type) { + case *types.Named: + g.recvType = pt + default: + panic(fmt.Sprintf("not implemented: %T", t)) + } + + case *types.Named: + g.recvType = t + default: + panic(fmt.Sprintf("not implemented: %T", t)) + } + } + + g.typeArgs = map[*types.TypeParam]types.Type{} + if g.recvType != nil { + g.scanRecvTypeArgs(func(p *types.TypeParam, _ int, arg types.Type) { + g.typeArgs[p] = arg + }) + } + g.scanTypeArgs(func(p *types.TypeParam, _ int, arg types.Type) { + g.typeArgs[p] = arg + }) + + return g +} + +func (g *genericInstance) typeArgOf(param *types.TypeParam) types.Type { + arg, ok := g.typeArgs[param] + if !ok { + panic(fmt.Sprintf("not type arg found for %s", param)) + } + return arg +} + +func (g *genericInstance) partial() bool { + sig := g.instance.Signature + params := sig.Params() + for i := 0; i < params.Len(); i++ { + if _, ok := params.At(i).Type().(*types.TypeParam); ok { + return true + } + } + results := sig.Results() + for i := 0; i < results.Len(); i++ { + if _, ok := results.At(i).Type().(*types.TypeParam); ok { + return true + } + } + return false +} + +func (g *genericInstance) scanRecvTypeArgs(fn func(*types.TypeParam, int, types.Type)) { + typeParams := g.instance.Signature.RecvTypeParams() + typeArgs := g.recvType.TypeArgs() + for i := 0; i < typeArgs.Len(); i++ { + arg := typeArgs.At(i) + param := typeParams.At(i) + + fn(param, i, arg) + } +} + +func (g *genericInstance) scanTypeArgs(fn func(*types.TypeParam, int, types.Type)) { + params := g.origin.TypeParams() + args := g.instance.TypeArgs() + + for i := 0; i < params.Len(); i++ { + fn(params.At(i), i, args[i]) + } +} + +func (g *genericInstance) gcshapePath() string { + var path strings.Builder + + path.WriteString(g.origin.Pkg.Pkg.Path()) + + if g.recvType != nil { + path.WriteByte('.') + if g.recvPtr { + path.WriteString("(*") + } + path.WriteString(g.recvType.Obj().Name()) + + if g.recvType.TypeParams() != nil { + path.WriteByte('[') + g.scanRecvTypeArgs(func(_ *types.TypeParam, i int, arg types.Type) { + if i > 0 { + path.WriteString(",") + } + writeGoShape(&path, arg) + }) + path.WriteByte(']') + } + + if g.recvPtr { + path.WriteByte(')') + } + } + + path.WriteByte('.') + path.WriteString(g.instance.Object().(*types.Func).Name()) + + if g.origin.Signature.TypeParams() != nil { + path.WriteByte('[') + g.scanTypeArgs(func(_ *types.TypeParam, i int, arg types.Type) { + if i > 0 { + path.WriteString(",") + } + writeGoShape(&path, arg) + }) + path.WriteByte(']') + } + + return path.String() +} + +func writeGoShape(b *strings.Builder, tt types.Type) { + b.WriteString("go.shape.") + + switch t := tt.Underlying().(type) { + case *types.Basic: + b.WriteString(t.Name()) + case *types.Pointer: + // All pointers resolve to *uint8. + b.WriteString("*uint8") + case *types.Interface: + if t.Empty() { + b.WriteString("interface{}") + } else { + panic(fmt.Sprintf("not implemented: %#v (%T)", tt, t)) + } + default: + panic(fmt.Sprintf("not implemented: %#v (%T)", tt, t)) + } +} diff --git a/compiler/testdata/coroutine.go b/compiler/testdata/coroutine.go index 97fc7fd..958d28e 100644 --- a/compiler/testdata/coroutine.go +++ b/compiler/testdata/coroutine.go @@ -558,26 +558,6 @@ func ReturnNamedValue() (out int) { return } -func IdentityGeneric[T any](n T) { - coroutine.Yield[T, any](n) -} - -type IdentityGenericStruct[T any] struct { - n T -} - -func (i *IdentityGenericStruct[T]) Run() { - coroutine.Yield[T, any](i.n) -} - -func IdentityGenericInt(n int) { - IdentityGeneric[int](n) -} - -func IdentityGenericStructInt(n int) { - (&IdentityGenericStruct[int]{n: n}).Run() -} - type Box struct { x int } @@ -604,3 +584,65 @@ func StructClosure(n int) { fn(1000) } } + +func IdentityGeneric[T any](n T) { + coroutine.Yield[T, any](n) +} + +func IdentityGenericInt(n int) { + IdentityGeneric[int](n) +} + +func IdentityGenericClosure[T any](n T) { + fn := buildClosure(n) + fn() + fn() +} + +// TODO: add this go:noinline directive automatically (once stealthrocket/coroutine#84 is fixed) +// +//go:noinline +func buildClosure[T any](n T) func() { + return func() { + coroutine.Yield[T, any](n) + } +} + +func IdentityGenericClosureInt(n int) { + IdentityGenericClosure[int](n) +} + +type integer interface { + ~int | ~int8 | ~int16 | ~int32 | ~int64 +} + +type IdentityGenericStruct[T integer] struct { + n T +} + +func (i *IdentityGenericStruct[T]) Run() { + coroutine.Yield[T, any](i.n) +} + +func (i *IdentityGenericStruct[T]) Closure(n T) func(T) { + // Force compilation of this method. Remove once #84 is fixed. + coroutine.Yield[T, any](-1) + + return func(x T) { + coroutine.Yield[T, any](i.n) + i.n++ + coroutine.Yield[T, any](n) + n++ + coroutine.Yield[T, any](x) + } +} + +func IdentityGenericStructInt(n int) { + (&IdentityGenericStruct[int]{n: n}).Run() +} + +func IdentityGenericStructClosureInt(n int) { + fn := (&IdentityGenericStruct[int]{n: n}).Closure(100) + fn(23) + fn(45) +} diff --git a/compiler/testdata/coroutine_durable.go b/compiler/testdata/coroutine_durable.go index 0e4b610..ae33652 100644 --- a/compiler/testdata/coroutine_durable.go +++ b/compiler/testdata/coroutine_durable.go @@ -3237,22 +3237,6 @@ func ReturnNamedValue() (_fn0 int) { panic("unreachable") } -//go:noinline -func IdentityGeneric[T any](n T) { coroutine.Yield[T, any](n) } - -type IdentityGenericStruct[T any] struct { - n T -} - -//go:noinline -func (i *IdentityGenericStruct[T]) Run() { coroutine.Yield[T, any](i.n) } - -//go:noinline -func IdentityGenericInt(n int) { IdentityGeneric[int](n) } - -//go:noinline -func IdentityGenericStructInt(n int) { (&IdentityGenericStruct[int]{n: n}).Run() } - type Box struct { x int } @@ -3391,6 +3375,192 @@ func StructClosure(_fn0 int) { } } } + +//go:noinline +func IdentityGeneric[T any](n T) { coroutine.Yield[T, any](n) } + +//go:noinline +func IdentityGenericInt(n int) { IdentityGeneric[int](n) } + +//go:noinline +func IdentityGenericClosure[T any](_fn0 T) { + _c := coroutine.LoadContext[T, any]() + var _f0 *struct { + IP int + X0 T + X1 func() + } = coroutine.Push[struct { + IP int + X0 T + X1 func() + }](&_c.Stack) + if _f0.IP == 0 { + *_f0 = struct { + IP int + X0 T + X1 func() + }{X0: _fn0} + } + defer func() { + if !_c.Unwinding() { + coroutine.Pop(&_c.Stack) + } + }() + switch { + case _f0.IP < 2: + _f0.X1 = buildClosure(_f0.X0) + _f0.IP = 2 + fallthrough + case _f0.IP < 3: + _f0.X1() + _f0.IP = 3 + fallthrough + case _f0.IP < 4: + _f0.X1() + } +} + +// TODO: add this go:noinline directive automatically (once stealthrocket/coroutine#84 is fixed) +// +//go:noinline +func buildClosure[T any](n T) func() { + return func() { + coroutine.Yield[T, any](n) + } +} + +//go:noinline +func IdentityGenericClosureInt(n int) { IdentityGenericClosure[int](n) } + +type integer interface { + ~int | ~int8 | ~int16 | ~int32 | ~int64 +} + +type IdentityGenericStruct[T integer] struct { + n T +} + +//go:noinline +func (i *IdentityGenericStruct[T]) Run() { coroutine.Yield[T, any](i.n) } + +//go:noinline +func (_fn0 *IdentityGenericStruct[T]) Closure(_fn1 T) (_ func(T)) { + _c := coroutine.LoadContext[int, any]() + var _f1 *struct { + IP int + X0 *IdentityGenericStruct[T] + X1 T + } = coroutine.Push[struct { + IP int + X0 *IdentityGenericStruct[T] + X1 T + }](&_c.Stack) + if _f1.IP == 0 { + *_f1 = struct { + IP int + X0 *IdentityGenericStruct[T] + X1 T + }{X0: _fn0, X1: _fn1} + } + defer func() { + if !_c.Unwinding() { + coroutine.Pop(&_c.Stack) + } + }() + switch { + case _f1.IP < 2: + + coroutine.Yield[T, any](-1) + _f1.IP = 2 + fallthrough + case _f1.IP < 3: + + return func(_fn0 T) { + _c := coroutine.LoadContext[int, any]() + var _f0 *struct { + IP int + X0 T + } = coroutine.Push[struct { + IP int + X0 T + }](&_c.Stack) + if _f0.IP == 0 { + *_f0 = struct { + IP int + X0 T + }{X0: _fn0} + } + defer func() { + if !_c.Unwinding() { + coroutine.Pop(&_c.Stack) + } + }() + switch { + case _f0.IP < 2: + coroutine.Yield[T, any](_f1.X0.n) + _f0.IP = 2 + fallthrough + case _f0.IP < 3: + _f1.X0. + n++ + _f0.IP = 3 + fallthrough + case _f0.IP < 4: + coroutine.Yield[T, any](_f1.X1) + _f0.IP = 4 + fallthrough + case _f0.IP < 5: + _f1.X1++ + _f0.IP = 5 + fallthrough + case _f0.IP < 6: + coroutine.Yield[T, any](_f0.X0) + } + } + } + panic("unreachable") +} + +//go:noinline +func IdentityGenericStructInt(n int) { (&IdentityGenericStruct[int]{n: n}).Run() } + +//go:noinline +func IdentityGenericStructClosureInt(_fn0 int) { + _c := coroutine.LoadContext[int, any]() + var _f0 *struct { + IP int + X0 int + X1 func(int) + } = coroutine.Push[struct { + IP int + X0 int + X1 func(int) + }](&_c.Stack) + if _f0.IP == 0 { + *_f0 = struct { + IP int + X0 int + X1 func(int) + }{X0: _fn0} + } + defer func() { + if !_c.Unwinding() { + coroutine.Pop(&_c.Stack) + } + }() + switch { + case _f0.IP < 2: + _f0.X1 = (&IdentityGenericStruct[int]{n: _f0.X0}).Closure(100) + _f0.IP = 2 + fallthrough + case _f0.IP < 3: + _f0.X1(23) + _f0.IP = 3 + fallthrough + case _f0.IP < 4: + _f0.X1(45) + } +} func init() { _types.RegisterFunc[func(_fn1 int) (_ func(int))]("github.com/stealthrocket/coroutine/compiler/testdata.(*Box).Closure") _types.RegisterClosure[func(_fn0 int), struct { @@ -3401,14 +3571,29 @@ func init() { X1 int } }]("github.com/stealthrocket/coroutine/compiler/testdata.(*Box).Closure.func2") + _types.RegisterFunc[func(_fn1 int) (_ func(int))]("github.com/stealthrocket/coroutine/compiler/testdata.(*IdentityGenericStruct[go.shape.int]).Closure") + _types.RegisterClosure[func(_fn0 int), struct { + F uintptr + X0 *struct { + IP int + X0 *IdentityGenericStruct[int] + X1 int + } + D uintptr + }]("github.com/stealthrocket/coroutine/compiler/testdata.(*IdentityGenericStruct[go.shape.int]).Closure.func2") + _types.RegisterFunc[func()]("github.com/stealthrocket/coroutine/compiler/testdata.(*IdentityGenericStruct[go.shape.int]).Run") _types.RegisterFunc[func(_fn1 int)]("github.com/stealthrocket/coroutine/compiler/testdata.(*MethodGeneratorState).MethodGenerator") _types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.Double") _types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.EvenSquareGenerator") _types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.FizzBuzzIfGenerator") _types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.FizzBuzzSwitchGenerator") _types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.Identity") + _types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.IdentityGenericClosureInt") + _types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.IdentityGenericClosure[go.shape.int]") _types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.IdentityGenericInt") + _types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.IdentityGenericStructClosureInt") _types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.IdentityGenericStructInt") + _types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.IdentityGeneric[go.shape.int]") _types.RegisterFunc[func(_ int)]("github.com/stealthrocket/coroutine/compiler/testdata.LoopBreakAndContinue") _types.RegisterFunc[func(_fn0 int) (_ int)]("github.com/stealthrocket/coroutine/compiler/testdata.NestedLoops") _types.RegisterFunc[func(_fn0 int, _fn1 func(int))]("github.com/stealthrocket/coroutine/compiler/testdata.Range") @@ -3533,5 +3718,11 @@ func init() { _types.RegisterFunc[func()]("github.com/stealthrocket/coroutine/compiler/testdata.YieldingExpressionDesugaring") _types.RegisterFunc[func(_fn0 int) (_ int)]("github.com/stealthrocket/coroutine/compiler/testdata.a") _types.RegisterFunc[func(_fn0 int) (_ int)]("github.com/stealthrocket/coroutine/compiler/testdata.b") + _types.RegisterFunc[func(n int) func()]("github.com/stealthrocket/coroutine/compiler/testdata.buildClosure[go.shape.int]") + _types.RegisterClosure[func(), struct { + F uintptr + X0 int + D uintptr + }]("github.com/stealthrocket/coroutine/compiler/testdata.buildClosure[go.shape.int].func1") _types.RegisterFunc[func(_fn0 ...int)]("github.com/stealthrocket/coroutine/compiler/testdata.varArgs") } diff --git a/compiler/types.go b/compiler/types.go index eae8442..1d433b2 100644 --- a/compiler/types.go +++ b/compiler/types.go @@ -10,7 +10,10 @@ import ( "golang.org/x/tools/go/packages" ) -func typeExpr(p *packages.Package, typ types.Type) ast.Expr { +// typeExpr converts a types.Type to an ast.Expr. +// +// If typeArg is provided, it's used to resolve type parameters. +func typeExpr(p *packages.Package, typ types.Type, typeArg func(*types.TypeParam) types.Type) ast.Expr { switch t := typ.(type) { case *types.Basic: switch t { @@ -19,22 +22,22 @@ func typeExpr(p *packages.Package, typ types.Type) ast.Expr { } return ast.NewIdent(t.String()) case *types.Slice: - return &ast.ArrayType{Elt: typeExpr(p, t.Elem())} + return &ast.ArrayType{Elt: typeExpr(p, t.Elem(), typeArg)} case *types.Array: return &ast.ArrayType{ Len: &ast.BasicLit{Kind: token.INT, Value: strconv.FormatInt(t.Len(), 10)}, - Elt: typeExpr(p, t.Elem()), + Elt: typeExpr(p, t.Elem(), typeArg), } case *types.Map: return &ast.MapType{ - Key: typeExpr(p, t.Key()), - Value: typeExpr(p, t.Elem()), + Key: typeExpr(p, t.Key(), typeArg), + Value: typeExpr(p, t.Elem(), typeArg), } case *types.Struct: fields := make([]*ast.Field, t.NumFields()) for i := range fields { f := t.Field(i) - fields[i] = &ast.Field{Type: typeExpr(p, f.Type())} + fields[i] = &ast.Field{Type: typeExpr(p, f.Type(), typeArg)} if !f.Anonymous() { fields[i].Names = []*ast.Ident{ast.NewIdent(f.Name())} } @@ -44,13 +47,13 @@ func typeExpr(p *packages.Package, typ types.Type) ast.Expr { } return &ast.StructType{Fields: &ast.FieldList{List: fields}} case *types.Pointer: - return &ast.StarExpr{X: typeExpr(p, t.Elem())} + return &ast.StarExpr{X: typeExpr(p, t.Elem(), typeArg)} case *types.Interface: if t.Empty() { return ast.NewIdent("any") } case *types.Signature: - return newFuncType(p, t) + return newFuncType(p, t, typeArg) case *types.Named: obj := t.Obj() name := ast.NewIdent(obj.Name()) @@ -70,7 +73,7 @@ func typeExpr(p *packages.Package, typ types.Type) ast.Expr { if typeArgs := t.TypeArgs(); typeArgs != nil { indices := make([]ast.Expr, typeArgs.Len()) for i := range indices { - indices[i] = typeExpr(p, typeArgs.At(i)) + indices[i] = typeExpr(p, typeArgs.At(i), typeArg) } namedExpr = &ast.IndexListExpr{ X: namedExpr, @@ -81,7 +84,7 @@ func typeExpr(p *packages.Package, typ types.Type) ast.Expr { case *types.Chan: c := &ast.ChanType{ - Value: typeExpr(p, t.Elem()), + Value: typeExpr(p, t.Elem(), typeArg), } switch t.Dir() { case types.SendRecv: @@ -94,6 +97,9 @@ func typeExpr(p *packages.Package, typ types.Type) ast.Expr { return c case *types.TypeParam: + if typeArg != nil { + return typeExpr(p, typeArg(t), typeArg) + } obj := t.Obj() ident := ast.NewIdent(obj.Name()) p.TypesInfo.Defs[ident] = obj @@ -102,25 +108,102 @@ func typeExpr(p *packages.Package, typ types.Type) ast.Expr { panic(fmt.Sprintf("not implemented: %T", typ)) } -func newFuncType(p *packages.Package, signature *types.Signature) *ast.FuncType { +func newFuncType(p *packages.Package, signature *types.Signature, typeArg func(*types.TypeParam) types.Type) *ast.FuncType { return &ast.FuncType{ - Params: newFieldList(p, signature.Params()), - Results: newFieldList(p, signature.Results()), + Params: newFieldList(p, signature.Params(), typeArg), + Results: newFieldList(p, signature.Results(), typeArg), } } -func newFieldList(p *packages.Package, tuple *types.Tuple) *ast.FieldList { +func newFieldList(p *packages.Package, tuple *types.Tuple, typeArg func(*types.TypeParam) types.Type) *ast.FieldList { return &ast.FieldList{ - List: newFields(p, tuple), + List: newFields(p, tuple, typeArg), } } -func newFields(p *packages.Package, tuple *types.Tuple) []*ast.Field { +func newFields(p *packages.Package, tuple *types.Tuple, typeArg func(*types.TypeParam) types.Type) []*ast.Field { fields := make([]*ast.Field, tuple.Len()) for i := range fields { fields[i] = &ast.Field{ - Type: typeExpr(p, tuple.At(i).Type()), + Type: typeExpr(p, tuple.At(i).Type(), typeArg), } } return fields } + +// substituteTypeArgs replaces all type parameter placeholders +// with type args. +// +// It returns a deep copy of the input expr. +func substituteTypeArgs(p *packages.Package, expr ast.Expr, typeArg func(*types.TypeParam) types.Type) ast.Expr { + if expr == nil { + return nil + } + switch e := expr.(type) { + case *ast.ArrayType: + return &ast.ArrayType{ + Elt: substituteTypeArgs(p, e.Elt, typeArg), + Len: substituteTypeArgs(p, e.Len, typeArg), + } + case *ast.MapType: + return &ast.MapType{ + Key: substituteTypeArgs(p, e.Key, typeArg), + Value: substituteTypeArgs(p, e.Value, typeArg), + } + case *ast.FuncType: + return &ast.FuncType{ + TypeParams: substituteFieldList(p, e.TypeParams, typeArg), + Params: substituteFieldList(p, e.Params, typeArg), + Results: substituteFieldList(p, e.Results, typeArg), + } + case *ast.ChanType: + return &ast.ChanType{ + Dir: e.Dir, + Value: substituteTypeArgs(p, e.Value, typeArg), + } + case *ast.StructType: + return &ast.StructType{ + Fields: substituteFieldList(p, e.Fields, typeArg), + } + case *ast.StarExpr: + return &ast.StarExpr{ + X: substituteTypeArgs(p, e.X, typeArg), + } + case *ast.SelectorExpr: + return &ast.SelectorExpr{ + X: substituteTypeArgs(p, e.X, typeArg), + Sel: e.Sel, + } + case *ast.IndexExpr: + return &ast.IndexExpr{ + X: substituteTypeArgs(p, e.X, typeArg), + Index: substituteTypeArgs(p, e.Index, typeArg), + } + case *ast.Ident: + t := p.TypesInfo.TypeOf(e) + tp, ok := t.(*types.TypeParam) + if !ok { + return e + } + return typeExpr(p, typeArg(tp), typeArg) + case *ast.BasicLit: + return e + default: + panic(fmt.Sprintf("not implemented: %T", e)) + } +} + +func substituteFieldList(p *packages.Package, f *ast.FieldList, typeArg func(*types.TypeParam) types.Type) *ast.FieldList { + if f == nil || f.List == nil { + return f + } + fields := make([]*ast.Field, len(f.List)) + for i, field := range f.List { + fields[i] = &ast.Field{ + Names: field.Names, + Type: substituteTypeArgs(p, field.Type, typeArg), + Tag: field.Tag, + } + } + return &ast.FieldList{List: fields} +}