Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvement: Handle select instructions #235

Merged
merged 6 commits into from
Dec 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 41 additions & 2 deletions internal/pkg/levee/propagation/propagation.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ func Dfs(n ssa.Node, conf *config.Config, taggedFields fieldtags.ResultType) Pro
}
maxInstrReached := map[*ssa.BasicBlock]int{}

record.visitReferrers(n, maxInstrReached, nil)
record.dfs(n, maxInstrReached, nil, false)
record.visitReferrers(n, maxInstrReached, nil)
PurelyApplied marked this conversation as resolved.
Show resolved Hide resolved

return record
}
Expand Down Expand Up @@ -194,6 +194,9 @@ func (prop *Propagation) visit(n ssa.Node, maxInstrReached map[*ssa.BasicBlock]i
case *ssa.MapUpdate:
prop.dfs(t.Map.(ssa.Node), maxInstrReached, lastBlockVisited, false)

case *ssa.Select:
prop.visitSelect(t, maxInstrReached, lastBlockVisited)

// The only Operand that can be tainted by a Send is the Chan.
// The Value can propagate taint to the Chan, but not receive it.
// Send has no referrers, it is only an Instruction, not a Value.
Expand All @@ -214,7 +217,7 @@ func (prop *Propagation) visit(n ssa.Node, maxInstrReached map[*ssa.BasicBlock]i
prop.visitOperands(n, maxInstrReached, lastBlockVisited)

// These nodes are both Instructions and Values, and currently have no special restrictions.
case *ssa.MakeInterface, *ssa.Select, *ssa.Slice, *ssa.TypeAssert, *ssa.UnOp:
case *ssa.MakeInterface, *ssa.Slice, *ssa.TypeAssert, *ssa.UnOp:
prop.visitReferrers(n, maxInstrReached, lastBlockVisited)
prop.visitOperands(n, maxInstrReached, lastBlockVisited)

Expand Down Expand Up @@ -252,6 +255,42 @@ func (prop *Propagation) visitOperands(n ssa.Node, maxInstrReached map[*ssa.Basi
}
}

func (prop *Propagation) visitSelect(sel *ssa.Select, maxInstrReached map[*ssa.BasicBlock]int, lastBlockVisited *ssa.BasicBlock) {
// Select returns a tuple whose first 2 elements are irrelevant for our
// analysis. Subsequent elements correspond to Recv states, which map
// 1:1 with Extracts.
// See the ssa package code for more details.
recvIndex := 0
extractIndex := map[*ssa.SelectState]int{}
for _, ss := range sel.States {
if ss.Dir == types.RecvOnly {
extractIndex[ss] = recvIndex + 2
recvIndex++
}
}

for _, s := range sel.States {
switch {
// If the sent value (Send) is tainted, propagate taint to the channel
case s.Dir == types.SendOnly && prop.marked[s.Send.(ssa.Node)]:
prop.dfs(s.Chan.(ssa.Node), maxInstrReached, lastBlockVisited, false)

// If the channel is tainted, propagate taint to the appropriate Extract
case s.Dir == types.RecvOnly && prop.marked[s.Chan.(ssa.Node)]:
if sel.Referrers() == nil {
continue
}
for _, r := range *sel.Referrers() {
e, ok := r.(*ssa.Extract)
if !ok || e.Index != extractIndex[s] {
continue
}
prop.dfs(e, maxInstrReached, lastBlockVisited, false)
}
}
}
}

func (prop *Propagation) canReach(start *ssa.BasicBlock, dest *ssa.BasicBlock) bool {
if start.Dominates(dest) {
return true
Expand Down
80 changes: 80 additions & 0 deletions internal/pkg/levee/testdata/src/example.com/tests/select/tests.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,83 @@ func TestTaintedInForkedCall(objects chan interface{}) {
func PutSource(objects chan<- interface{}) {
objects <- core.Source{}
}

func TestRecvFromTaintedAndNonTaintedChans(sources <-chan *core.Source, innocs <-chan *core.Innocuous) {
select {
case s := <-sources:
core.Sink(sources) // want "a source has reached a sink"
core.Sink(s) // want "a source has reached a sink"
case i := <-innocs:
core.Sink(innocs)
core.Sink(i)
}
core.Sink(sources) // want "a source has reached a sink"
core.Sink(innocs)
}

func TestSendOnTaintedAndNonTaintedChans(i1 chan<- interface{}, i2 chan<- interface{}) {
select {
case i1 <- core.Source{}:
core.Sink(i1) // want "a source has reached a sink"
case i2 <- core.Innocuous{}:
core.Sink(i2)
}
core.Sink(i1) // want "a source has reached a sink"
core.Sink(i2)
}

func TestDaisyChain(srcs chan core.Source, i1, i2, i3, i4 chan interface{}) {
select {
case s := <-srcs:
i1 <- s
case z := <-i1:
i2 <- z
case y := <-i2:
i3 <- y
case x := <-i3:
i4 <- x
}
core.Sink(srcs) // want "a source has reached a sink"
core.Sink(i1) // want "a source has reached a sink"
core.Sink(i2)
core.Sink(i3)
core.Sink(i4)
}

func TestDaisyChainCasesInReverseOrder(srcs chan core.Source, i1, i2, i3, i4 chan interface{}) {
select {
case x := <-i3:
i4 <- x
case y := <-i2:
i3 <- y
case z := <-i1:
i2 <- z
case s := <-srcs:
i1 <- s
}
core.Sink(srcs) // want "a source has reached a sink"
core.Sink(i1) // want "a source has reached a sink"
core.Sink(i2)
core.Sink(i3)
core.Sink(i4)
}

func TestDaisyChainInLoop(srcs chan core.Source, i1, i2, i3, i4 chan interface{}) {
for i := 0; i < 4; i++ {
select {
case s := <-srcs:
i1 <- s
case z := <-i1:
i2 <- z
case y := <-i2:
i3 <- y
case x := <-i3:
i4 <- x
}
}
core.Sink(srcs) // want "a source has reached a sink"
core.Sink(i1) // want "a source has reached a sink"
core.Sink(i2) // want "a source has reached a sink"
core.Sink(i3) // want "a source has reached a sink"
core.Sink(i4) // want "a source has reached a sink"
}