Skip to content

Commit

Permalink
Fix ToString for AWS KMS to include role, context, and profile.
Browse files Browse the repository at this point in the history
Signed-off-by: Felix Fontein <[email protected]>
  • Loading branch information
felixfontein committed Jan 17, 2025
1 parent c6bdd2f commit e79a9e9
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 5 deletions.
24 changes: 20 additions & 4 deletions config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ creation_rules:
- kms:
- arn: foo
aws_profile: bar
- arn: foo
context:
baz: bam
pgp:
- bar
gcp_kms:
Expand Down Expand Up @@ -421,6 +424,7 @@ func TestLoadConfigFile(t *testing.T) {
}

func TestLoadConfigFileWithGroups(t *testing.T) {
bam := "bam"
expected := configFile{
CreationRules: []creationRule{
{
Expand All @@ -432,7 +436,18 @@ func TestLoadConfigFileWithGroups(t *testing.T) {
PathRegex: "",
KeyGroups: []keyGroup{
{
KMS: []kmsKey{{Arn: "foo", AwsProfile: "bar"}},
KMS: []kmsKey{
{
Arn: "foo",
AwsProfile: "bar",
},
{
Arn: "foo",
Context: map[string]*string{
"baz": &bam,
},
},
},
PGP: []string{"bar"},
GCPKMS: []gcpKmsKey{{ResourceID: "foo"}},
AzureKV: []azureKVKey{{VaultURL: "https://foo.vault.azure.net", Key: "foo-key", Version: "fooversion"}},
Expand Down Expand Up @@ -464,7 +479,7 @@ func TestLoadConfigFileWithMerge(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, 2, len(conf.KeyGroups))
assert.Equal(t, 1, len(conf.KeyGroups[0]))
assert.Equal(t, 22, len(conf.KeyGroups[1]))
assert.Equal(t, 23, len(conf.KeyGroups[1]))
}

func TestLoadConfigFileWithNoMatchingRules(t *testing.T) {
Expand Down Expand Up @@ -538,9 +553,10 @@ func TestKeyGroupsForFileWithGroups(t *testing.T) {
conf, err := parseCreationRuleForFile(parseConfigFile(sampleConfigWithGroups, t), "/conf/path", "whatever", nil)
assert.Nil(t, err)
assert.Equal(t, "bar", conf.KeyGroups[0][0].ToString())
assert.Equal(t, "foo", conf.KeyGroups[0][1].ToString())
assert.Equal(t, "foo||bar", conf.KeyGroups[0][1].ToString())
assert.Equal(t, "foo|baz:bam", conf.KeyGroups[0][2].ToString())
assert.Equal(t, "qux", conf.KeyGroups[1][0].ToString())
assert.Equal(t, "baz", conf.KeyGroups[1][1].ToString())
assert.Equal(t, "baz||foo", conf.KeyGroups[1][1].ToString())
}

func TestLoadConfigFileWithUnencryptedSuffix(t *testing.T) {
Expand Down
46 changes: 45 additions & 1 deletion kms/keysource.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"fmt"
"os"
"regexp"
"sort"
"strings"
"time"

Expand Down Expand Up @@ -181,6 +182,38 @@ func ParseKMSContext(in interface{}) map[string]*string {
return out
}

// kmsContextToString converts a dictionary into a string that can be parsed
// again with ParseKMSContext().
func kmsContextToString(in map[string]*string) string {
if len(in) == 0 {
return ""
}

// Collect the keys in a slice and compute the expected length
keys := make([]string, 0, len(in))
length := 0
for key := range in {
keys = append(keys, key)
length += len(key) + len(*in[key]) + 2
}

// Sort the keys
sort.Strings(keys)

// Compose a comma-separated string of key-vale pairs
var builder strings.Builder
builder.Grow(length)
for index, key := range keys {
if index > 0 {
builder.WriteString(",")
}
builder.WriteString(key)
builder.WriteByte(':')
builder.WriteString(*in[key])
}
return builder.String()
}

// CredentialsProvider is a wrapper around aws.CredentialsProvider used for
// authentication towards AWS KMS.
type CredentialsProvider struct {
Expand Down Expand Up @@ -278,7 +311,18 @@ func (key *MasterKey) NeedsRotation() bool {

// ToString converts the key to a string representation.
func (key *MasterKey) ToString() string {
return key.Arn
arnRole := key.Arn
if key.Role != "" {
arnRole = fmt.Sprintf("%s+%s", key.Arn, key.Role)
}
context := kmsContextToString(key.EncryptionContext)
if key.AwsProfile != "" {
return fmt.Sprintf("%s|%s|%s", arnRole, context, key.AwsProfile)
}
if len(key.EncryptionContext) > 0 {
return fmt.Sprintf("%s|%s", arnRole, context)
}
return arnRole
}

// ToMap converts the MasterKey to a map for serialization purposes.
Expand Down
30 changes: 30 additions & 0 deletions kms/keysource_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -367,8 +367,38 @@ func TestMasterKey_NeedsRotation(t *testing.T) {
}

func TestMasterKey_ToString(t *testing.T) {
dummyARNWithRole := fmt.Sprintf("%s+arn:aws:iam::my-role", dummyARN)

bar := "bar"
bam := "bam"
context := map[string]*string{
"foo": &bar,
"baz": &bam,
}

key := NewMasterKeyFromArn(dummyARN, nil, "")
assert.Equal(t, dummyARN, key.ToString())

key = NewMasterKeyFromArn(dummyARNWithRole, nil, "")
assert.Equal(t, dummyARNWithRole, key.ToString())

key = NewMasterKeyFromArn(dummyARN, nil, "profile")
assert.Equal(t, fmt.Sprintf("%s||profile", dummyARN), key.ToString())

key = NewMasterKeyFromArn(dummyARNWithRole, nil, "profile")
assert.Equal(t, fmt.Sprintf("%s||profile", dummyARNWithRole), key.ToString())

key = NewMasterKeyFromArn(dummyARN, context, "")
assert.Equal(t, fmt.Sprintf("%s|baz:bam,foo:bar", dummyARN), key.ToString())

key = NewMasterKeyFromArn(dummyARNWithRole, context, "")
assert.Equal(t, fmt.Sprintf("%s|baz:bam,foo:bar", dummyARNWithRole), key.ToString())

key = NewMasterKeyFromArn(dummyARN, context, "profile")
assert.Equal(t, fmt.Sprintf("%s|baz:bam,foo:bar|profile", dummyARN), key.ToString())

key = NewMasterKeyFromArn(dummyARNWithRole, context, "profile")
assert.Equal(t, fmt.Sprintf("%s|baz:bam,foo:bar|profile", dummyARNWithRole), key.ToString())
}

func TestMasterKey_ToMap(t *testing.T) {
Expand Down

0 comments on commit e79a9e9

Please sign in to comment.