Skip to content

Commit

Permalink
🐛 Fix check control resolver bug
Browse files Browse the repository at this point in the history
When we have controls that map to check, the resolver creates a
reporting job from the query by its code id to query by its mrn. And
then maps that to the control. However, there was a bug in this
translation where we would pick a random check mrn from the bundle that
had that code id. Since its possible for multiple checks to have the
same id in a bundle, and not all of those checks are part of an active
policy, this would end up picking a mrn from a policy that wasn't
active.
  • Loading branch information
jaym committed Aug 8, 2023
1 parent 66ec8e3 commit 6e327cf
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 27 deletions.
47 changes: 22 additions & 25 deletions policy/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ type resolverCache struct {
baseChecksum string
assetFiltersChecksum string
assetFilters map[string]struct{}
codeIdToMrn map[string][]string

// assigned queries, listed by their UUID (i.e. policy context)
executionQueries map[string]*ExecutionQuery
Expand Down Expand Up @@ -507,6 +508,7 @@ func (s *LocalServices) tryResolve(ctx context.Context, bundleMrn string, assetF
assetFiltersChecksum: assetFiltersChecksum,
assetFilters: assetFiltersMap,
executionQueries: map[string]*ExecutionQuery{},
codeIdToMrn: map[string][]string{},
dataQueries: map[string]struct{}{},
propsCache: explorer.NewPropsCache(),
queriesByMsum: map[string]*explorer.Mquery{},
Expand Down Expand Up @@ -583,8 +585,7 @@ func (s *LocalServices) tryResolve(ctx context.Context, bundleMrn string, assetF
return nil, err
}

queries := bundleMap.QueryMap()
if err := s.jobsToControls(cacheFrameworkJobs, resolvedFramework, collectorJob, queries); err != nil {
if err := s.jobsToControls(cacheFrameworkJobs, resolvedFramework, collectorJob); err != nil {
logCtx.Error().
Err(err).
Str("bundle", bundleMrn).
Expand Down Expand Up @@ -976,6 +977,7 @@ func (cache *policyResolverCache) addCheckJob(ctx context.Context, check *explor
DeprecatedV7Spec: map[string]*DeprecatedV7_ScoringSpec{},
// ^^
}
cache.global.codeIdToMrn[check.CodeId] = append(cache.global.codeIdToMrn[check.CodeId], check.Mrn)
cache.global.reportingJobsByUUID[uuid] = queryJob
cache.global.reportingJobsByMsum[check.Checksum] = append(cache.global.reportingJobsByMsum[check.Checksum], queryJob)
cache.childJobsByMrn[check.Mrn] = append(cache.childJobsByMrn[check.Mrn], queryJob)
Expand Down Expand Up @@ -1496,7 +1498,7 @@ func (s *LocalServices) jobsToFrameworksInner(cache *frameworkResolverCache, res
return nil
}

func (s *LocalServices) jobsToControls(cache *frameworkResolverCache, framework *ResolvedFramework, job *CollectorJob, querymap map[string]*explorer.Mquery) error {
func (s *LocalServices) jobsToControls(cache *frameworkResolverCache, framework *ResolvedFramework, job *CollectorJob) error {
nuJobs := map[string]*ReportingJob{}

// try to find all framework groups of type IGNORE or DISABLE for this and depending frameworks
Expand Down Expand Up @@ -1526,11 +1528,10 @@ func (s *LocalServices) jobsToControls(cache *frameworkResolverCache, framework

rjByMrn := map[string]*ReportingJob{}
for _, rj := range job.ReportingJobs {
query := querymap[rj.QrId]
if query == nil {
continue
queryMrns := cache.codeIdToMrn[rj.QrId]
for _, queryMrn := range queryMrns {
rjByMrn[queryMrn] = rj
}
rjByMrn[query.Mrn] = rj
}

mrns := framework.TopologicalSort()
Expand All @@ -1550,26 +1551,22 @@ func (s *LocalServices) jobsToControls(cache *frameworkResolverCache, framework
if !ok {
continue
}
query, ok := querymap[rj.QrId]
if !ok {
continue
}
queryMrns := cache.codeIdToMrn[rj.QrId]

for _, queryMrn := range queryMrns {
uuid := cache.relativeChecksum(queryMrn)
queryJob := &ReportingJob{
Uuid: uuid,
QrId: queryMrn,
ChildJobs: map[string]*explorer.Impact{},
Type: ReportingJob_CHECK,
}
nuJobs[uuid] = queryJob

// Create a reporting job from the query code id to one with the mrn.
// This isn't 100% correct. We don't keep track of all the queries that
// have the same code id.
uuid := cache.relativeChecksum(query.Mrn)
queryJob := &ReportingJob{
Uuid: uuid,
QrId: query.Mrn,
ChildJobs: map[string]*explorer.Impact{},
Type: ReportingJob_CHECK,
queryJob.ChildJobs[rj.Uuid] = nil
rj.Notify = append(rj.Notify, queryJob.Uuid)
continue
}
nuJobs[uuid] = queryJob

queryJob.ChildJobs[rj.Uuid] = nil
rj.Notify = append(rj.Notify, queryJob.Uuid)
continue
case ResolvedFrameworkNodeTypeControl:
// skip controls which are part of a FrameworkGroup with type DISABLE
if group, ok := frameworkGroupByControlMrn[mrn]; ok {
Expand Down
19 changes: 17 additions & 2 deletions policy/resolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,16 @@ policies:
mql: 1 == 1
- uid: check-pass-2
mql: 2 == 2
- uid: policy-inactive
groups:
- filters: "false"
checks:
- uid: inactive-fail
mql: 1 == 2
- uid: inactive-pass
mql: 1 == 1
- uid: inactive-pass-2
mql: 2 == 2
frameworks:
- uid: framework1
name: framework1
Expand Down Expand Up @@ -391,7 +400,7 @@ framework_maps:
b := parseBundle(t, bundleStr)

srv := initResolver(t, []*testAsset{
{asset: "asset1", policies: []string{policyMrn("policy1")}, frameworks: []string{frameworkMrn("parent-framework")}},
{asset: "asset1", policies: []string{policyMrn("policy1"), policyMrn("policy-inactive")}, frameworks: []string{frameworkMrn("parent-framework")}},
}, []*policy.Bundle{b})

bundle, err := srv.GetBundle(context.Background(), &policy.Mrn{Mrn: "asset1"})
Expand Down Expand Up @@ -421,6 +430,8 @@ framework_maps:
}

for _, rj := range rjTester.rjIdToReportingJob {
_, ok := rjTester.queryIdToReportingJob[rj.QrId]
require.False(t, ok)
rjTester.queryIdToReportingJob[rj.QrId] = rj
}

Expand All @@ -441,6 +452,10 @@ framework_maps:
rjTester.requireReportsTo(controlMrn("control4"), frameworkMrn("framework1"))
rjTester.requireReportsTo(frameworkMrn("framework1"), frameworkMrn("parent-framework"))
rjTester.requireReportsTo(frameworkMrn("parent-framework"), "root")

require.Nil(t, rjTester.queryIdToReportingJob[queryMrn("inactive-fail")])
require.Nil(t, rjTester.queryIdToReportingJob[queryMrn("inactive-pass")])
require.Nil(t, rjTester.queryIdToReportingJob[queryMrn("inactive-pass-2")])
})

t.Run("test checksumming", func(t *testing.T) {
Expand Down

0 comments on commit 6e327cf

Please sign in to comment.