From 38754ade98ef626ad8ea07078ef19f3b9460f66b Mon Sep 17 00:00:00 2001 From: Brian Nuszkowski Date: Mon, 24 Oct 2022 12:19:44 -0400 Subject: [PATCH] Introducing secret Unmarshal function (#73) * First draft of unmarshal function * Add unit test * Stricter type checking * Always return type check errors while unmarshaling * More comprehensive tests and pointer support * Use options pattern for setting apex * Trim apex during config time instead of runtime * Pass options when recursing structs * Latest iteration of Secret Unmarshaler --- .gitignore | 6 +- pkg/daytona/options.go | 45 +++++ pkg/daytona/options_test.go | 70 ++++++++ pkg/daytona/unmarshal.go | 282 ++++++++++++++++++++++++++++++++ pkg/daytona/unmarshal_test.go | 299 ++++++++++++++++++++++++++++++++++ 5 files changed, 699 insertions(+), 3 deletions(-) create mode 100644 pkg/daytona/options.go create mode 100644 pkg/daytona/options_test.go create mode 100644 pkg/daytona/unmarshal.go create mode 100644 pkg/daytona/unmarshal_test.go diff --git a/.gitignore b/.gitignore index d24887ef..b9268bf2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,3 @@ -cmd/daytona/daytona -daytona -coverage.out \ No newline at end of file +/cmd/daytona/daytona +/daytona +/coverage.out diff --git a/pkg/daytona/options.go b/pkg/daytona/options.go new file mode 100644 index 00000000..7292fc5f --- /dev/null +++ b/pkg/daytona/options.go @@ -0,0 +1,45 @@ +package daytona + +import ( + "github.com/hashicorp/vault/api" +) + +type Option interface { + Apply(s *SecretUnmarshler) +} + +// WithClient allows callers to provice a custom +// vault client +func WithClient(client *api.Client) Option { + return withClient{client} +} + +type withClient struct{ c *api.Client } + +func (w withClient) Apply(s *SecretUnmarshler) { + s.client = w.c +} + +// WithTokenString allows callers to provide a token +// in the form of a string +func WithTokenString(token string) Option { + return withTokenString{token} +} + +type withTokenString struct{ token string } + +func (w withTokenString) Apply(s *SecretUnmarshler) { + s.tokenString = w.token +} + +// WithTokenFile allows callers to provide a path +// to a file where a vault token is stored +func WithTokenFile(path string) Option { + return withTokenFile{path} +} + +type withTokenFile struct{ path string } + +func (w withTokenFile) Apply(s *SecretUnmarshler) { + s.tokenFile = w.path +} diff --git a/pkg/daytona/options_test.go b/pkg/daytona/options_test.go new file mode 100644 index 00000000..b8e393b7 --- /dev/null +++ b/pkg/daytona/options_test.go @@ -0,0 +1,70 @@ +package daytona + +import ( + "io/ioutil" + "log" + "os" + "testing" + + "github.com/hashicorp/vault/api" +) + +var testToken = "THIS IS MY TOKEN" + +func TestOptionsWithClient(t *testing.T) { + client, err := api.NewClient(api.DefaultConfig()) + if err != nil { + t.Fatal(err) + } + + client.SetToken(testToken) + + u, err := NewSecretUnmarshler(WithClient(client)) + if err != nil { + t.Fatal(err) + } + + if u.client.Token() != testToken { + // we purposely don't log api.Client.Token() in the + // unlikely event we pickup a production token + t.Fatalf("WithClient options is not working. exptected token %s, got something else...", testToken) + } +} + +func TestOptionsWithTokenString(t *testing.T) { + u, err := NewSecretUnmarshler(WithTokenString(testToken)) + if err != nil { + t.Fatal(err) + } + + if u.client.Token() != testToken { + // we purposely don't log api.Client.Token() in the + // unlikely event we pickup a production token + t.Fatalf("WithTokenString options is not working. exptected token %s, got something else...", testToken) + } +} + +func TestOptionsWithTokenFile(t *testing.T) { + fileTokenContents := "THIS IS MY FILE TOKEN" + file, err := ioutil.TempFile("", "test-vault-token") + if err != nil { + log.Fatal(err) + } + defer os.Remove(file.Name()) + + _, err = file.Write([]byte(fileTokenContents)) + if err != nil { + t.Fatal(err) + } + + u, err := NewSecretUnmarshler(WithTokenFile(file.Name())) + if err != nil { + t.Fatal(err) + } + + if u.client.Token() != fileTokenContents { + // we purposely don't log api.Client.Token() in the + // unlikely event we pickup a production token + t.Fatalf("WithTokenFile options is not working. exptected token %s, got something else...", testToken) + } +} diff --git a/pkg/daytona/unmarshal.go b/pkg/daytona/unmarshal.go new file mode 100644 index 00000000..5eee5d00 --- /dev/null +++ b/pkg/daytona/unmarshal.go @@ -0,0 +1,282 @@ +package daytona + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "reflect" + "strconv" + "time" + + "github.com/hashicorp/vault/api" +) + +const ( + tagVaultPathKeyName = "vault_path_key" + tagVaultPathDataKeyName = "vault_path_data_key" + + tagVaultDataKeyName = "vault_data_key" + + defaultDataKeyFieldName = "value" +) + +var ( + // ErrValueInput indicates the provided value is not a struct pointer + ErrValueInput = errors.New("the provided value must be a struct pointer") +) + +// SecretUnmarshler reads data from Vault and stores the result(s) in the +// a provided struct. This can be useful to inject sensitive configuration +// items directly into config structs +type SecretUnmarshler struct { + client *api.Client + tokenString string + tokenFile string +} + +// NewSecretUnmarshler returns a new SecretUnmarshler, applying any options +// that are supplied. +func NewSecretUnmarshler(opts ...Option) (*SecretUnmarshler, error) { + var s SecretUnmarshler + for _, opt := range opts { + opt.Apply(&s) + } + + if s.client == nil { + client, err := api.NewClient(api.DefaultConfig()) + if err != nil { + return nil, fmt.Errorf("failed to create new vault client: %w", err) + } + s.client = client + } + + if s.tokenString != "" && s.tokenFile != "" { + return nil, errors.New("cannot use dual token sources, pick one") + } + + if s.tokenString != "" { + s.client.SetToken(s.tokenString) + } + + if s.tokenFile != "" { + b, err := ioutil.ReadFile(s.tokenFile) + if err != nil { + return nil, fmt.Errorf("failed to read token from %s: %w", s.tokenFile, err) + } + + s.client.SetToken(string(b)) + } + return &s, nil +} + +// Unmarshal makes a read request to vault using the supplied vault apex path +// and stores the result(s) in the value pointed to by v. Unmarshal traverses the value v +// recursively looking for tagged fields that can be populated with secret data. +// +// (DATA EXAMPLE #1) Consider the design of the following secret path: secret/application, that contains +// several sub-keys: +// +// API_KEY - the data being stored in the data key 'value' +// DB_PASSWORD - the data being stored in the data key 'value' +// +// (DATA EXAMPLE #2) Consider the design of the following secret path: secret/application/configs, that contains +// several data keys +// +// api_key +// db_password +// +// A field tagged with 'vault_path_key' implies that the apex is a top-level secret path, +// and the value provided by 'vault_path_key' is the suffix key in the path. The full final path will +// be a combination of the apex and the path key. e.g. Using the example #1 above, an apex of secret/application +// with a 'vault_path_key' of DB_PASSWORD, will attempt to read the data stored in secret/application/DB_PASSSWORD. +// By default a data key of 'value' is used. The data key can be customized via the tag `vault_path_data_key` +// +// Field string `vault_path_key:"DB_PASSWORD"` +// Field string `vault_path_key:"DB_PASSWORD" vault_path_data_key:"password"` // data key override +// +// A field tagged with 'vault_data_key' implies that the apex is a full, final secret path +// and the value provided by 'vault_data_key' is the name of the data key. e.g. an apex of secret/application/configs +// with a 'vault_data_key' of db_password, will attempt to read the data stored in secret/application/configs, referncing +// the db_password data key. +// +// Field string `vault_data_key:"db_password"` +func (su SecretUnmarshler) Unmarshal(ctx context.Context, apex string, v interface{}) error { + val := reflect.ValueOf(v) + if val.Kind() != reflect.Ptr { + return ErrValueInput + } + + val = val.Elem() + if val.Kind() != reflect.Struct { + return ErrValueInput + } + + for i := 0; i < val.NumField(); i++ { + fName := val.Type().Field(i).Name + f := val.Field(i) + if f.Kind() == reflect.Ptr { + if f.IsNil() { + f.Set(reflect.New(f.Type().Elem())) + } + f = f.Elem() + } + + qualified, path, valueIndex := introspect(apex, val.Type().Field(i).Tag) + if !qualified && f.Kind() != reflect.Struct { + continue + } + + switch f.Kind() { + case reflect.Struct: + err := su.Unmarshal(ctx, path, f.Addr().Interface()) + if err != nil { + return err + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + var iv int64 + + v, err := fetchValue(ctx, su.client, path, valueIndex) + if err != nil { + return err + } + + if f.Kind() == reflect.Int64 && f.Type().String() == "time.Duration" { + var d time.Duration + dur, ok := v.(string) + if ok { + d, err = time.ParseDuration(dur) + if err != nil { + return err + } + iv = int64(d) + } else { + return fmt.Errorf("expected a string but was given type %T for field %s", v, fName) + } + } else { + switch v := v.(type) { + case json.Number: + if value, err := v.Int64(); err == nil { + iv = value + } else { + return err + } + case string: + vv, err := strconv.ParseInt(v, 0, f.Type().Bits()) + if err != nil { + return err + } + iv = vv + default: + return fmt.Errorf("expected a number or string but was given type %T for field %s", v, fName) + } + } + f.SetInt(iv) + case reflect.Float32, reflect.Float64: + var iv float64 + + v, err := fetchValue(ctx, su.client, path, valueIndex) + if err != nil { + return err + } + + switch v := v.(type) { + case json.Number: + if value, err := v.Float64(); err == nil { + iv = value + } else { + return err + } + case string: + vv, err := strconv.ParseFloat(v, f.Type().Bits()) + if err != nil { + return err + } + iv = vv + default: + return fmt.Errorf("expected a float or string but was given type %T for field %s", v, fName) + } + f.SetFloat(iv) + case reflect.String: + v, err := fetchValue(ctx, su.client, path, valueIndex) + if err != nil { + return err + } + if vv, ok := v.(string); ok { + f.SetString(vv) + } else { + return fmt.Errorf("expected a string but was given type %T for field %s", v, fName) + } + case reflect.Bool: + v, err := fetchValue(ctx, su.client, path, valueIndex) + if err != nil { + return err + } + + var b bool + switch v := v.(type) { + case bool: + b = v + case string: + pb, err := strconv.ParseBool(v) + if err != nil { + return err + } + b = pb + + default: + return fmt.Errorf("expected a bool or string but was given type %T for field %s", v, fName) + } + f.SetBool(b) + default: + continue + } + } + return nil +} + +func introspect(apex string, tag reflect.StructTag) (qualified bool, path string, key string) { + pathKey, isPathKey := tag.Lookup(tagVaultPathKeyName) + dataKey, isDataKey := tag.Lookup(tagVaultDataKeyName) + + if isPathKey && isDataKey { + // disqualified, unsolveable + return + } + + path = apex + key = defaultDataKeyFieldName + + if isPathKey { + qualified = true + path = fmt.Sprintf("%s/%s", apex, pathKey) + if dk, ok := tag.Lookup(tagVaultPathDataKeyName); ok { + key = dk + } + } + + if isDataKey { + qualified = true + key = dataKey + } + return +} + +func fetchValue(ctx context.Context, client *api.Client, path, valueIndex string) (interface{}, error) { + secret, err := client.Logical().ReadWithContext(ctx, path) + if err != nil { + return nil, fmt.Errorf("failed to read secret %s: %w", path, err) + } + if secret == nil || secret.Data == nil { + return nil, errors.New("path did not return any data") + } + + value := secret.Data[valueIndex] + + if value == nil { + return nil, fmt.Errorf("could not extract value from data %s %s", path, valueIndex) + } + + return value, nil +} diff --git a/pkg/daytona/unmarshal_test.go b/pkg/daytona/unmarshal_test.go new file mode 100644 index 00000000..ce0aface --- /dev/null +++ b/pkg/daytona/unmarshal_test.go @@ -0,0 +1,299 @@ +package daytona + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/cruise-automation/daytona/pkg/helpers/testhelpers" + "github.com/stretchr/testify/assert" +) + +var testPayload = map[string]interface{}{ + "auth": nil, + "data": map[string]interface{}{ + "value": "standard", + "password": "nonstandard", + "private_key": "BEGIN PRIVATE KEY", + "a_bad_string": 9, + "an_int": 12, + "a_string_int": "12", + "a_bad_int": "xxx", + "a_float": 6.66, + "a_string_float": "6.66", + "a_bad_float": "xxx", + "a_bool": true, + "a_string_bool": "true", + "a_bad_bool": "yee", + "a_duration": "7h", + "a_bad_duration": "hello", + "a_mismatch": "7xL", + }, + "lease_duration": 3600, + "lease_id": "", + "renewable": false, +} + +func generateMuliKeyPayload() (string, error) { + b, err := json.Marshal(testPayload) + if err != nil { + return "", err + } + return string(b), nil +} + +func TestUnmarshalSecretDataKeys(t *testing.T) { + tp, err := generateMuliKeyPayload() + if err != nil { + t.Fatal(err) + } + + ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/v1/secret/application/password": + fmt.Fprintln(w, tp) + case "/v1/secret/top-level/API_KEY": + fmt.Fprintln(w, ` + { + "auth": null, + "data": { + "value": "THIS_IS_MY_API_KEY" + }, + "lease_duration": 3600, + "lease_id": "", + "renewable": false + } + `) + default: + w.WriteHeader(404) + } + })) + defer ts.Close() + client, err := testhelpers.GetTestClient(ts.URL) + if err != nil { + t.Fatal(err) + } + + secret, err := NewSecretUnmarshler(WithClient(client)) + if err != nil { + t.Fatal(err) + } + + testData := testPayload["data"].(map[string]interface{}) + + // generic type input validation + empty := struct{}{} + err = secret.Unmarshal(context.TODO(), "secret/applicaiton", empty) + assert.Equal(t, ErrValueInput, err) + + var tst string + err = secret.Unmarshal(context.TODO(), "secret/applicaiton", tst) + assert.Equal(t, ErrValueInput, err) + + // unaffected fields + normal := struct { + Hello string + Empty string + }{Hello: "hi!"} + err = secret.Unmarshal(context.TODO(), "secret/applicaiton", &normal) + assert.Equal(t, nil, err) + assert.Equal(t, "hi!", normal.Hello) + assert.Equal(t, "", normal.Empty) + + conflictingTags := struct { + Value string `vault_path_key:"password" vault_data_key:"value"` + }{} + err = secret.Unmarshal(context.TODO(), "secret/application", &conflictingTags) + assert.Equal(t, nil, err) + assert.Equal(t, "", conflictingTags.Value) + + aValue := struct { + Value string `vault_data_key:"value"` + }{} + err = secret.Unmarshal(context.TODO(), "secret/application/password", &aValue) + assert.Equal(t, nil, err) + assert.Equal(t, testData["value"], aValue.Value) + + aFullPath := struct { + Password string `vault_data_key:"password"` + }{} + err = secret.Unmarshal(context.TODO(), "secret/application/password", &aFullPath) + assert.Equal(t, nil, err) + assert.Equal(t, testData["password"], aFullPath.Password) + + var embedded struct { + Nested struct { + PrivateKey string `vault_data_key:"private_key"` + } + } + err = secret.Unmarshal(context.TODO(), "secret/application/password", &embedded) + assert.Equal(t, nil, err) + assert.Equal(t, testData["private_key"], embedded.Nested.PrivateKey) + + aBadString := struct { + aBadString string `vault_data_key:"a_bad_string"` + }{} + err = secret.Unmarshal(context.TODO(), "secret/application/password", &aBadString) + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "expected a string") + + goodDuration := struct { + ADuration time.Duration `vault_data_key:"a_duration"` + }{} + err = secret.Unmarshal(context.TODO(), "secret/application/password", &goodDuration) + assert.Equal(t, nil, err) + assert.Equal(t, time.Hour*7, goodDuration.ADuration) + + invalidDuration := struct { + ADuration time.Duration `vault_data_key:"a_bad_duration"` + }{} + err = secret.Unmarshal(context.TODO(), "secret/application/password", &invalidDuration) + assert.Equal(t, `time: invalid duration "hello"`, err.Error()) + + aFloat := struct { + AFloat float64 `vault_data_key:"a_float"` + }{} + err = secret.Unmarshal(context.TODO(), "secret/application/password", &aFloat) + assert.Equal(t, nil, err) + assert.Equal(t, testData["a_float"], aFloat.AFloat) + + aStringFloat := struct { + AFloat float64 `vault_data_key:"a_string_float"` + }{} + err = secret.Unmarshal(context.TODO(), "secret/application/password", &aStringFloat) + assert.Equal(t, nil, err) + assert.Equal(t, testData["a_float"], aStringFloat.AFloat) + + aBadFloat := struct { + aBadFloat float64 `vault_data_key:"a_bad_float"` + }{} + err = secret.Unmarshal(context.TODO(), "secret/application/password", &aBadFloat) + assert.NotNil(t, err) + assert.Contains(t, err.Error(), fmt.Sprintf("parsing %q", testData["a_bad_float"])) + + anInteger := struct { + AnInteger int `vault_data_key:"an_int"` + }{} + err = secret.Unmarshal(context.TODO(), "secret/application/password", &anInteger) + assert.Equal(t, nil, err) + assert.Equal(t, testData["an_int"], anInteger.AnInteger) + + aStringInteger := struct { + AStringInteger int `vault_data_key:"a_string_int"` + }{} + err = secret.Unmarshal(context.TODO(), "secret/application/password", &aStringInteger) + assert.Equal(t, nil, err) + assert.Equal(t, testData["an_int"], aStringInteger.AStringInteger) + + aBadInt := struct { + aBadInt int64 `vault_data_key:"a_bad_int"` + }{} + err = secret.Unmarshal(context.TODO(), "secret/application/password", &aBadInt) + assert.NotNil(t, err) + assert.Contains(t, err.Error(), fmt.Sprintf("parsing %q", testData["a_bad_int"])) + + aBool := struct { + ABool bool `vault_data_key:"a_bool"` + }{} + err = secret.Unmarshal(context.TODO(), "secret/application/password", &aBool) + assert.Equal(t, nil, err) + assert.Equal(t, testData["a_bool"], aBool.ABool) + + aStringBool := struct { + AStringBool bool `vault_data_key:"a_string_bool"` + }{} + err = secret.Unmarshal(context.TODO(), "secret/application/password", &aStringBool) + assert.Equal(t, nil, err) + assert.Equal(t, testData["a_bool"], aStringBool.AStringBool) + + aBadBool := struct { + aBadBool bool `vault_data_key:"a_bad_bool"` + }{} + err = secret.Unmarshal(context.TODO(), "secret/application/password", &aBadBool) + assert.NotNil(t, err) + assert.Contains(t, err.Error(), fmt.Sprintf("parsing %q", testData["a_bad_bool"])) + + aMisMatch := struct { + AMismatch float64 `vault_data_key:"a_mismatch"` + }{} + err = secret.Unmarshal(context.TODO(), "secret/application/password", &aMisMatch) + assert.Equal(t, fmt.Sprintf(`strconv.ParseFloat: parsing %q: invalid syntax`, testData["a_mismatch"]), err.Error()) + + type x struct { + Value string `vault_data_key:"password"` + } + aStructWithPointer := struct { + Thingy *x + }{} + err = secret.Unmarshal(context.TODO(), "secret/application/password", &aStructWithPointer) + assert.Equal(t, nil, err) + assert.Equal(t, testData["password"], aStructWithPointer.Thingy.Value) + + aPtrField := struct { + Password *string `vault_data_key:"password"` + }{} + err = secret.Unmarshal(context.TODO(), "secret/application/password", &aPtrField) + assert.Equal(t, nil, err) + assert.Equal(t, testData["password"], *aPtrField.Password) +} + +func TestUnmarshalSecretPathKeys(t *testing.T) { + ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/v1/secret/top-level/API_KEY": + fmt.Fprintln(w, ` + { + "auth": null, + "data": { + "value": "THIS_IS_MY_API_KEY" + }, + "lease_duration": 3600, + "lease_id": "", + "renewable": false + } + `) + case "/v1/secret/top-level/SECRET_KEY": + fmt.Fprintln(w, ` + { + "auth": null, + "data": { + "secret": "shhhh" + }, + "lease_duration": 3600, + "lease_id": "", + "renewable": false + } + `) + default: + w.WriteHeader(404) + } + })) + defer ts.Close() + client, err := testhelpers.GetTestClient(ts.URL) + if err != nil { + t.Fatal(err) + } + + secret, err := NewSecretUnmarshler(WithClient(client)) + if err != nil { + t.Fatal(err) + } + + pathKey := struct { + Password string `vault_path_key:"API_KEY"` + }{} + err = secret.Unmarshal(context.TODO(), "secret/top-level", &pathKey) + assert.Equal(t, nil, err) + assert.Equal(t, "THIS_IS_MY_API_KEY", pathKey.Password) + + pathKeyAltDatKey := struct { + Secret string `vault_path_key:"SECRET_KEY" vault_path_data_key:"secret"` + }{} + err = secret.Unmarshal(context.TODO(), "secret/top-level", &pathKeyAltDatKey) + assert.Equal(t, nil, err) + assert.Equal(t, "shhhh", pathKeyAltDatKey.Secret) +}