Skip to content

Commit

Permalink
fix(terraform): fix merging of context variables (#1475)
Browse files Browse the repository at this point in the history
  • Loading branch information
nikpivkin authored Oct 10, 2023
1 parent 78aed65 commit 657bb31
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 3 deletions.
10 changes: 7 additions & 3 deletions pkg/scanners/terraform/context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (c *Context) Set(val cty.Value, parts ...string) {
func mergeVars(src cty.Value, parts []string, value cty.Value) cty.Value {

if len(parts) == 0 {
if value.IsKnown() && value.Type().IsObjectType() && !value.IsNull() && value.LengthInt() > 0 && src.IsKnown() && src.Type().IsObjectType() && !src.IsNull() && src.LengthInt() > 0 {
if isNotEmptyObject(src) && isNotEmptyObject(value) {
return mergeObjects(src, value)
}
return value
Expand Down Expand Up @@ -110,11 +110,15 @@ func mergeObjects(a cty.Value, b cty.Value) cty.Value {
}
for key, val := range b.AsValueMap() {
old, exists := output[key]
if exists && val.IsKnown() && val.Type().IsObjectType() && !val.IsNull() && val.LengthInt() > 0 && old.IsKnown() && old.Type().IsObjectType() && !old.IsNull() && old.LengthInt() > 0 {
output[key] = mergeObjects(val, old)
if exists && isNotEmptyObject(old) && isNotEmptyObject(val) {
output[key] = mergeObjects(old, val)
} else {
output[key] = val
}
}
return cty.ObjectVal(output)
}

func isNotEmptyObject(val cty.Value) bool {
return !val.IsNull() && val.IsKnown() && val.Type().IsObjectType() && val.LengthInt() > 0
}
105 changes: 105 additions & 0 deletions pkg/scanners/terraform/context/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,108 @@ func Test_ContextSetThenImmediateGetWithChild(t *testing.T) {
val := ctx.Get("module", "modulename", "mod_result")
assert.Equal(t, "ok", val.AsString())
}

func Test_MergeObjects(t *testing.T) {

tests := []struct {
name string
oldVal cty.Value
newVal cty.Value
expected cty.Value
}{
{
name: "happy",
oldVal: cty.ObjectVal(map[string]cty.Value{
"this": cty.ObjectVal(map[string]cty.Value{
"id": cty.StringVal("some_id"),
"arn": cty.StringVal("some_arn"),
}),
}),
newVal: cty.ObjectVal(map[string]cty.Value{
"this": cty.ObjectVal(map[string]cty.Value{
"arn": cty.StringVal("some_new_arn"),
"bucket": cty.StringVal("test"),
}),
}),
expected: cty.ObjectVal(map[string]cty.Value{
"this": cty.ObjectVal(map[string]cty.Value{
"id": cty.StringVal("some_id"),
"arn": cty.StringVal("some_new_arn"),
"bucket": cty.StringVal("test"),
}),
}),
},
{
name: "old value is empty",
oldVal: cty.EmptyObjectVal,
newVal: cty.ObjectVal(map[string]cty.Value{
"this": cty.ObjectVal(map[string]cty.Value{
"bucket": cty.StringVal("test"),
}),
}),
expected: cty.ObjectVal(map[string]cty.Value{
"this": cty.ObjectVal(map[string]cty.Value{
"bucket": cty.StringVal("test"),
}),
}),
},
{
name: "new value is empty",
oldVal: cty.ObjectVal(map[string]cty.Value{
"this": cty.ObjectVal(map[string]cty.Value{
"bucket": cty.StringVal("test"),
}),
}),
newVal: cty.EmptyObjectVal,
expected: cty.ObjectVal(map[string]cty.Value{
"this": cty.ObjectVal(map[string]cty.Value{
"bucket": cty.StringVal("test"),
}),
}),
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, mergeObjects(tt.oldVal, tt.newVal))
})
}

}

func Test_IsNotEmptyObject(t *testing.T) {
tests := []struct {
name string
val cty.Value
expected bool
}{
{
name: "happy",
val: cty.ObjectVal(map[string]cty.Value{
"field": cty.NilVal,
}),
expected: true,
},
{
name: "empty object",
val: cty.EmptyObjectVal,
expected: false,
},
{
name: "nil value",
val: cty.NilVal,
expected: false,
},
{
name: "dynamic value",
val: cty.DynamicVal,
expected: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, isNotEmptyObject(tt.val))
})
}
}
45 changes: 45 additions & 0 deletions pkg/scanners/terraform/parser/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -859,3 +859,48 @@ policy_rules = {
assert.Equal(t, "host() != 'google.com'", block.GetAttribute("session_matcher").AsStringValueOrDefault("", block).Value())
assert.Equal(t, 1001, block.GetAttribute("priority").AsIntValueOrDefault(0, block).Value())
}

func Test_InputVariableIsExpression(t *testing.T) {
fs := testutil.CreateFS(t, map[string]string{
"main.tf": `
resource "aws_iam_role" "this" {
name = "role-name"
}
resource "aws_s3_bucket" "this" {
bucket = aws_iam_role.this.name
}
module "this" {
source = "./modules/access"
bucket = aws_s3_bucket.this.bucket
}
`,
"modules/access/main.tf": `
variable "bucket" {
type = string
}
resource "aws_s3_bucket_public_access_block" "this" {
bucket = var.bucket
block_public_acls = true
block_public_policy = true
ignore_public_acls = true
restrict_public_buckets = true
}
`,
})
p := New(fs, "", OptionStopOnHCLError(true))
require.NoError(t, p.ParseFS(context.TODO(), "."))
modules, _, err := p.EvaluateAll(context.TODO())
require.NoError(t, err)
require.Len(t, modules, 2)

blocks := modules.GetResourcesByType("aws_s3_bucket_public_access_block")
require.Len(t, blocks, 1)

accessBlock := blocks[0]

bucket := accessBlock.GetAttribute("bucket").AsStringValueOrDefault("", accessBlock).Value()
assert.Equal(t, "role-name", bucket)
}

0 comments on commit 657bb31

Please sign in to comment.