diff --git a/policy/framework.go b/policy/framework.go index 92a4731a..017cebb4 100644 --- a/policy/framework.go +++ b/policy/framework.go @@ -11,6 +11,20 @@ import ( "go.mondoo.com/cnquery/sortx" ) +type ResolvedFrameworkNodeType int + +const ( + ResolvedFrameworkNodeTypeFramework ResolvedFrameworkNodeType = iota + ResolvedFrameworkNodeTypeControl + ResolvedFrameworkNodeTypePolicy + ResolvedFrameworkNodeTypeCheck +) + +type ResolvedFrameworkNode struct { + Mrn string + Type ResolvedFrameworkNodeType +} + type ResolvedFramework struct { Mrn string GraphContentChecksum string @@ -24,6 +38,7 @@ type ResolvedFramework struct { // E.g. ReportSources[controlA] = [check123, check45] // E.g. ReportSources[frameworkX] = [controlA, ...] ReportSources map[string][]string + Nodes map[string]ResolvedFrameworkNode } // Compile takes a framework and prepares it to be stored and further @@ -477,7 +492,7 @@ func (c *ControlMap) refreshMRNs(ownerMRN string, cache *bundleCache) error { } control.Mrn, ok = cache.uid2mrn[control.Uid] if !ok { - return errors.New("cannot find policy '" + control.Uid + "' in this bundle, which is referenced by control " + c.Mrn) + return errors.New("cannot find control '" + control.Uid + "' in this bundle, which is referenced by control " + c.Mrn) } control.Uid = "" } @@ -490,6 +505,7 @@ func ResolveFramework(mrn string, frameworks map[string]*Framework) *ResolvedFra Mrn: mrn, ReportTargets: map[string][]string{}, ReportSources: map[string][]string{}, + Nodes: map[string]ResolvedFrameworkNode{}, } for _, framework := range frameworks { @@ -497,13 +513,30 @@ func ResolveFramework(mrn string, frameworks map[string]*Framework) *ResolvedFra fmap := framework.FrameworkMaps[i] for _, ctl := range fmap.Controls { - res.addReportLink(framework.Mrn, ctl.Mrn) + res.addReportLink( + ResolvedFrameworkNode{ + Mrn: framework.Mrn, + Type: ResolvedFrameworkNodeTypeFramework, + }, + ResolvedFrameworkNode{ + Mrn: ctl.Mrn, + Type: ResolvedFrameworkNodeTypeControl, + }) res.addControl(ctl) } } // FIXME: why do these not show up in the framework map for _, depFramework := range framework.Dependencies { - res.addReportLink(framework.Mrn, depFramework.Mrn) + res.addReportLink( + ResolvedFrameworkNode{ + Mrn: framework.Mrn, + Type: ResolvedFrameworkNodeTypeFramework, + }, + ResolvedFrameworkNode{ + Mrn: depFramework.Mrn, + Type: ResolvedFrameworkNodeTypeFramework, + }, + ) } } @@ -512,28 +545,94 @@ func ResolveFramework(mrn string, frameworks map[string]*Framework) *ResolvedFra func (r *ResolvedFramework) addControl(control *ControlMap) { for i := range control.Checks { - r.addReportLink(control.Mrn, control.Checks[i].Mrn) + r.addReportLink( + ResolvedFrameworkNode{ + Mrn: control.Mrn, + Type: ResolvedFrameworkNodeTypeControl, + }, + ResolvedFrameworkNode{ + Mrn: control.Checks[i].Mrn, + Type: ResolvedFrameworkNodeTypeCheck, + }, + ) } for i := range control.Policies { - r.addReportLink(control.Mrn, control.Policies[i].Mrn) + r.addReportLink( + ResolvedFrameworkNode{ + Mrn: control.Mrn, + Type: ResolvedFrameworkNodeTypeControl, + }, + ResolvedFrameworkNode{ + Mrn: control.Policies[i].Mrn, + Type: ResolvedFrameworkNodeTypePolicy, + }, + ) } for i := range control.Controls { - r.addReportLink(control.Mrn, control.Controls[i].Mrn) + r.addReportLink( + ResolvedFrameworkNode{ + Mrn: control.Mrn, + Type: ResolvedFrameworkNodeTypeControl, + }, + ResolvedFrameworkNode{ + Mrn: control.Controls[i].Mrn, + Type: ResolvedFrameworkNodeTypeControl, + }, + ) } } -func (r *ResolvedFramework) addReportLink(parent, child string) { - existing, ok := r.ReportTargets[child] +func (r *ResolvedFramework) addReportLink(parent, child ResolvedFrameworkNode) { + r.Nodes[parent.Mrn] = parent + r.Nodes[child.Mrn] = child + existing, ok := r.ReportTargets[child.Mrn] if !ok { - r.ReportTargets[child] = []string{parent} + r.ReportTargets[child.Mrn] = []string{parent.Mrn} } else { - r.ReportTargets[child] = append(existing, parent) + r.ReportTargets[child.Mrn] = append(existing, parent.Mrn) } - existing, ok = r.ReportSources[parent] + existing, ok = r.ReportSources[parent.Mrn] if !ok { - r.ReportSources[parent] = []string{child} + r.ReportSources[parent.Mrn] = []string{child.Mrn} } else { - r.ReportSources[parent] = append(existing, child) + r.ReportSources[parent.Mrn] = append(existing, child.Mrn) + } +} + +func (r *ResolvedFramework) TopologicalSort() []string { + sorted := []string{} + visited := map[string]struct{}{} + + nodes := make([]string, len(r.Nodes)) + i := 0 + for node := range r.Nodes { + nodes[i] = node + i++ + } + + sort.Strings(nodes) + + for _, node := range nodes { + r.visit(node, visited, &sorted) + } + + // reverse the list + for i := len(sorted)/2 - 1; i >= 0; i-- { + opp := len(sorted) - 1 - i + sorted[i], sorted[opp] = sorted[opp], sorted[i] + } + + return sorted +} + +func (r *ResolvedFramework) visit(node string, visited map[string]struct{}, sorted *[]string) { + if _, ok := visited[node]; ok { + return + } + visited[node] = struct{}{} + for _, child := range r.ReportTargets[node] { + r.visit(child, visited, sorted) } + *sorted = append(*sorted, node) } diff --git a/policy/frameworks_test.go b/policy/frameworks_test.go new file mode 100644 index 00000000..409a3e1b --- /dev/null +++ b/policy/frameworks_test.go @@ -0,0 +1,54 @@ +package policy + +import ( + "testing" +) + +func TestResolvedFrameworkTopologicalSort(t *testing.T) { + framework := &ResolvedFramework{ + ReportTargets: map[string][]string{}, + ReportSources: map[string][]string{}, + Nodes: map[string]ResolvedFrameworkNode{}, + } + + framework.addReportLink(ResolvedFrameworkNode{Mrn: "z"}, ResolvedFrameworkNode{Mrn: "c"}) + framework.addReportLink(ResolvedFrameworkNode{Mrn: "y"}, ResolvedFrameworkNode{Mrn: "x"}) + framework.addReportLink(ResolvedFrameworkNode{Mrn: "a"}, ResolvedFrameworkNode{Mrn: "b"}) + framework.addReportLink(ResolvedFrameworkNode{Mrn: "b"}, ResolvedFrameworkNode{Mrn: "c"}) + framework.addReportLink(ResolvedFrameworkNode{Mrn: "c"}, ResolvedFrameworkNode{Mrn: "d"}) + framework.addReportLink(ResolvedFrameworkNode{Mrn: "c"}, ResolvedFrameworkNode{Mrn: "e"}) + framework.addReportLink(ResolvedFrameworkNode{Mrn: "b"}, ResolvedFrameworkNode{Mrn: "e"}) + + sorted := framework.TopologicalSort() + + requireComesAfter(t, sorted, "z", "c") + requireComesAfter(t, sorted, "y", "x") + requireComesAfter(t, sorted, "a", "b") + requireComesAfter(t, sorted, "b", "c") + requireComesAfter(t, sorted, "c", "d") + requireComesAfter(t, sorted, "c", "e") + requireComesAfter(t, sorted, "b", "e") +} + +func requireComesAfter(t *testing.T, sorted []string, a, b string) { + t.Helper() + aIdx := -1 + bIdx := -1 + for i, v := range sorted { + if v == a { + aIdx = i + } + if v == b { + bIdx = i + } + } + if aIdx == -1 { + t.Errorf("Expected %s to be in sorted list", a) + } + if bIdx == -1 { + t.Errorf("Expected %s to be in sorted list", b) + } + if aIdx < bIdx { + t.Errorf("Expected %s to come after %s", a, b) + } +} diff --git a/policy/resolver.go b/policy/resolver.go index 28a1ded8..7d7b17b9 100644 --- a/policy/resolver.go +++ b/policy/resolver.go @@ -1524,45 +1524,99 @@ func (s *LocalServices) jobsToControls(cache *frameworkResolverCache, framework } } + rjByMrn := map[string]*ReportingJob{} for _, rj := range job.ReportingJobs { - query, ok := querymap[rj.QrId] - if !ok { - log.Warn().Str("mrn", framework.Mrn) + query := querymap[rj.QrId] + if query == nil { continue } + rjByMrn[query.Mrn] = rj + } - targets, ok := framework.ReportTargets[query.Mrn] + mrns := framework.TopologicalSort() + for _, mrn := range mrns { + node, ok := framework.Nodes[mrn] if !ok { continue } + targets := framework.ReportTargets[mrn] + var curJob *ReportingJob + switch node.Type { + case ResolvedFrameworkNodeTypeCheck: + if len(targets) == 0 { + continue + } + rj, ok := rjByMrn[mrn] + if !ok { + continue + } + query, ok := querymap[rj.QrId] + if !ok { + continue + } - // Create a reporting job from the query code id to one with the mrn. - // This isn't 100% correct. We don't keep track of all the queries that - // have the same code id. - uuid := cache.relativeChecksum(query.Mrn) - queryJob := &ReportingJob{ - Uuid: uuid, - QrId: query.Mrn, - ChildJobs: map[string]*explorer.Impact{}, - Type: ReportingJob_CHECK, - } - nuJobs[uuid] = queryJob + // Create a reporting job from the query code id to one with the mrn. + // This isn't 100% correct. We don't keep track of all the queries that + // have the same code id. + uuid := cache.relativeChecksum(query.Mrn) + queryJob := &ReportingJob{ + Uuid: uuid, + QrId: query.Mrn, + ChildJobs: map[string]*explorer.Impact{}, + Type: ReportingJob_CHECK, + } + nuJobs[uuid] = queryJob - for i := range targets { - controlMrn := targets[i] + queryJob.ChildJobs[rj.Uuid] = nil + rj.Notify = append(rj.Notify, queryJob.Uuid) + continue + case ResolvedFrameworkNodeTypeControl: // skip controls which are part of a FrameworkGroup with type DISABLE - if group, ok := frameworkGroupByControlMrn[controlMrn]; ok { + if group, ok := frameworkGroupByControlMrn[mrn]; ok { if group.Type == GroupType_DISABLE { continue } } - controlJob := ensureControlJob(cache, nuJobs, controlMrn, framework, frameworkGroupByControlMrn) - queryJob.ChildJobs[rj.Uuid] = nil - rj.Notify = append(rj.Notify, queryJob.Uuid) + // Avoid adding controls which don't have any active children + shouldAdd := false + for _, child := range framework.ReportSources[mrn] { + if _, ok := nuJobs[cache.relativeChecksum(child)]; ok { + shouldAdd = true + break + } + } + + if !shouldAdd { + continue + } + + controlJob := ensureControlJob(cache, nuJobs, mrn, framework, frameworkGroupByControlMrn) + // addedControlJobs[mrn] = controlJob + curJob = controlJob + case ResolvedFrameworkNodeTypeFramework: + curJob = job.ReportingJobs[cache.relativeChecksum(mrn)] + } - controlJob.ChildJobs[queryJob.Uuid] = nil - queryJob.Notify = append(queryJob.Notify, controlJob.Uuid) + // Ensure that child jobs notify their parents + for _, child := range framework.ReportSources[mrn] { + childJob, ok := nuJobs[cache.relativeChecksum(child)] + if !ok { + continue + } + if _, ok := curJob.ChildJobs[childJob.Uuid]; !ok { + curJob.ChildJobs[childJob.Uuid] = nil + } + shouldAdd := true + for _, parent := range childJob.Notify { + if parent == curJob.Uuid { + shouldAdd = false + break + } + } + if shouldAdd { + childJob.Notify = append(childJob.Notify, curJob.Uuid) + } } } diff --git a/policy/resolver_test.go b/policy/resolver_test.go index 3712ef3b..30297449 100644 --- a/policy/resolver_test.go +++ b/policy/resolver_test.go @@ -353,6 +353,8 @@ frameworks: title: control2 - uid: control3 title: control3 + - uid: control4 + title: control4 - uid: framework2 name: framework2 groups: @@ -380,6 +382,9 @@ framework_maps: checks: - uid: check-pass-2 - uid: check-fail + - uid: control4 + controls: + - uid: control1 ` t.Run("resolve with correct filters", func(t *testing.T) { @@ -431,7 +436,9 @@ framework_maps: rjTester.requireReportsTo(queryMrn("check-fail"), controlMrn("control2")) rjTester.requireReportsTo(controlMrn("control1"), frameworkMrn("framework1")) + rjTester.requireReportsTo(controlMrn("control1"), controlMrn("control4")) rjTester.requireReportsTo(controlMrn("control2"), frameworkMrn("framework1")) + rjTester.requireReportsTo(controlMrn("control4"), frameworkMrn("framework1")) rjTester.requireReportsTo(frameworkMrn("framework1"), frameworkMrn("parent-framework")) rjTester.requireReportsTo(frameworkMrn("parent-framework"), "root") })