Skip to content

Commit

Permalink
Merge pull request #19 from RaduBerinde/fail-for-bad-directive
Browse files Browse the repository at this point in the history
output error if a directive is invalid
  • Loading branch information
yuzefovich authored Dec 23, 2024
2 parents eef0d55 + 634c287 commit 778fb58
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 67 deletions.
69 changes: 32 additions & 37 deletions gcassert.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func stringToDirective(s string) (assertDirective, error) {
case "noescape":
return noescape, nil
}
return noDirective, errors.New(fmt.Sprintf("no such directive %s", s))
return noDirective, errors.New(fmt.Sprintf("unknown directive %q", s))
}

// passInfo contains info on a passed directive for directives that have
Expand Down Expand Up @@ -76,35 +76,40 @@ type assertVisitor struct {
// some kind that were marked with //gcassert:inline by the user.
mustInlineFuncs map[types.Object]struct{}
fileSet *token.FileSet
cwd string

p *packages.Package

errOutput io.Writer
}

func newAssertVisitor(
commentMap ast.CommentMap,
fileSet *token.FileSet,
cwd string,
p *packages.Package,
mustInlineFuncs map[types.Object]struct{},
errOutput io.Writer,
) assertVisitor {
return assertVisitor{
commentMap: commentMap,
fileSet: fileSet,
cwd: cwd,
directiveMap: make(map[int]lineInfo),
mustInlineFuncs: mustInlineFuncs,
p: p,
errOutput: errOutput,
}
}

func (v assertVisitor) Visit(node ast.Node) (w ast.Visitor) {
func (v *assertVisitor) Visit(node ast.Node) ast.Visitor {
if node == nil {
return w
return nil
}
pos := node.Pos()
lineNumber := v.fileSet.Position(pos).Line
pos := v.fileSet.Position(node.Pos())

m := v.commentMap[node]
for _, g := range m {
COMMENT_LIST:
for _, c := range g.List {
matches := gcAssertRegex.FindStringSubmatch(c.Text)
if len(matches) == 0 {
Expand All @@ -114,28 +119,29 @@ func (v assertVisitor) Visit(node ast.Node) (w ast.Visitor) {
// gcassert directive(s).
directiveStrings := strings.Split(matches[1], ",")

lineInfo := v.directiveMap[lineNumber]
lineInfo := v.directiveMap[pos.Line]
lineInfo.n = node
for _, s := range directiveStrings {
directive, err := stringToDirective(s)
if err != nil {
printAssertionFailure(v.cwd, v.fileSet, node, v.errOutput, err.Error())
continue
}
if directive == inline {
switch n := node.(type) {
case *ast.FuncDecl:
// Add the Object that this FuncDecl's ident is connected
// to to our map of must-inline functions.
// to our map of must-inline functions.
obj := v.p.TypesInfo.Defs[n.Name]
if obj != nil {
v.mustInlineFuncs[obj] = struct{}{}
}
continue COMMENT_LIST
continue
}
}
lineInfo.directives = append(lineInfo.directives, directive)
v.directiveMap[pos.Line] = lineInfo
}
v.directiveMap[lineNumber] = lineInfo
}
}
return v
Expand All @@ -149,7 +155,7 @@ func GCAssert(w io.Writer, paths ...string) error {

// GCAssertCwd performs the same operation as GCAssert, but runs `go build` in
// the provided working directory `cwd`. If `cwd` is the empty string, then
/// `go build` will be run in the current working directory.
// `go build` will be run in the current working directory.
func GCAssertCwd(w io.Writer, cwd string, paths ...string) error {
var err error
if cwd == "" {
Expand All @@ -166,7 +172,7 @@ func GCAssertCwd(w io.Writer, cwd string, paths ...string) error {
packages.NeedTypesInfo | packages.NeedTypes,
Fset: fileSet,
}, paths...)
directiveMap, err := parseDirectives(pkgs, fileSet)
directiveMap, err := parseDirectives(pkgs, fileSet, cwd, w)
if err != nil {
return err
}
Expand Down Expand Up @@ -244,19 +250,15 @@ func GCAssertCwd(w io.Writer, cwd string, paths ...string) error {
// Print out the user's code lineNo that failed the assertion,
// the assertion itself, and the compiler output that
// proved that the assertion failed.
if err := printAssertionFailure(cwd, fileSet, info, w, message); err != nil {
return err
}
printAssertionFailure(cwd, fileSet, info.n, w, message)
}
case inline:
if strings.HasPrefix(message, "inlining call to") {
info.passedDirective[i] = true
}
case noescape:
if strings.HasSuffix(message, "escapes to heap:") {
if err := printAssertionFailure(cwd, fileSet, info, w, message); err != nil {
return err
}
printAssertionFailure(cwd, fileSet, info.n, w, message)
}
}
}
Expand Down Expand Up @@ -291,21 +293,15 @@ func GCAssertCwd(w io.Writer, cwd string, paths ...string) error {
// each inlining directive, check if there was matching compiler
// output and fail if not.
if !d.passed {
if err := printAssertionFailure(
cwd, fileSet, info, w, "call was not inlined"); err != nil {
return err
}
printAssertionFailure(cwd, fileSet, info.n, w, "call was not inlined")
}
}
for i, d := range info.directives {
if d != inline {
continue
}
if !info.passedDirective[i] {
if err := printAssertionFailure(
cwd, fileSet, info, w, "call was not inlined"); err != nil {
return err
}
printAssertionFailure(cwd, fileSet, info.n, w, "call was not inlined")
}
}
}
Expand All @@ -317,31 +313,30 @@ func GCAssertCwd(w io.Writer, cwd string, paths ...string) error {
return nil
}

func printAssertionFailure(cwd string, fileSet *token.FileSet, info lineInfo, w io.Writer, message string) error {
func printAssertionFailure(cwd string, fileSet *token.FileSet, n ast.Node, w io.Writer, message string) {
var buf strings.Builder
_ = printer.Fprint(&buf, fileSet, info.n)
pos := fileSet.Position(info.n.Pos())
_ = printer.Fprint(&buf, fileSet, n)
pos := fileSet.Position(n.Pos())
relPath, err := filepath.Rel(cwd, pos.Filename)
if err != nil {
return err
relPath = pos.Filename
}
fmt.Fprintf(w, "%s:%d:\t%s: %s\n", relPath, pos.Line, buf.String(), message)
return nil
}

// directiveMap maps filepath to line number to lineInfo
type directiveMap map[string]map[int]lineInfo

func parseDirectives(pkgs []*packages.Package, fileSet *token.FileSet) (directiveMap, error) {
func parseDirectives(pkgs []*packages.Package, fileSet *token.FileSet, cwd string, errOutput io.Writer) (directiveMap, error) {
fileDirectiveMap := make(directiveMap)
mustInlineFuncs := make(map[types.Object]struct{})
for _, pkg := range pkgs {
for i, file := range pkg.Syntax {
commentMap := ast.NewCommentMap(fileSet, file, file.Comments)

v := newAssertVisitor(commentMap, fileSet, pkg, mustInlineFuncs)
v := newAssertVisitor(commentMap, fileSet, cwd, pkg, mustInlineFuncs, errOutput)
// First: find all lines of code annotated with our gcassert directives.
ast.Walk(v, file)
ast.Walk(&v, file)

file := pkg.CompiledGoFiles[i]
if len(v.directiveMap) > 0 {
Expand All @@ -353,7 +348,7 @@ func parseDirectives(pkgs []*packages.Package, fileSet *token.FileSet) (directiv
// Do another pass to find all callsites of funcs marked with inline.
for _, pkg := range pkgs {
for i, file := range pkg.Syntax {
v := &inlinedDeclVisitor{assertVisitor: newAssertVisitor(nil, fileSet, pkg, mustInlineFuncs)}
v := &inlinedDeclVisitor{assertVisitor: newAssertVisitor(nil, fileSet, cwd, pkg, mustInlineFuncs, errOutput)}
filePath := pkg.CompiledGoFiles[i]
v.directiveMap = fileDirectiveMap[filePath]
if v.directiveMap == nil {
Expand All @@ -372,9 +367,9 @@ type inlinedDeclVisitor struct {
assertVisitor
}

func (v *inlinedDeclVisitor) Visit(node ast.Node) (w ast.Visitor) {
func (v *inlinedDeclVisitor) Visit(node ast.Node) ast.Visitor {
if node == nil {
return w
return nil
}
pos := node.Pos()
lineNumber := v.fileSet.Position(pos).Line
Expand Down
42 changes: 31 additions & 11 deletions gcassert_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package gcassert

import (
"bytes"
"go/token"
"os"
"path/filepath"
Expand All @@ -21,15 +22,24 @@ func TestParseDirectives(t *testing.T) {
if err != nil {
t.Fatal(err)
}
absMap, err := parseDirectives(pkgs, fileSet)
cwd, err := os.Getwd()
if err != nil {
t.Fatal(err)
}

cwd, err := os.Getwd()
var errOut bytes.Buffer
absMap, err := parseDirectives(pkgs, fileSet, cwd, &errOut)
if err != nil {
t.Fatal(err)
}
assert.Equal(t, `testdata/bad_directive.go:4: //gcassert:foo
func badDirective1() {}: unknown directive "foo"
testdata/bad_directive.go:8: badDirective1(): unknown directive "bar"
testdata/bad_directive.go:12: //gcassert:inline,afterinline
func badDirective3() {
badDirective2()
}: unknown directive "afterinline"
`, errOut.String())

// Convert the map into relative paths for ease of testing, and remove
// the syntax node so we don't have to test that as well.
relMap := make(directiveMap, len(absMap))
Expand All @@ -46,6 +56,9 @@ func TestParseDirectives(t *testing.T) {
}

expectedMap := directiveMap{
"testdata/bad_directive.go": {
8: {directives: []assertDirective{bce, inline}},
},
"testdata/bce.go": {
8: {directives: []assertDirective{bce}},
11: {directives: []assertDirective{bce, inline}},
Expand All @@ -63,11 +76,11 @@ func TestParseDirectives(t *testing.T) {
58: {inlinableCallsites: []passInfo{{colNo: 35}}},
},
"testdata/noescape.go": {
21: {directives: []assertDirective{noescape}},
28: {directives: []assertDirective{noescape}},
35: {directives: []assertDirective{noescape}},
42: {directives: []assertDirective{noescape}},
45: {directives: []assertDirective{noescape}},
11: {directives: []assertDirective{noescape}},
18: {directives: []assertDirective{noescape}},
25: {directives: []assertDirective{noescape}},
33: {directives: []assertDirective{noescape}},
36: {directives: []assertDirective{noescape}},
},
"testdata/issue5.go": {
4: {inlinableCallsites: []passInfo{{colNo: 14}}},
Expand All @@ -81,15 +94,22 @@ func TestGCAssert(t *testing.T) {
if err != nil {
t.Fatal(err)
}
expectedOutput := `testdata/noescape.go:21: foo := foo{a: 1, b: 2}: foo escapes to heap:
testdata/noescape.go:35: // This annotation should fail, because f will escape to the heap.
expectedOutput := `testdata/bad_directive.go:4: //gcassert:foo
func badDirective1() {}: unknown directive "foo"
testdata/bad_directive.go:8: badDirective1(): unknown directive "bar"
testdata/bad_directive.go:12: //gcassert:inline,afterinline
func badDirective3() {
badDirective2()
}: unknown directive "afterinline"
testdata/noescape.go:11: foo := foo{a: 1, b: 2}: foo escapes to heap:
testdata/noescape.go:25: // This annotation should fail, because f will escape to the heap.
//
//gcassert:noescape
func (f foo) setA(a int) *foo {
f.a = a
return &f
}: f escapes to heap:
testdata/noescape.go:45: : a escapes to heap:
testdata/noescape.go:36: : a escapes to heap:
testdata/bce.go:8: fmt.Println(ints[5]): Found IsInBounds
testdata/bce.go:23: fmt.Println(ints[1:7]): Found IsSliceInBounds
testdata/bce.go:17: sum += notInlinable(ints[i]): call was not inlined
Expand Down
14 changes: 14 additions & 0 deletions testdata/bad_directive.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package gcassert

//gcassert:foo
func badDirective1() {}

func badDirective2() {
//gcassert:bce,bar,inline
badDirective1()
}

//gcassert:inline,afterinline
func badDirective3() {
badDirective2()
}
29 changes: 10 additions & 19 deletions testdata/noescape.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,3 @@
// Copyright 2021 The Cockroach Authors.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.

package gcassert

type foo struct {
Expand All @@ -18,32 +8,33 @@ type foo struct {
func returnsStackVarPtr() *foo {
// this should fail
//gcassert:noescape
foo := foo{a: 1, b:2}
foo := foo{a: 1, b: 2}
return &foo
}

func returnsStackVar() foo {
// this should succeed
//gcassert:noescape
foo := foo{a: 1, b:2}
foo := foo{a: 1, b: 2}
return foo
}

// This annotation should fail, because f will escape to the heap.
//
//gcassert:noescape
func (f foo) setA(a int) *foo {
f.a = a
return &f
f.a = a
return &f
}

// This annotation should pass, because f does not escape.
//
//gcassert:noescape
func (f foo) returnA(
// This annotation should fail, because a will escape to the heap.
//gcassert:noescape
a int,
b int,
// This annotation should fail, because a will escape to the heap.
//gcassert:noescape
a int,
b int,
) *int {
return &a
return &a
}

0 comments on commit 778fb58

Please sign in to comment.