Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: Unify traversal through Field and FieldAddr instructions #201

Merged
merged 8 commits into from
Dec 22, 2020
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions internal/pkg/fieldpropagator/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,8 @@ func analyzeResults(pass *analysis.Pass, conf *config.Config, tf fieldtags.Resul
continue
}

deref := utils.Dereference(fa.X.Type())
path, typeName := utils.DecomposeType(deref)
fieldName := utils.FieldName(fa)
if conf.IsSourceField(path, typeName, fieldName) || tf.IsSourceFieldAddr(fa) {
xt, field := fa.X.Type(), fa.Field
if conf.IsSourceField(utils.DecomposeField(xt, field)) || tf.IsSourceField(xt, field) {
pass.ExportObjectFact(meth.Object(), &isFieldPropagator{})
}
}
Expand Down
12 changes: 7 additions & 5 deletions internal/pkg/fieldtags/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ import (
"reflect"

"github.com/google/go-flow-levee/internal/pkg/config"
"github.com/google/go-flow-levee/internal/pkg/utils"
"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/analysis/passes/inspect"
"golang.org/x/tools/go/ast/inspector"
"golang.org/x/tools/go/ssa"
)

// ResultType is a map from types.Object to bool.
Expand Down Expand Up @@ -89,11 +89,13 @@ func run(pass *analysis.Pass) (interface{}, error) {
return ResultType(result), nil
}

// IsSourceFieldAddr determines whether a ssa.FieldAddr is a source, that is whether it refers to a field previously identified as a source.
func (r ResultType) IsSourceFieldAddr(fa *ssa.FieldAddr) bool {
// IsSourceField determines whether a field on a type is a source field,
// using the type of the struct holding the field as well as the index
// of the field.
func (r ResultType) IsSourceField(t types.Type, field int) bool {
// incantation plundered from the docstring for ssa.FieldAddr.Field
field := fa.X.Type().Underlying().(*types.Pointer).Elem().Underlying().(*types.Struct).Field(fa.Field)
return r.IsSource(field)
fieldVar := utils.Dereference(t).Underlying().(*types.Struct).Field(field)
return r.IsSource(fieldVar)
}

// IsSource determines whether a types.Var is a source, that is whether it refers to a field previously identified as a source.
Expand Down
22 changes: 13 additions & 9 deletions internal/pkg/levee/propagation/propagation.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,15 +162,11 @@ func (prop *Propagation) visit(n ssa.Node, maxInstrReached map[*ssa.BasicBlock]i
}
}

case *ssa.Field:
prop.visitField(n, maxInstrReached, lastBlockVisited, t.X.Type(), t.Field)

case *ssa.FieldAddr:
deref := utils.Dereference(t.X.Type())
typPath, typName := utils.DecomposeType(deref)
fieldName := utils.FieldName(t)
if !prop.config.IsSourceField(typPath, typName, fieldName) && !prop.taggedFields.IsSourceFieldAddr(t) {
return
}
prop.visitReferrers(n, maxInstrReached, lastBlockVisited)
prop.visitOperands(n, maxInstrReached, lastBlockVisited)
prop.visitField(n, maxInstrReached, lastBlockVisited, t.X.Type(), t.Field)

// Everything but the actual integer Index should be visited.
case *ssa.Index:
Expand Down Expand Up @@ -213,7 +209,7 @@ func (prop *Propagation) visit(n ssa.Node, maxInstrReached map[*ssa.BasicBlock]i
prop.visitOperands(n, maxInstrReached, lastBlockVisited)

// These nodes are both Instructions and Values, and currently have no special restrictions.
case *ssa.Field, *ssa.MakeInterface, *ssa.Select, *ssa.Slice, *ssa.TypeAssert, *ssa.UnOp:
case *ssa.MakeInterface, *ssa.Select, *ssa.Slice, *ssa.TypeAssert, *ssa.UnOp:
prop.visitReferrers(n, maxInstrReached, lastBlockVisited)
prop.visitOperands(n, maxInstrReached, lastBlockVisited)

Expand All @@ -225,6 +221,14 @@ func (prop *Propagation) visit(n ssa.Node, maxInstrReached map[*ssa.BasicBlock]i
}
}

func (prop *Propagation) visitField(n ssa.Node, maxInstrReached map[*ssa.BasicBlock]int, lastBlockVisited *ssa.BasicBlock, t types.Type, field int) {
if !prop.config.IsSourceField(utils.DecomposeField(t, field)) && !prop.taggedFields.IsSourceField(t, field) {
return
}
prop.visitReferrers(n, maxInstrReached, lastBlockVisited)
prop.visitOperands(n, maxInstrReached, lastBlockVisited)
}

func (prop *Propagation) visitReferrers(n ssa.Node, maxInstrReached map[*ssa.BasicBlock]int, lastBlockVisited *ssa.BasicBlock) {
if n.Referrers() == nil {
return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ func TestDirectFieldAccess(c *core.Source) {
core.Sinkf("ID: %v", c.ID)
}

func TestInlinedDirectFieldAccess() {
// This pattern is unlikely to occur in real code.
// The intent is to get Field instructions in the SSA
// so that we can validate that those are handled correctly.
core.Sinkf("Data: %v", core.Source{}.Data) // want "a source has reached a sink"
mlevesquedion marked this conversation as resolved.
Show resolved Hide resolved
core.Sinkf("ID: %v", core.Source{}.ID)
}

func TestProtoStyleFieldAccessorSanitizedPII(c *core.Source) {
core.Sinkf("Source data: %v", core.Sanitize(c.GetData()))
}
Expand Down
24 changes: 11 additions & 13 deletions internal/pkg/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,6 @@ func Dereference(t types.Type) types.Type {
}
}

// FieldName returns the name of the field identified by the FieldAddr.
// It is the responsibility of the caller to ensure that the returned value is a non-empty string.
func FieldName(fa *ssa.FieldAddr) string {
// fa.Type() refers to the accessed field's type.
// fa.X.Type() refers to the surrounding struct's type.
d := Dereference(fa.X.Type())
st, ok := d.Underlying().(*types.Struct)
if !ok {
return ""
}
return st.Field(fa.Field).Name()
}

// DecomposeType returns the path and name of a Named type
// Returns empty strings if the type is not *types.Named
func DecomposeType(t types.Type) (path, name string) {
Expand All @@ -62,6 +49,17 @@ func DecomposeType(t types.Type) (path, name string) {
return path, n.Obj().Name()
}

// DecomposeField returns the decomposed type of the
// struct containing the field, as well as the field's name.
// If the referenced struct's type is not a named type,
// the type path and name will both be empty strings.
func DecomposeField(t types.Type, field int) (typePath, typeName, fieldName string) {
deref := Dereference(t)
typePath, typeName = DecomposeType(deref)
fieldName = deref.Underlying().(*types.Struct).Field(field).Name()
return
}

func unqualifiedName(v *types.Var) string {
packageQualifiedName := v.Type().String()
dotPos := strings.LastIndexByte(packageQualifiedName, '.')
Expand Down
45 changes: 31 additions & 14 deletions internal/pkg/utils/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,38 +99,55 @@ func TestDereference(t *testing.T) {
}
}

func TestFieldName(t *testing.T) {
func TestDecomposeField(t *testing.T) {
dir := analysistest.TestData()

testCases := []struct {
pattern string
want string
pattern string
typePath string
typeName string
fieldName string
}{
{
pattern: "regular",
want: "name",
pattern: "regular",
typePath: "fields/regular",
typeName: "foo",
fieldName: "name",
},
{
pattern: "embedded",
want: "foo",
pattern: "embedded",
typePath: "fields/embedded",
typeName: "bar",
fieldName: "foo",
},
}

for _, tt := range testCases {
t.Run(tt.pattern, func(t *testing.T) {
r := analysistest.Run(t, dir, testAnalyzer, fmt.Sprintf("fieldname/%s", tt.pattern))
r := analysistest.Run(t, dir, testAnalyzer, fmt.Sprintf("fields/%s", tt.pattern))

if len(r) != 1 {
t.Fatalf("Got len(result) == %d, want 1", len(r))
}

a, ok := r[0].Result.(testAnalyzerResult)
if !ok {
t.Fatalf("Got result of type %T, wanted testAnalyzerResult", a)
if r[0].Err != nil {
t.Fatalf("Got unexpected error: %s", r[0].Err)
}

got := FieldName(a.fieldAddr[0])
if got != tt.want {
t.Fatalf("Got %s, want %s", got, tt.want)
res := r[0].Result.(testAnalyzerResult)

fa := res.fieldAddr[0]
fmt.Println(fa.Pos())

typePath, typeName, fieldName := DecomposeField(fa.X.Type(), fa.Field)
if typePath != tt.typePath {
t.Fatalf("Got typePath %s, want %s", typePath, tt.typePath)
}
if typeName != tt.typeName {
t.Fatalf("Got typeName %s, want %s", typeName, tt.typeName)
}
if fieldName != tt.fieldName {
t.Fatalf("Got fieldName %s, want %s", fieldName, tt.fieldName)
}
})
}
Expand Down