Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🧹 Support controls implemented by controls #676

Merged
merged 1 commit into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 112 additions & 13 deletions policy/framework.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,20 @@ import (
"go.mondoo.com/cnquery/sortx"
)

type ResolvedFrameworkNodeType int

const (
ResolvedFrameworkNodeTypeFramework ResolvedFrameworkNodeType = iota
ResolvedFrameworkNodeTypeControl
ResolvedFrameworkNodeTypePolicy
ResolvedFrameworkNodeTypeCheck
)

type ResolvedFrameworkNode struct {
Mrn string
Type ResolvedFrameworkNodeType
}

type ResolvedFramework struct {
Mrn string
GraphContentChecksum string
Expand All @@ -24,6 +38,7 @@ type ResolvedFramework struct {
// E.g. ReportSources[controlA] = [check123, check45]
// E.g. ReportSources[frameworkX] = [controlA, ...]
ReportSources map[string][]string
Nodes map[string]ResolvedFrameworkNode
}

// Compile takes a framework and prepares it to be stored and further
Expand Down Expand Up @@ -477,7 +492,7 @@ func (c *ControlMap) refreshMRNs(ownerMRN string, cache *bundleCache) error {
}
control.Mrn, ok = cache.uid2mrn[control.Uid]
if !ok {
return errors.New("cannot find policy '" + control.Uid + "' in this bundle, which is referenced by control " + c.Mrn)
return errors.New("cannot find control '" + control.Uid + "' in this bundle, which is referenced by control " + c.Mrn)
}
control.Uid = ""
}
Expand All @@ -490,20 +505,38 @@ func ResolveFramework(mrn string, frameworks map[string]*Framework) *ResolvedFra
Mrn: mrn,
ReportTargets: map[string][]string{},
ReportSources: map[string][]string{},
Nodes: map[string]ResolvedFrameworkNode{},
}

for _, framework := range frameworks {
for i := range framework.FrameworkMaps {
fmap := framework.FrameworkMaps[i]

for _, ctl := range fmap.Controls {
res.addReportLink(framework.Mrn, ctl.Mrn)
res.addReportLink(
ResolvedFrameworkNode{
Mrn: framework.Mrn,
Type: ResolvedFrameworkNodeTypeFramework,
},
ResolvedFrameworkNode{
Mrn: ctl.Mrn,
Type: ResolvedFrameworkNodeTypeControl,
})
res.addControl(ctl)
}
}
// FIXME: why do these not show up in the framework map
for _, depFramework := range framework.Dependencies {
res.addReportLink(framework.Mrn, depFramework.Mrn)
res.addReportLink(
ResolvedFrameworkNode{
Mrn: framework.Mrn,
Type: ResolvedFrameworkNodeTypeFramework,
},
ResolvedFrameworkNode{
Mrn: depFramework.Mrn,
Type: ResolvedFrameworkNodeTypeFramework,
},
)
}
}

Expand All @@ -512,28 +545,94 @@ func ResolveFramework(mrn string, frameworks map[string]*Framework) *ResolvedFra

func (r *ResolvedFramework) addControl(control *ControlMap) {
for i := range control.Checks {
r.addReportLink(control.Mrn, control.Checks[i].Mrn)
r.addReportLink(
ResolvedFrameworkNode{
Mrn: control.Mrn,
Type: ResolvedFrameworkNodeTypeControl,
},
ResolvedFrameworkNode{
Mrn: control.Checks[i].Mrn,
Type: ResolvedFrameworkNodeTypeCheck,
},
)
}
for i := range control.Policies {
r.addReportLink(control.Mrn, control.Policies[i].Mrn)
r.addReportLink(
ResolvedFrameworkNode{
Mrn: control.Mrn,
Type: ResolvedFrameworkNodeTypeControl,
},
ResolvedFrameworkNode{
Mrn: control.Policies[i].Mrn,
Type: ResolvedFrameworkNodeTypePolicy,
},
)
}
for i := range control.Controls {
r.addReportLink(control.Mrn, control.Controls[i].Mrn)
r.addReportLink(
ResolvedFrameworkNode{
Mrn: control.Mrn,
Type: ResolvedFrameworkNodeTypeControl,
},
ResolvedFrameworkNode{
Mrn: control.Controls[i].Mrn,
Type: ResolvedFrameworkNodeTypeControl,
},
)
}
}

func (r *ResolvedFramework) addReportLink(parent, child string) {
existing, ok := r.ReportTargets[child]
func (r *ResolvedFramework) addReportLink(parent, child ResolvedFrameworkNode) {
r.Nodes[parent.Mrn] = parent
r.Nodes[child.Mrn] = child
existing, ok := r.ReportTargets[child.Mrn]
if !ok {
r.ReportTargets[child] = []string{parent}
r.ReportTargets[child.Mrn] = []string{parent.Mrn}
} else {
r.ReportTargets[child] = append(existing, parent)
r.ReportTargets[child.Mrn] = append(existing, parent.Mrn)
}

existing, ok = r.ReportSources[parent]
existing, ok = r.ReportSources[parent.Mrn]
if !ok {
r.ReportSources[parent] = []string{child}
r.ReportSources[parent.Mrn] = []string{child.Mrn}
} else {
r.ReportSources[parent] = append(existing, child)
r.ReportSources[parent.Mrn] = append(existing, child.Mrn)
}
}

func (r *ResolvedFramework) TopologicalSort() []string {
sorted := []string{}
visited := map[string]struct{}{}

nodes := make([]string, len(r.Nodes))
i := 0
for node := range r.Nodes {
nodes[i] = node
i++
}

sort.Strings(nodes)

for _, node := range nodes {
r.visit(node, visited, &sorted)
}

// reverse the list
for i := len(sorted)/2 - 1; i >= 0; i-- {
opp := len(sorted) - 1 - i
sorted[i], sorted[opp] = sorted[opp], sorted[i]
}

return sorted
}

func (r *ResolvedFramework) visit(node string, visited map[string]struct{}, sorted *[]string) {
if _, ok := visited[node]; ok {
return
}
visited[node] = struct{}{}
for _, child := range r.ReportTargets[node] {
r.visit(child, visited, sorted)
}
*sorted = append(*sorted, node)
}
54 changes: 54 additions & 0 deletions policy/frameworks_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package policy

import (
"testing"
)

func TestResolvedFrameworkTopologicalSort(t *testing.T) {
framework := &ResolvedFramework{
ReportTargets: map[string][]string{},
ReportSources: map[string][]string{},
Nodes: map[string]ResolvedFrameworkNode{},
}

framework.addReportLink(ResolvedFrameworkNode{Mrn: "z"}, ResolvedFrameworkNode{Mrn: "c"})
framework.addReportLink(ResolvedFrameworkNode{Mrn: "y"}, ResolvedFrameworkNode{Mrn: "x"})
framework.addReportLink(ResolvedFrameworkNode{Mrn: "a"}, ResolvedFrameworkNode{Mrn: "b"})
framework.addReportLink(ResolvedFrameworkNode{Mrn: "b"}, ResolvedFrameworkNode{Mrn: "c"})
framework.addReportLink(ResolvedFrameworkNode{Mrn: "c"}, ResolvedFrameworkNode{Mrn: "d"})
framework.addReportLink(ResolvedFrameworkNode{Mrn: "c"}, ResolvedFrameworkNode{Mrn: "e"})
framework.addReportLink(ResolvedFrameworkNode{Mrn: "b"}, ResolvedFrameworkNode{Mrn: "e"})

sorted := framework.TopologicalSort()

requireComesAfter(t, sorted, "z", "c")
requireComesAfter(t, sorted, "y", "x")
requireComesAfter(t, sorted, "a", "b")
requireComesAfter(t, sorted, "b", "c")
requireComesAfter(t, sorted, "c", "d")
requireComesAfter(t, sorted, "c", "e")
requireComesAfter(t, sorted, "b", "e")
}

func requireComesAfter(t *testing.T, sorted []string, a, b string) {
t.Helper()
aIdx := -1
bIdx := -1
for i, v := range sorted {
if v == a {
aIdx = i
}
if v == b {
bIdx = i
}
}
if aIdx == -1 {
t.Errorf("Expected %s to be in sorted list", a)
}
if bIdx == -1 {
t.Errorf("Expected %s to be in sorted list", b)
}
if aIdx < bIdx {
t.Errorf("Expected %s to come after %s", a, b)
}
}
100 changes: 77 additions & 23 deletions policy/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -1524,45 +1524,99 @@ func (s *LocalServices) jobsToControls(cache *frameworkResolverCache, framework
}
}

rjByMrn := map[string]*ReportingJob{}
for _, rj := range job.ReportingJobs {
query, ok := querymap[rj.QrId]
if !ok {
log.Warn().Str("mrn", framework.Mrn)
query := querymap[rj.QrId]
if query == nil {
continue
}
rjByMrn[query.Mrn] = rj
}

targets, ok := framework.ReportTargets[query.Mrn]
mrns := framework.TopologicalSort()
for _, mrn := range mrns {
node, ok := framework.Nodes[mrn]
if !ok {
continue
}
targets := framework.ReportTargets[mrn]
var curJob *ReportingJob
switch node.Type {
case ResolvedFrameworkNodeTypeCheck:
if len(targets) == 0 {
continue
}
rj, ok := rjByMrn[mrn]
if !ok {
continue
}
query, ok := querymap[rj.QrId]
if !ok {
continue
}

// Create a reporting job from the query code id to one with the mrn.
// This isn't 100% correct. We don't keep track of all the queries that
// have the same code id.
uuid := cache.relativeChecksum(query.Mrn)
queryJob := &ReportingJob{
Uuid: uuid,
QrId: query.Mrn,
ChildJobs: map[string]*explorer.Impact{},
Type: ReportingJob_CHECK,
}
nuJobs[uuid] = queryJob
// Create a reporting job from the query code id to one with the mrn.
// This isn't 100% correct. We don't keep track of all the queries that
// have the same code id.
uuid := cache.relativeChecksum(query.Mrn)
queryJob := &ReportingJob{
Uuid: uuid,
QrId: query.Mrn,
ChildJobs: map[string]*explorer.Impact{},
Type: ReportingJob_CHECK,
}
nuJobs[uuid] = queryJob

for i := range targets {
controlMrn := targets[i]
queryJob.ChildJobs[rj.Uuid] = nil
rj.Notify = append(rj.Notify, queryJob.Uuid)
continue
case ResolvedFrameworkNodeTypeControl:
// skip controls which are part of a FrameworkGroup with type DISABLE
if group, ok := frameworkGroupByControlMrn[controlMrn]; ok {
if group, ok := frameworkGroupByControlMrn[mrn]; ok {
if group.Type == GroupType_DISABLE {
continue
}
}
controlJob := ensureControlJob(cache, nuJobs, controlMrn, framework, frameworkGroupByControlMrn)

queryJob.ChildJobs[rj.Uuid] = nil
rj.Notify = append(rj.Notify, queryJob.Uuid)
// Avoid adding controls which don't have any active children
shouldAdd := false
for _, child := range framework.ReportSources[mrn] {
if _, ok := nuJobs[cache.relativeChecksum(child)]; ok {
shouldAdd = true
break
}
}

if !shouldAdd {
continue
}

controlJob := ensureControlJob(cache, nuJobs, mrn, framework, frameworkGroupByControlMrn)
// addedControlJobs[mrn] = controlJob
curJob = controlJob
case ResolvedFrameworkNodeTypeFramework:
curJob = job.ReportingJobs[cache.relativeChecksum(mrn)]
}

controlJob.ChildJobs[queryJob.Uuid] = nil
queryJob.Notify = append(queryJob.Notify, controlJob.Uuid)
// Ensure that child jobs notify their parents
for _, child := range framework.ReportSources[mrn] {
childJob, ok := nuJobs[cache.relativeChecksum(child)]
if !ok {
continue
}
if _, ok := curJob.ChildJobs[childJob.Uuid]; !ok {
curJob.ChildJobs[childJob.Uuid] = nil
}
shouldAdd := true
for _, parent := range childJob.Notify {
if parent == curJob.Uuid {
shouldAdd = false
break
}
}
if shouldAdd {
childJob.Notify = append(childJob.Notify, curJob.Uuid)
}
}
}

Expand Down
7 changes: 7 additions & 0 deletions policy/resolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,8 @@ frameworks:
title: control2
- uid: control3
title: control3
- uid: control4
title: control4
- uid: framework2
name: framework2
groups:
Expand Down Expand Up @@ -380,6 +382,9 @@ framework_maps:
checks:
- uid: check-pass-2
- uid: check-fail
- uid: control4
controls:
- uid: control1
`

t.Run("resolve with correct filters", func(t *testing.T) {
Expand Down Expand Up @@ -431,7 +436,9 @@ framework_maps:
rjTester.requireReportsTo(queryMrn("check-fail"), controlMrn("control2"))

rjTester.requireReportsTo(controlMrn("control1"), frameworkMrn("framework1"))
rjTester.requireReportsTo(controlMrn("control1"), controlMrn("control4"))
rjTester.requireReportsTo(controlMrn("control2"), frameworkMrn("framework1"))
rjTester.requireReportsTo(controlMrn("control4"), frameworkMrn("framework1"))
rjTester.requireReportsTo(frameworkMrn("framework1"), frameworkMrn("parent-framework"))
rjTester.requireReportsTo(frameworkMrn("parent-framework"), "root")
})
Expand Down