diff --git a/internal/printer/nodes.go b/internal/printer/nodes.go index 2d98fc76..c4568bac 100644 --- a/internal/printer/nodes.go +++ b/internal/printer/nodes.go @@ -321,8 +321,21 @@ func (p *printer) exprList(prev0 token.Pos, list []ast.Expr, depth int, mode exp } } -func (p *printer) parameters(fields *ast.FieldList) { - p.print(fields.Opening, token.LPAREN) +type paramMode int + +const ( + funcParam paramMode = iota + funcTParam + typeTParam +) + +func (p *printer) parameters(fields *ast.FieldList, mode paramMode) { + openTok, closeTok := token.LPAREN, token.RPAREN + if mode != funcParam { + openTok, closeTok = token.LBRACK, token.RBRACK + } + + p.print(fields.Opening, openTok) if len(fields.List) > 0 { prevLine := p.lineFor(fields.Opening) ws := indent @@ -376,36 +389,79 @@ func (p *printer) parameters(fields *ast.FieldList) { if closing := p.lineFor(fields.Closing); 0 < prevLine && prevLine < closing { p.print(token.COMMA) p.linebreak(closing, 0, ignore, true) + } else if mode == typeTParam && fields.NumFields() == 1 && combinesWithName(fields.List[0].Type) { + // A type parameter list [P T] where the name P and the type expression T syntactically + // combine to another valid (value) expression requires a trailing comma, as in [P *T,] + // (or an enclosing interface as in [P interface(*T)]), so that the type parameter list + // is not parsed as an array length [P*T]. + p.print(token.COMMA) } // unindent if we indented if ws == ignore { p.print(unindent) } } - p.print(fields.Closing, token.RPAREN) + p.print(fields.Closing, closeTok) } -func (p *printer) signature(params, result *ast.FieldList) { - if params != nil { - if params.Opening != token.NoPos { - p.parameters(params) +// combinesWithName reports whether a name followed by the expression x +// syntactically combines to another valid (value) expression. For instance +// using *T for x, "name *T" syntactically appears as the expression x*T. +// On the other hand, using P|Q or *P|~Q for x, "name P|Q" or name *P|~Q" +// cannot be combined into a valid (value) expression. +func combinesWithName(x ast.Expr) bool { + switch x := x.(type) { + case *ast.StarExpr: + // name *x.X combines to name*x.X if x.X is not a type element + return !isTypeElem(x.X) + case *ast.BinaryExpr: + return combinesWithName(x.X) && !isTypeElem(x.Y) + case *ast.ParenExpr: + // name(x) combines but we are making sure at + // the call site that x is never parenthesized. + panic("unexpected parenthesized expression") + } + return false +} + +// isTypeElem reports whether x is a (possibly parenthesized) type element expression. +// The result is false if x could be a type element OR an ordinary (value) expression. +func isTypeElem(x ast.Expr) bool { + switch x := x.(type) { + case *ast.ArrayType, *ast.StructType, *ast.FuncType, *ast.InterfaceType, *ast.MapType: + return true + case *ast.BinaryExpr: + return isTypeElem(x.X) || isTypeElem(x.Y) + case *ast.ParenExpr: + return isTypeElem(x.X) + } + return false +} + +func (p *printer) signature(sig *ast.FuncType) { + if sig.TypeParams != nil { + p.parameters(sig.TypeParams, funcTParam) + } + if sig.Params != nil { + if sig.Params.Opening != token.NoPos { + p.parameters(sig.Params, funcParam) } } else { p.print(token.LPAREN, token.RPAREN) } - n := result.NumFields() + n := sig.Results.NumFields() if n > 0 { // result != nil p.print(blank) p.print(token.ARROW) p.print(blank) - if n == 1 && result.List[0].Names == nil { + if n == 1 && sig.Results.List[0].Names == nil { // single anonymous result; no ()'s - p.expr(stripParensAlways(result.List[0].Type)) + p.expr(stripParensAlways(sig.Results.List[0].Type)) return } - p.parameters(result) + p.parameters(sig.Results, funcParam) } } @@ -479,7 +535,7 @@ func (p *printer) fieldList(fields *ast.FieldList, isStruct, isIncomplete bool) if ftyp, isFtyp := f.Type.(*ast.FuncType); isFtyp { // method p.expr(f.Names[0]) - p.signature(ftyp.Params, ftyp.Results) + p.signature(ftyp) } else { // embedded interface p.expr(f.Type) @@ -557,7 +613,7 @@ func (p *printer) fieldList(fields *ast.FieldList, isStruct, isIncomplete bool) if ftyp, isFtyp := f.Type.(*ast.FuncType); isFtyp { // method p.expr(f.Names[0]) - p.signature(ftyp.Params, ftyp.Results) + p.signature(ftyp) } else { // embedded interface p.expr(f.Type) @@ -949,7 +1005,7 @@ func (p *printer) expr1(expr ast.Expr, prec1, depth int) { case *ast.FuncType: p.print(token.FUNC) - p.signature(x.Params, x.Results) + p.signature(x) case *ast.InterfaceType: p.print(token.INTERFACE) @@ -1507,6 +1563,9 @@ func (p *printer) spec(spec ast.Spec, n int, doIndent bool) { case *ast.TypeSpec: p.setComment(s.Doc) p.expr(s.Name) + if s.TypeParams != nil { + p.parameters(s.TypeParams, typeTParam) + } if n == 1 { p.print(blank) } else { @@ -1711,12 +1770,12 @@ func (p *printer) funcDecl(d *ast.FuncDecl) { p.print(thisTypeIdent.Name) p.print(".") } else { - p.parameters(d.Recv) // method: print receiver + p.parameters(d.Recv, funcParam) // method: print receiver p.print(blank) } } p.expr(d.Name) - p.signature(d.Type.Params, d.Type.Results) + p.signature(d.Type) p.funcBody(p.distanceFrom(d.Pos()), vtab, d.Body) }