Skip to content

Commit

Permalink
coroc: support generics (#125)
Browse files Browse the repository at this point in the history
This PR adds initial support for generics.

The guide at
https://github.com/golang/proposal/blob/master/design/generics-implementation-dictionaries-go1.18.md
has been helpful in understanding how the Go compiler implements
generics (via a combination of stenciling and dictionaries).

Fixes #123.
  • Loading branch information
chriso authored Dec 15, 2023
2 parents b379c8b + be88525 commit edd7b06
Show file tree
Hide file tree
Showing 8 changed files with 667 additions and 92 deletions.
41 changes: 31 additions & 10 deletions compiler/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ func (scope *scope) compileFuncDecl(p *packages.Package, fn *ast.FuncDecl, color
// Generate the coroutine function. At this stage, use the same name
// as the source function (and require that the caller use build tags
// to disambiguate function calls).
fnType := funcTypeWithNamedResults(fn.Type)
fnType := funcTypeWithNamedResults(p, fn)
gen := &ast.FuncDecl{
Recv: fn.Recv,
Doc: &ast.CommentGroup{},
Expand Down Expand Up @@ -473,7 +473,7 @@ func (scope *scope) compileFuncLit(p *packages.Package, fn *ast.FuncLit, color *
log.Printf("compiling function literal %s", p.Name)

gen := &ast.FuncLit{
Type: funcTypeWithNamedResults(fn.Type),
Type: funcTypeWithNamedResults(p, fn),
Body: scope.compileFuncBody(p, fn.Type, fn.Body, nil, color),
}

Expand Down Expand Up @@ -534,8 +534,8 @@ func (scope *scope) compileFuncBody(p *packages.Package, typ *ast.FuncType, body
ctx := ast.NewIdent("_c")

yieldTypeExpr := make([]ast.Expr, 2)
yieldTypeExpr[0] = typeExpr(p, color.Params().At(0).Type())
yieldTypeExpr[1] = typeExpr(p, color.Results().At(0).Type())
yieldTypeExpr[0] = typeExpr(p, color.Params().At(0).Type(), nil)
yieldTypeExpr[1] = typeExpr(p, color.Results().At(0).Type(), nil)

coroutineIdent := ast.NewIdent("coroutine")
p.TypesInfo.Uses[coroutineIdent] = types.NewPkgName(token.NoPos, p.Types, "coroutine", scope.compiler.coroutinePkg.Types)
Expand Down Expand Up @@ -710,21 +710,42 @@ func isExpr(body *ast.BlockStmt) bool {
return false
}

func funcTypeWithNamedResults(t *ast.FuncType) *ast.FuncType {
func funcTypeWithNamedResults(p *packages.Package, n ast.Node) *ast.FuncType {
t := functionTypeOf(n)
signature := functionSignatureOf(p, n)
if signature == nil {
panic("missing type info for func decl or lit")
}
if t.Results == nil {
return t
}
underscore := ast.NewIdent("_")
funcType := *t
funcType.Results = &ast.FieldList{
List: slices.Clone(t.Results.List),
}
resultTypes := signature.Results()
if resultTypes == nil || resultTypes.Len() == 0 {
panic("result type count mismatch")
}
typePos := 0
for i, f := range t.Results.List {
if len(f.Names) == 0 {
field := *f
field.Names = []*ast.Ident{underscore}
funcType.Results.List[i] = &field
if len(f.Names) > 0 {
typePos += len(f.Names)
continue
}
if typePos >= resultTypes.Len() {
panic("result type count mismatch")
}
t := resultTypes.At(typePos)
underscore := ast.NewIdent("_")
p.TypesInfo.Defs[underscore] = t
field := *f
field.Names = []*ast.Ident{underscore}
funcType.Results.List[i] = &field
typePos++
}
if typePos != resultTypes.Len() {
panic("result type count mismatch")
}
return &funcType
}
21 changes: 13 additions & 8 deletions compiler/coroutine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,23 +220,28 @@ func TestCoroutineYield(t *testing.T) {
coro: func() { StructClosure(3) },
yields: []int{-1, 10, 100, 1000, 11, 101, 1000, 12, 102, 1000},
},

{
name: "identity generic",
coro: func() { IdentityGeneric[int](11) },
name: "generic function",
coro: func() { IdentityGenericInt(11) },
yields: []int{11},
},

{
name: "identity generic (2)",
coro: func() { IdentityGenericInt(11) },
name: "generic struct",
coro: func() { IdentityGenericStructInt(11) },
yields: []int{11},
},

{
name: "identity generic (3)",
coro: func() { IdentityGenericStructInt(11) },
yields: []int{11},
name: "generic function closure",
coro: func() { IdentityGenericClosureInt(11) },
yields: []int{11, 11},
},

{
name: "generic struct closure",
coro: func() { IdentityGenericStructClosureInt(11) },
yields: []int{-1, 11, 100, 23, 12, 101, 45},
},
}

Expand Down
4 changes: 2 additions & 2 deletions compiler/decls.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func extractDecls(p *packages.Package, typ *ast.FuncType, body *ast.BlockStmt, r
} else {
for _, spec := range n.Specs {
valueSpec := spec.(*ast.ValueSpec)
valueType := typeExpr(p, info.TypeOf(valueSpec.Names[0]))
valueType := typeExpr(p, info.TypeOf(valueSpec.Names[0]), nil)
for _, ident := range valueSpec.Names {
if ident.Name != "_" {
frameType.Fields.List = append(frameType.Fields.List, &ast.Field{
Expand Down Expand Up @@ -132,7 +132,7 @@ func extractDecls(p *packages.Package, typ *ast.FuncType, body *ast.BlockStmt, r
}
frameType.Fields.List = append(frameType.Fields.List, &ast.Field{
Names: []*ast.Ident{name},
Type: typeExpr(p, t),
Type: typeExpr(p, t, nil),
})
}
}
Expand Down
4 changes: 2 additions & 2 deletions compiler/desugar.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ func (d *desugarer) desugar(stmt ast.Stmt, breakTo, continueTo, userLabel *ast.I
&ast.CallExpr{
Fun: d.builtin("make"),
Args: []ast.Expr{
typeExpr(d.pkg, keySliceType),
typeExpr(d.pkg, keySliceType, nil),
&ast.BasicLit{Kind: token.INT, Value: "0"},
&ast.CallExpr{Fun: d.builtin("len"), Args: []ast.Expr{x}},
},
Expand Down Expand Up @@ -476,7 +476,7 @@ func (d *desugarer) desugar(stmt ast.Stmt, breakTo, continueTo, userLabel *ast.I
Specs: []ast.Spec{
&ast.ValueSpec{
Names: []*ast.Ident{tmpLhs},
Type: typeExpr(d.pkg, lhsType),
Type: typeExpr(d.pkg, lhsType, nil),
},
},
}})
Expand Down
Loading

0 comments on commit edd7b06

Please sign in to comment.