diff --git a/provisioner/template.go b/provisioner/template.go index fd7540bc..0750a9eb 100644 --- a/provisioner/template.go +++ b/provisioner/template.go @@ -144,6 +144,7 @@ func renderTemplate(context *templateContext, file string) (string, error) { "sumQuantities": sumQuantities, "awsValidID": awsValidID, "indent": sprig.GenericFuncMap()["indent"], + "dict": dict, } content, ok := context.fileData[file] @@ -285,6 +286,27 @@ func split(s string, d string) []string { return strings.Split(s, d) } +// dict is a template function that constructs a map out of its arguments. +// Argument list is treated as a sequence of key-value pairs and must have even length. +// Key arguments must have string type. +func dict(args ...interface{}) (map[string]interface{}, error) { + if len(args)%2 != 0 { + return nil, fmt.Errorf("dict: invalid number of arguments: %d", len(args)) + } + dict := make(map[string]interface{}, len(args)/2) + for i := 0; i < len(args); i += 2 { + key, ok := args[i].(string) + if !ok { + return nil, fmt.Errorf("dict: key argument %d must be string", i) + } + if _, ok := dict[key]; ok { + return nil, fmt.Errorf("dict: duplicate key %s", key) + } + dict[key] = args[i+1] + } + return dict, nil +} + // accountID returns just the ID part of an account func accountID(account string) (string, error) { items := strings.Split(account, ":") diff --git a/provisioner/template_test.go b/provisioner/template_test.go index ddefd7fb..ba1d7ecd 100644 --- a/provisioner/template_test.go +++ b/provisioner/template_test.go @@ -1164,3 +1164,39 @@ func TestNodePoolGroupsProfile(t *testing.T) { }) } } + +func TestDict(t *testing.T) { + result, err := renderSingle( + t, + `{{ define "a-template" -}} +name: {{ .name }} +version: {{ .version }} +{{ end }} + +{{ template "a-template" dict "name" "foo" "version" .Values.data }} +`, + "1") + + require.NoError(t, err) + require.EqualValues(t, ` + +name: foo +version: 1 + +`, result) +} + +func TestDictInvalidArgs(t *testing.T) { + for i, tc := range []struct { + args []interface{} + }{ + {args: []interface{}{"foo"}}, + {args: []interface{}{1, "foo"}}, + {args: []interface{}{"foo", "bar", "foo", "baz"}}, + } { + t.Run(fmt.Sprintf("%d: %v", i, tc.args), func(t *testing.T) { + _, err := dict(tc.args...) + require.Error(t, err) + }) + } +}