Skip to content

Commit

Permalink
Make sure underscore idents added preserve type info
Browse files Browse the repository at this point in the history
  • Loading branch information
chriso committed Dec 14, 2023
1 parent 1f6b17d commit 2aefa02
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 8 deletions.
37 changes: 29 additions & 8 deletions compiler/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ func (scope *scope) compileFuncDecl(p *packages.Package, fn *ast.FuncDecl, color
// Generate the coroutine function. At this stage, use the same name
// as the source function (and require that the caller use build tags
// to disambiguate function calls).
fnType := funcTypeWithNamedResults(fn.Type)
fnType := funcTypeWithNamedResults(p, fn)
gen := &ast.FuncDecl{
Recv: fn.Recv,
Doc: &ast.CommentGroup{},
Expand Down Expand Up @@ -473,7 +473,7 @@ func (scope *scope) compileFuncLit(p *packages.Package, fn *ast.FuncLit, color *
log.Printf("compiling function literal %s", p.Name)

gen := &ast.FuncLit{
Type: funcTypeWithNamedResults(fn.Type),
Type: funcTypeWithNamedResults(p, fn),
Body: scope.compileFuncBody(p, fn.Type, fn.Body, nil, color),
}

Expand Down Expand Up @@ -710,21 +710,42 @@ func isExpr(body *ast.BlockStmt) bool {
return false
}

func funcTypeWithNamedResults(t *ast.FuncType) *ast.FuncType {
func funcTypeWithNamedResults(p *packages.Package, n ast.Node) *ast.FuncType {
t := functionTypeOf(n)
signature := functionSignatureOf(p, n)
if signature == nil {
panic("missing type info for func decl or lit")
}
if t.Results == nil {
return t
}
underscore := ast.NewIdent("_")
funcType := *t
funcType.Results = &ast.FieldList{
List: slices.Clone(t.Results.List),
}
resultTypes := signature.Results()
if resultTypes == nil || resultTypes.Len() == 0 {
panic("result type count mismatch")
}
typePos := 0
for i, f := range t.Results.List {
if len(f.Names) == 0 {
field := *f
field.Names = []*ast.Ident{underscore}
funcType.Results.List[i] = &field
if len(f.Names) > 0 {
typePos += len(f.Names)
continue
}
if typePos >= resultTypes.Len() {
panic("result type count mismatch")
}
t := resultTypes.At(typePos)
underscore := ast.NewIdent("_")
p.TypesInfo.Defs[underscore] = t
field := *f
field.Names = []*ast.Ident{underscore}
funcType.Results.List[i] = &field
typePos++
}
if typePos != resultTypes.Len() {
panic("result type count mismatch")
}
return &funcType
}
11 changes: 11 additions & 0 deletions compiler/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,17 @@ func functionTypeOf(fn ast.Node) *ast.FuncType {
}
}

func functionSignatureOf(p *packages.Package, fn ast.Node) *types.Signature {
switch f := fn.(type) {
case *ast.FuncDecl:
return p.TypesInfo.Defs[f.Name].Type().(*types.Signature)
case *ast.FuncLit:
return p.TypesInfo.TypeOf(f).(*types.Signature)
default:
panic("node is neither *ast.FuncDecl or *ast.FuncLit")
}
}

func functionRecvOf(fn ast.Node) *ast.FieldList {
switch f := fn.(type) {
case *ast.FuncDecl:
Expand Down

0 comments on commit 2aefa02

Please sign in to comment.