From 82b0a0a84c2c6f1cfc6d8a7e950e1071193b423c Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Wed, 10 Apr 2019 15:40:57 -0700 Subject: [PATCH 01/13] regenerate --- boilerplate/lyft/golang_test_targets/Makefile | 2 +- cli/pflags/api/generator.go | 75 +++++++++++++++---- cli/pflags/api/generator_test.go | 32 +++++++- cli/pflags/api/sample.go | 5 +- cli/pflags/api/templates.go | 18 ++++- cli/pflags/api/testdata/testtype.go | 51 ++++++++----- cli/pflags/api/testdata/testtype_test.go | 32 ++++---- cli/pflags/api/utils.go | 28 +++++-- cli/pflags/cmd/root.go | 9 ++- logger/config.go | 19 +++-- logger/config_flags.go | 27 +++++-- logger/config_flags_test.go | 41 +++++++++- profutils/server_test.go | 4 +- storage/config.go | 20 +++-- storage/config_flags.go | 41 +++++++--- storage/config_flags_test.go | 22 +++--- 16 files changed, 320 insertions(+), 106 deletions(-) diff --git a/boilerplate/lyft/golang_test_targets/Makefile b/boilerplate/lyft/golang_test_targets/Makefile index 1c6f893..04b79ba 100644 --- a/boilerplate/lyft/golang_test_targets/Makefile +++ b/boilerplate/lyft/golang_test_targets/Makefile @@ -1,6 +1,6 @@ .PHONY: lint lint: #lints the package for common code smells - which golangci-lint || curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | bash -s -- -b $$GOPATH/bin v1.10 + which golangci-lint || curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | bash -s -- -b $$GOPATH/bin v1.11 golangci-lint run # If code is failing goimports linter, this will fix. diff --git a/cli/pflags/api/generator.go b/cli/pflags/api/generator.go index 2e4dd30..0ee2a61 100644 --- a/cli/pflags/api/generator.go +++ b/cli/pflags/api/generator.go @@ -19,8 +19,9 @@ const ( // PFlagProviderGenerator parses and generates GetPFlagSet implementation to add PFlags for a given struct's fields. type PFlagProviderGenerator struct { - pkg *types.Package - st *types.Named + pkg *types.Package + st *types.Named + defaultVar *types.Var } // This list is restricted because that's the only kinds viper parses out, otherwise it assumes strings. @@ -54,7 +55,7 @@ func buildFieldForSlice(ctx context.Context, t SliceOrArray, name, goName, usage emptyDefaultValue := `[]string{}` if b, ok := t.Elem().(*types.Basic); !ok { logger.Infof(ctx, "Elem of type [%v] is not a basic type. It must be json unmarshalable or generation will fail.", t.Elem()) - if !jsonUnmarshaler(t.Elem()) { + if !isJSONUnmarshaler(t.Elem()) { return FieldInfo{}, fmt.Errorf("slice of type [%v] is not supported. Only basic slices or slices of json-unmarshalable types are supported", t.Elem().String()) @@ -85,9 +86,17 @@ func buildFieldForSlice(ctx context.Context, t SliceOrArray, name, goName, usage }, nil } +func appendAccessorIfNotEmpty(baseAccessor, childAccessor string) string { + if len(baseAccessor) == 0 { + return baseAccessor + } + + return baseAccessor + "." + childAccessor +} + // Traverses fields in type and follows recursion tree to discover all fields. It stops when one of two conditions is // met; encountered a basic type (e.g. string, int... etc.) or the field type implements UnmarshalJSON. -func discoverFieldsRecursive(ctx context.Context, typ *types.Named) ([]FieldInfo, error) { +func discoverFieldsRecursive(ctx context.Context, typ *types.Named, defaultValueAccessor string) ([]FieldInfo, error) { logger.Printf(ctx, "Finding all fields in [%v.%v.%v]", typ.Obj().Pkg().Path(), typ.Obj().Pkg().Name(), typ.Obj().Name()) @@ -111,9 +120,11 @@ func discoverFieldsRecursive(ctx context.Context, typ *types.Named) ([]FieldInfo tag.Name = v.Name() } + isPtr := false typ := v.Type() - if ptr, isPtr := typ.(*types.Pointer); isPtr { + if ptr, casted := typ.(*types.Pointer); casted { typ = ptr.Elem() + isPtr = true } switch t := typ.(type) { @@ -137,12 +148,21 @@ func discoverFieldsRecursive(ctx context.Context, typ *types.Named) ([]FieldInfo t.String(), t.Kind(), allowedKinds) } + defaultValue := tag.DefaultValue + if accessor := appendAccessorIfNotEmpty(defaultValueAccessor, v.Name()); len(accessor) > 0 { + defaultValue = accessor + + if isPtr { + defaultValue = fmt.Sprintf("cfg.elemValueOrNil(%s).(%s)", defaultValue, t.Name()) + } + } + fields = append(fields, FieldInfo{ Name: tag.Name, GoName: v.Name(), Typ: t, FlagMethodName: camelCase(t.String()), - DefaultValue: tag.DefaultValue, + DefaultValue: defaultValue, UsageString: tag.Usage, TestValue: `"1"`, TestStrategy: JSON, @@ -155,7 +175,7 @@ func discoverFieldsRecursive(ctx context.Context, typ *types.Named) ([]FieldInfo // If the type has json unmarshaler, then stop the recursion and assume the type is string. config package // will use json unmarshaler to fill in the final config object. - jsonUnmarshaler := jsonUnmarshaler(t) + jsonUnmarshaler := isJSONUnmarshaler(t) testValue := tag.DefaultValue if len(tag.DefaultValue) == 0 { @@ -163,6 +183,16 @@ func discoverFieldsRecursive(ctx context.Context, typ *types.Named) ([]FieldInfo testValue = `"1"` } + defaultValue := tag.DefaultValue + if accessor := appendAccessorIfNotEmpty(defaultValueAccessor, v.Name()); len(accessor) > 0 { + defaultValue = accessor + if isStringer(t) { + defaultValue = defaultValue + ".String()" + } else { + defaultValue = fmt.Sprintf("fmt.Sprintf(\"%%v\",%s)", defaultValue) + } + } + logger.Infof(ctx, "[%v] is of a Named type (struct) with default value [%v].", tag.Name, tag.DefaultValue) if jsonUnmarshaler { @@ -173,7 +203,7 @@ func discoverFieldsRecursive(ctx context.Context, typ *types.Named) ([]FieldInfo GoName: v.Name(), Typ: types.Typ[types.String], FlagMethodName: "String", - DefaultValue: tag.DefaultValue, + DefaultValue: defaultValue, UsageString: tag.Usage, TestValue: testValue, TestStrategy: JSON, @@ -181,7 +211,7 @@ func discoverFieldsRecursive(ctx context.Context, typ *types.Named) ([]FieldInfo } else { logger.Infof(ctx, "Traversing fields in type.") - nested, err := discoverFieldsRecursive(logger.WithIndent(ctx, indent), t) + nested, err := discoverFieldsRecursive(logger.WithIndent(ctx, indent), t, appendAccessorIfNotEmpty(defaultValueAccessor, v.Name())) if err != nil { return nil, err } @@ -228,7 +258,8 @@ func discoverFieldsRecursive(ctx context.Context, typ *types.Named) ([]FieldInfo // NewGenerator initializes a PFlagProviderGenerator for pflags files for targetTypeName struct under pkg. If pkg is not filled in, // it's assumed to be current package (which is expected to be the common use case when invoking pflags from // go:generate comments) -func NewGenerator(pkg, targetTypeName string) (*PFlagProviderGenerator, error) { +func NewGenerator(pkg, targetTypeName, defaultVariableName string) (*PFlagProviderGenerator, error) { + ctx := context.Background() var err error // Resolve package path if pkg == "" || pkg[0] == '.' { @@ -257,9 +288,22 @@ func NewGenerator(pkg, targetTypeName string) (*PFlagProviderGenerator, error) { return nil, fmt.Errorf("%s should be an struct, was %s", targetTypeName, obj.Type().Underlying()) } + var defaultVar *types.Var + obj = targetPackage.Scope().Lookup(defaultVariableName) + if obj != nil { + defaultVar = obj.(*types.Var) + } + + if defaultVar != nil { + logger.Infof(ctx, "Using default variable with name [%v] to assign all default values.", defaultVariableName) + } else { + logger.Infof(ctx, "Using default values defined in tags if any.") + } + return &PFlagProviderGenerator{ - st: st, - pkg: targetPackage, + st: st, + pkg: targetPackage, + defaultVar: defaultVar, }, nil } @@ -268,7 +312,12 @@ func (g PFlagProviderGenerator) GetTargetPackage() *types.Package { } func (g PFlagProviderGenerator) Generate(ctx context.Context) (PFlagProvider, error) { - fields, err := discoverFieldsRecursive(ctx, g.st) + defaultValueAccessor := "" + if g.defaultVar != nil { + defaultValueAccessor = g.defaultVar.Name() + } + + fields, err := discoverFieldsRecursive(ctx, g.st, defaultValueAccessor) if err != nil { return PFlagProvider{}, err } diff --git a/cli/pflags/api/generator_test.go b/cli/pflags/api/generator_test.go index edfab6c..b1d7c61 100644 --- a/cli/pflags/api/generator_test.go +++ b/cli/pflags/api/generator_test.go @@ -6,6 +6,7 @@ import ( "io/ioutil" "os" "path/filepath" + "reflect" "testing" "github.com/stretchr/testify/assert" @@ -14,8 +15,37 @@ import ( // Make sure existing config file(s) parse correctly before overriding them with this flag! var update = flag.Bool("update", false, "Updates testdata") +// If v is a pointer, it will get its element value or the zero value of the element type. +// If v is not a pointer, it will return it as is. +func elemValueOrNil(v interface{}) interface{} { + if t := reflect.TypeOf(v); t.Kind() == reflect.Ptr { + if reflect.ValueOf(v).IsNil() { + return reflect.Zero(t.Elem()).Interface() + } else { + return reflect.ValueOf(v).Interface() + } + } else if v == nil { + return reflect.Zero(t).Interface() + } + + return v +} + +func TestElemValueOrNil(t *testing.T) { + var iPtr *int + assert.Equal(t, 0, elemValueOrNil(iPtr)) + var sPtr *string + assert.Equal(t, "", elemValueOrNil(sPtr)) + var i int + assert.Equal(t, 0, elemValueOrNil(i)) + var s string + assert.Equal(t, "", elemValueOrNil(s)) + var arr []string + assert.Equal(t, arr, elemValueOrNil(arr)) +} + func TestNewGenerator(t *testing.T) { - g, err := NewGenerator(".", "TestType") + g, err := NewGenerator(".", "TestType", "DefaultTestType") assert.NoError(t, err) ctx := context.Background() diff --git a/cli/pflags/api/sample.go b/cli/pflags/api/sample.go index b1ebb50..43dfce7 100644 --- a/cli/pflags/api/sample.go +++ b/cli/pflags/api/sample.go @@ -3,10 +3,13 @@ package api import ( "encoding/json" "errors" - "github.com/lyft/flytestdlib/storage" ) +var DefaultTestType = &TestType{ + StringValue: "Welcome to defaults", +} + type TestType struct { StringValue string `json:"str" pflag:"\"hello world\",\"life is short\""` BoolValue bool `json:"bl" pflag:"true"` diff --git a/cli/pflags/api/templates.go b/cli/pflags/api/templates.go index a7adba8..545d0ce 100644 --- a/cli/pflags/api/templates.go +++ b/cli/pflags/api/templates.go @@ -26,9 +26,25 @@ import ( {{$name}} "{{$path}}"{{end}} ) +// If v is a pointer, it will get its element value or the zero value of the element type. +// If v is not a pointer, it will return it as is. +func ({{ .Name }}) elemValueOrNil(v interface{}) interface{} { + if t := reflect.TypeOf(v); t.Kind() == reflect.Ptr { + if reflect.ValueOf(v).IsNil() { + return reflect.Zero(t.Elem()).Interface() + } else { + return reflect.ValueOf(v).Interface() + } + } else if v == nil { + return reflect.Zero(t).Interface() + } + + return v +} + // GetPFlagSet will return strongly types pflags for all fields in {{ .Name }} and its nested types. The format of the // flags is json-name.json-sub-name... etc. -func ({{ .Name }}) GetPFlagSet(prefix string) *pflag.FlagSet { +func (cfg {{ .Name }}) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags := pflag.NewFlagSet("{{ .Name }}", pflag.ExitOnError) {{- range .Fields }} cmdFlags.{{ .FlagMethodName }}(fmt.Sprintf("%v%v", prefix, "{{ .Name }}"), {{ .DefaultValue }}, {{ .UsageString }}) diff --git a/cli/pflags/api/testdata/testtype.go b/cli/pflags/api/testdata/testtype.go index 87f5cb7..dde0f48 100755 --- a/cli/pflags/api/testdata/testtype.go +++ b/cli/pflags/api/testdata/testtype.go @@ -5,32 +5,49 @@ package api import ( "fmt" + "reflect" "github.com/spf13/pflag" ) +// If v is a pointer, it will get its element value or the zero value of the element type. +// If v is not a pointer, it will return it as is. +func (TestType) elemValueOrNil(v interface{}) interface{} { + if t := reflect.TypeOf(v); t.Kind() == reflect.Ptr { + if reflect.ValueOf(v).IsNil() { + return reflect.Zero(t.Elem()).Interface() + } else { + return reflect.ValueOf(v).Interface() + } + } else if v == nil { + return reflect.Zero(t).Interface() + } + + return v +} + // GetPFlagSet will return strongly types pflags for all fields in TestType and its nested types. The format of the // flags is json-name.json-sub-name... etc. -func (TestType) GetPFlagSet(prefix string) *pflag.FlagSet { +func (cfg TestType) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags := pflag.NewFlagSet("TestType", pflag.ExitOnError) - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "str"), "hello world", "life is short") - cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "bl"), true, "") - cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "nested.i"), *new(int), "this is an important flag") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "str"), DefaultTestType.StringValue, "life is short") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "bl"), DefaultTestType.BoolValue, "") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "nested.i"), DefaultTestType.NestedType.IntValue, "this is an important flag") cmdFlags.IntSlice(fmt.Sprintf("%v%v", prefix, "ints"), []int{12, 1}, "") cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "strs"), []string{"12", "1"}, "") cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "complexArr"), []string{}, "") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "c"), "", "I'm a complex type but can be converted from string.") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.type"), "s3", "Sets the type of storage to configure [s3/minio/local/mem].") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.connection.endpoint"), "", "URL for storage client to connect to.") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.connection.auth-type"), "iam", "Auth Type to use [iam, accesskey].") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.connection.access-key"), *new(string), "Access key to use. Only required when authtype is set to accesskey.") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.connection.secret-key"), *new(string), "Secret to use when accesskey is set.") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.connection.region"), "us-east-1", "Region to connect to.") - cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "storage.connection.disable-ssl"), *new(bool), "Disables SSL connection. Should only be used for development.") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.container"), *new(string), "Initial container to create -if it doesn't exist-.'") - cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "storage.cache.max_size_mbs"), *new(int), "Maximum size of the cache where the Blob store data is cached in-memory. If not specified or set to 0, cache is not used") - cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "storage.cache.target_gc_percent"), *new(int), "Sets the garbage collection target percentage.") - cmdFlags.Int64(fmt.Sprintf("%v%v", prefix, "storage.limits.maxDownloadMBs"), 2, "Maximum allowed download size (in MBs) per call.") - cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "i"), *new(int), "") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "c"), fmt.Sprintf("%v", DefaultTestType.StringToJSON), "I'm a complex type but can be converted from string.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.type"), DefaultTestType.StorageConfig.Type, "Sets the type of storage to configure [s3/minio/local/mem].") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.connection.endpoint"), DefaultTestType.StorageConfig.Connection.Endpoint.String(), "URL for storage client to connect to.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.connection.auth-type"), DefaultTestType.StorageConfig.Connection.AuthType, "Auth Type to use [iam, accesskey].") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.connection.access-key"), DefaultTestType.StorageConfig.Connection.AccessKey, "Access key to use. Only required when authtype is set to accesskey.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.connection.secret-key"), DefaultTestType.StorageConfig.Connection.SecretKey, "Secret to use when accesskey is set.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.connection.region"), DefaultTestType.StorageConfig.Connection.Region, "Region to connect to.") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "storage.connection.disable-ssl"), DefaultTestType.StorageConfig.Connection.DisableSSL, "Disables SSL connection. Should only be used for development.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.container"), DefaultTestType.StorageConfig.InitContainer, "Initial container to create -if it doesn't exist-.'") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "storage.cache.max_size_mbs"), DefaultTestType.StorageConfig.Cache.MaxSizeMegabytes, "Maximum size of the cache where the Blob store data is cached in-memory. If not specified or set to 0, cache is not used") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "storage.cache.target_gc_percent"), DefaultTestType.StorageConfig.Cache.TargetGCPercent, "Sets the garbage collection target percentage.") + cmdFlags.Int64(fmt.Sprintf("%v%v", prefix, "storage.limits.maxDownloadMBs"), DefaultTestType.StorageConfig.Limits.GetLimitMegabytes, "Maximum allowed download size (in MBs) per call.") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "i"), cfg.elemValueOrNil(DefaultTestType.IntValue).(int), "") return cmdFlags } diff --git a/cli/pflags/api/testdata/testtype_test.go b/cli/pflags/api/testdata/testtype_test.go index f8b81bb..03412f0 100755 --- a/cli/pflags/api/testdata/testtype_test.go +++ b/cli/pflags/api/testdata/testtype_test.go @@ -103,7 +103,7 @@ func TestTestType_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vString, err := cmdFlags.GetString("str"); err == nil { - assert.Equal(t, string("hello world"), vString) + assert.Equal(t, string(DefaultTestType.StringValue), vString) } else { assert.FailNow(t, err.Error()) } @@ -125,7 +125,7 @@ func TestTestType_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vBool, err := cmdFlags.GetBool("bl"); err == nil { - assert.Equal(t, bool(true), vBool) + assert.Equal(t, bool(DefaultTestType.BoolValue), vBool) } else { assert.FailNow(t, err.Error()) } @@ -147,7 +147,7 @@ func TestTestType_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vInt, err := cmdFlags.GetInt("nested.i"); err == nil { - assert.Equal(t, int(*new(int)), vInt) + assert.Equal(t, int(DefaultTestType.NestedType.IntValue), vInt) } else { assert.FailNow(t, err.Error()) } @@ -235,7 +235,7 @@ func TestTestType_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vString, err := cmdFlags.GetString("c"); err == nil { - assert.Equal(t, string(""), vString) + assert.Equal(t, string(fmt.Sprintf("%v", DefaultTestType.StringToJSON)), vString) } else { assert.FailNow(t, err.Error()) } @@ -257,7 +257,7 @@ func TestTestType_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vString, err := cmdFlags.GetString("storage.type"); err == nil { - assert.Equal(t, string("s3"), vString) + assert.Equal(t, string(DefaultTestType.StorageConfig.Type), vString) } else { assert.FailNow(t, err.Error()) } @@ -279,7 +279,7 @@ func TestTestType_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vString, err := cmdFlags.GetString("storage.connection.endpoint"); err == nil { - assert.Equal(t, string(""), vString) + assert.Equal(t, string(DefaultTestType.StorageConfig.Connection.Endpoint.String()), vString) } else { assert.FailNow(t, err.Error()) } @@ -301,7 +301,7 @@ func TestTestType_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vString, err := cmdFlags.GetString("storage.connection.auth-type"); err == nil { - assert.Equal(t, string("iam"), vString) + assert.Equal(t, string(DefaultTestType.StorageConfig.Connection.AuthType), vString) } else { assert.FailNow(t, err.Error()) } @@ -323,7 +323,7 @@ func TestTestType_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vString, err := cmdFlags.GetString("storage.connection.access-key"); err == nil { - assert.Equal(t, string(*new(string)), vString) + assert.Equal(t, string(DefaultTestType.StorageConfig.Connection.AccessKey), vString) } else { assert.FailNow(t, err.Error()) } @@ -345,7 +345,7 @@ func TestTestType_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vString, err := cmdFlags.GetString("storage.connection.secret-key"); err == nil { - assert.Equal(t, string(*new(string)), vString) + assert.Equal(t, string(DefaultTestType.StorageConfig.Connection.SecretKey), vString) } else { assert.FailNow(t, err.Error()) } @@ -367,7 +367,7 @@ func TestTestType_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vString, err := cmdFlags.GetString("storage.connection.region"); err == nil { - assert.Equal(t, string("us-east-1"), vString) + assert.Equal(t, string(DefaultTestType.StorageConfig.Connection.Region), vString) } else { assert.FailNow(t, err.Error()) } @@ -389,7 +389,7 @@ func TestTestType_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vBool, err := cmdFlags.GetBool("storage.connection.disable-ssl"); err == nil { - assert.Equal(t, bool(*new(bool)), vBool) + assert.Equal(t, bool(DefaultTestType.StorageConfig.Connection.DisableSSL), vBool) } else { assert.FailNow(t, err.Error()) } @@ -411,7 +411,7 @@ func TestTestType_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vString, err := cmdFlags.GetString("storage.container"); err == nil { - assert.Equal(t, string(*new(string)), vString) + assert.Equal(t, string(DefaultTestType.StorageConfig.InitContainer), vString) } else { assert.FailNow(t, err.Error()) } @@ -433,7 +433,7 @@ func TestTestType_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vInt, err := cmdFlags.GetInt("storage.cache.max_size_mbs"); err == nil { - assert.Equal(t, int(*new(int)), vInt) + assert.Equal(t, int(DefaultTestType.StorageConfig.Cache.MaxSizeMegabytes), vInt) } else { assert.FailNow(t, err.Error()) } @@ -455,7 +455,7 @@ func TestTestType_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vInt, err := cmdFlags.GetInt("storage.cache.target_gc_percent"); err == nil { - assert.Equal(t, int(*new(int)), vInt) + assert.Equal(t, int(DefaultTestType.StorageConfig.Cache.TargetGCPercent), vInt) } else { assert.FailNow(t, err.Error()) } @@ -477,7 +477,7 @@ func TestTestType_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vInt64, err := cmdFlags.GetInt64("storage.limits.maxDownloadMBs"); err == nil { - assert.Equal(t, int64(2), vInt64) + assert.Equal(t, int64(DefaultTestType.StorageConfig.Limits.GetLimitMegabytes), vInt64) } else { assert.FailNow(t, err.Error()) } @@ -499,7 +499,7 @@ func TestTestType_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vInt, err := cmdFlags.GetInt("i"); err == nil { - assert.Equal(t, int(*new(int)), vInt) + assert.Equal(t, int(cfg.elemValueOrNil(DefaultTestType.IntValue).(int)), vInt) } else { assert.FailNow(t, err.Error()) } diff --git a/cli/pflags/api/utils.go b/cli/pflags/api/utils.go index 4c71fbb..16edb50 100644 --- a/cli/pflags/api/utils.go +++ b/cli/pflags/api/utils.go @@ -20,13 +20,29 @@ func camelCase(str string) string { return str } -func jsonUnmarshaler(t types.Type) bool { +func isJSONUnmarshaler(t types.Type) bool { + found, _ := implementsAnyOfMethods(t, "UnmarshalJSON") + return found +} + +func isStringer(t types.Type) bool { + found, _ := implementsAnyOfMethods(t, "String") + return found +} + +func implementsAnyOfMethods(t types.Type, methodNames ...string) (found, implementedByPtr bool) { mset := types.NewMethodSet(t) - jsonUnmarshaler := mset.Lookup(nil, "UnmarshalJSON") - if jsonUnmarshaler == nil { - mset = types.NewMethodSet(types.NewPointer(t)) - jsonUnmarshaler = mset.Lookup(nil, "UnmarshalJSON") + for _, name := range methodNames { + if mset.Lookup(nil, name) != nil { + return true, false + } + } + mset = types.NewMethodSet(types.NewPointer(t)) + for _, name := range methodNames { + if mset.Lookup(nil, name) != nil { + return true, true + } } - return jsonUnmarshaler != nil + return false, false } diff --git a/cli/pflags/cmd/root.go b/cli/pflags/cmd/root.go index b6562d8..d78d4c4 100644 --- a/cli/pflags/cmd/root.go +++ b/cli/pflags/cmd/root.go @@ -3,7 +3,6 @@ package cmd import ( "bytes" "context" - "flag" "fmt" "strings" @@ -13,7 +12,8 @@ import ( ) var ( - pkg = flag.String("pkg", ".", "what package to get the interface from") + pkg string + defaultValuesVariable string ) var root = cobra.Command{ @@ -31,7 +31,8 @@ type MyStruct struct { } func init() { - root.Flags().StringP("package", "p", ".", "Determines the source/destination package.") + root.Flags().StringVarP(&pkg, "package", "p", ".", "Determines the source/destination package.") + root.Flags().StringVar(&defaultValuesVariable, "default-var", "defaultConfig", "Points to a variable to use to load default configs. If specified & found, it'll be used instead of the values specified in the tag.") } func Execute() error { @@ -45,7 +46,7 @@ func generatePflagsProvider(cmd *cobra.Command, args []string) error { } ctx := context.Background() - gen, err := api.NewGenerator(*pkg, structName) + gen, err := api.NewGenerator(pkg, structName, defaultValuesVariable) if err != nil { return err } diff --git a/logger/config.go b/logger/config.go index faa2da6..1c665ac 100644 --- a/logger/config.go +++ b/logger/config.go @@ -6,7 +6,7 @@ import ( "github.com/lyft/flytestdlib/config" ) -//go:generate pflags Config +//go:generate pflags Config --default-var defaultConfig const configSectionKey = "Logger" @@ -21,6 +21,13 @@ const ( jsonDataKey string = "json" ) +var defaultConfig = &Config{ + Formatter: FormatterConfig{ + Type: FormatterJSON, + }, + Level: InfoLevel, +} + // Global logger config. type Config struct { // Determines whether to include source code location in logs. This might incurs a performance hit and is only @@ -31,13 +38,13 @@ type Config struct { Mute bool `json:"mute" pflag:",Mutes all logs regardless of severity. Intended for benchmarks/tests only."` // Determines the minimum log level to log. - Level Level `json:"level" pflag:"4,Sets the minimum logging level."` + Level Level `json:"level" pflag:",Sets the minimum logging level."` Formatter FormatterConfig `json:"formatter" pflag:",Sets logging format."` } type FormatterConfig struct { - Type FormatterType `json:"type" pflag:"\"json\",Sets logging format type."` + Type FormatterType `json:"type" pflag:",Sets logging format type."` } var globalConfig = Config{} @@ -73,9 +80,7 @@ const ( ) func init() { - if _, err := config.RegisterSectionWithUpdates(configSectionKey, &Config{}, func(ctx context.Context, newValue config.Config) { + config.MustRegisterSectionWithUpdates(configSectionKey, defaultConfig, func(ctx context.Context, newValue config.Config) { SetConfig(*newValue.(*Config)) - }); err != nil { - panic(err) - } + }) } diff --git a/logger/config_flags.go b/logger/config_flags.go index cf8950b..27be2e3 100755 --- a/logger/config_flags.go +++ b/logger/config_flags.go @@ -5,17 +5,34 @@ package logger import ( "fmt" + "reflect" "github.com/spf13/pflag" ) +// If v is a pointer, it will get its element value or the zero value of the element type. +// If v is not a pointer, it will return it as is. +func (Config) elemValueOrNil(v interface{}) interface{} { + if t := reflect.TypeOf(v); t.Kind() == reflect.Ptr { + if reflect.ValueOf(v).IsNil() { + return reflect.Zero(t.Elem()).Interface() + } else { + return reflect.ValueOf(v).Interface() + } + } else if v == nil { + return reflect.Zero(t).Interface() + } + + return v +} + // GetPFlagSet will return strongly types pflags for all fields in Config and its nested types. The format of the // flags is json-name.json-sub-name... etc. -func (Config) GetPFlagSet(prefix string) *pflag.FlagSet { +func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags := pflag.NewFlagSet("Config", pflag.ExitOnError) - cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "show-source"), *new(bool), "Includes source code location in logs.") - cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "mute"), *new(bool), "Mutes all logs regardless of severity. Intended for benchmarks/tests only.") - cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "level"), 4, "Sets the minimum logging level.") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "formatter.type"), "json", "Sets logging format type.") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "show-source"), defaultConfig.IncludeSourceCode, "Includes source code location in logs.") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "mute"), defaultConfig.Mute, "Mutes all logs regardless of severity. Intended for benchmarks/tests only.") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "level"), defaultConfig.Level, "Sets the minimum logging level.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "formatter.type"), defaultConfig.Formatter.Type, "Sets logging format type.") return cmdFlags } diff --git a/logger/config_flags_test.go b/logger/config_flags_test.go index 401d58d..853aeac 100755 --- a/logger/config_flags_test.go +++ b/logger/config_flags_test.go @@ -103,7 +103,7 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vBool, err := cmdFlags.GetBool("show-source"); err == nil { - assert.Equal(t, bool(*new(bool)), vBool) + assert.Equal(t, bool(defaultConfig.IncludeSourceCode), vBool) } else { assert.FailNow(t, err.Error()) } @@ -125,7 +125,7 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vBool, err := cmdFlags.GetBool("mute"); err == nil { - assert.Equal(t, bool(*new(bool)), vBool) + assert.Equal(t, bool(defaultConfig.Mute), vBool) } else { assert.FailNow(t, err.Error()) } @@ -147,7 +147,7 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vInt, err := cmdFlags.GetInt("level"); err == nil { - assert.Equal(t, int(4), vInt) + assert.Equal(t, int(defaultConfig.Level), vInt) } else { assert.FailNow(t, err.Error()) } @@ -169,7 +169,7 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vString, err := cmdFlags.GetString("formatter.type"); err == nil { - assert.Equal(t, string("json"), vString) + assert.Equal(t, string(defaultConfig.Formatter.Type), vString) } else { assert.FailNow(t, err.Error()) } @@ -188,3 +188,36 @@ func TestConfig_SetFlags(t *testing.T) { }) }) } + +func TestConfig_elemValueOrNil(t *testing.T) { + type fields struct { + IncludeSourceCode bool + Mute bool + Level Level + Formatter FormatterConfig + } + type args struct { + v interface{} + } + tests := []struct { + name string + fields fields + args args + want interface{} + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := Config{ + IncludeSourceCode: tt.fields.IncludeSourceCode, + Mute: tt.fields.Mute, + Level: tt.fields.Level, + Formatter: tt.fields.Formatter, + } + if got := c.elemValueOrNil(tt.args.v); !reflect.DeepEqual(got, tt.want) { + t.Errorf("Config.elemValueOrNil() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/profutils/server_test.go b/profutils/server_test.go index e2eb709..a7ef350 100644 --- a/profutils/server_test.go +++ b/profutils/server_test.go @@ -70,9 +70,9 @@ func TestConfigHandler(t *testing.T) { "logger": map[string]interface{}{ "show-source": false, "mute": false, - "level": float64(0), + "level": float64(4), "formatter": map[string]interface{}{ - "type": "", + "type": "json", }, }, }, m) diff --git a/storage/config.go b/storage/config.go index 59db019..780ef12 100644 --- a/storage/config.go +++ b/storage/config.go @@ -28,12 +28,22 @@ const ( ) var ( - ConfigSection = config.MustRegisterSection(configSectionKey, &Config{}) + ConfigSection = config.MustRegisterSection(configSectionKey, defaultConfig) + defaultConfig = &Config{ + Type: TypeS3, + Limits: LimitsConfig{ + GetLimitMegabytes: 2, + }, + Connection: ConnectionConfig{ + Region: "us-east-1", + AuthType: "iam", + }, + } ) // A common storage config. type Config struct { - Type Type `json:"type" pflag:"\"s3\",Sets the type of storage to configure [s3/minio/local/mem]."` + Type Type `json:"type" pflag:",Sets the type of storage to configure [s3/minio/local/mem]."` Connection ConnectionConfig `json:"connection"` InitContainer string `json:"container" pflag:",Initial container to create -if it doesn't exist-.'"` // Caching is recommended to improve the performance of underlying systems. It caches the metadata and resolving @@ -47,10 +57,10 @@ type Config struct { // Defines connection configurations. type ConnectionConfig struct { Endpoint config.URL `json:"endpoint" pflag:",URL for storage client to connect to."` - AuthType string `json:"auth-type" pflag:"\"iam\",Auth Type to use [iam,accesskey]."` + AuthType string `json:"auth-type" pflag:",Auth Type to use [iam,accesskey]."` AccessKey string `json:"access-key" pflag:",Access key to use. Only required when authtype is set to accesskey."` SecretKey string `json:"secret-key" pflag:",Secret to use when accesskey is set."` - Region string `json:"region" pflag:"\"us-east-1\",Region to connect to."` + Region string `json:"region" pflag:",Region to connect to."` DisableSSL bool `json:"disable-ssl" pflag:",Disables SSL connection. Should only be used for development."` } @@ -71,7 +81,7 @@ type CachingConfig struct { // Specifies limits for storage package. type LimitsConfig struct { - GetLimitMegabytes int64 `json:"maxDownloadMBs" pflag:"2,Maximum allowed download size (in MBs) per call."` + GetLimitMegabytes int64 `json:"maxDownloadMBs" pflag:",Maximum allowed download size (in MBs) per call."` } // Retrieve current global config for storage. diff --git a/storage/config_flags.go b/storage/config_flags.go index 9a74efd..4cde49b 100755 --- a/storage/config_flags.go +++ b/storage/config_flags.go @@ -5,24 +5,41 @@ package storage import ( "fmt" + "reflect" "github.com/spf13/pflag" ) +// If v is a pointer, it will get its element value or the zero value of the element type. +// If v is not a pointer, it will return it as is. +func (Config) elemValueOrNil(v interface{}) interface{} { + if t := reflect.TypeOf(v); t.Kind() == reflect.Ptr { + if reflect.ValueOf(v).IsNil() { + return reflect.Zero(t.Elem()).Interface() + } else { + return reflect.ValueOf(v).Interface() + } + } else if v == nil { + return reflect.Zero(t).Interface() + } + + return v +} + // GetPFlagSet will return strongly types pflags for all fields in Config and its nested types. The format of the // flags is json-name.json-sub-name... etc. -func (Config) GetPFlagSet(prefix string) *pflag.FlagSet { +func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags := pflag.NewFlagSet("Config", pflag.ExitOnError) - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "type"), "s3", "Sets the type of storage to configure [s3/minio/local/mem].") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "connection.endpoint"), "", "URL for storage client to connect to.") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "connection.auth-type"), "iam", "Auth Type to use [iam, accesskey].") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "connection.access-key"), *new(string), "Access key to use. Only required when authtype is set to accesskey.") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "connection.secret-key"), *new(string), "Secret to use when accesskey is set.") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "connection.region"), "us-east-1", "Region to connect to.") - cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "connection.disable-ssl"), *new(bool), "Disables SSL connection. Should only be used for development.") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "container"), *new(string), "Initial container to create -if it doesn't exist-.'") - cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "cache.max_size_mbs"), *new(int), "Maximum size of the cache where the Blob store data is cached in-memory. If not specified or set to 0, cache is not used") - cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "cache.target_gc_percent"), *new(int), "Sets the garbage collection target percentage.") - cmdFlags.Int64(fmt.Sprintf("%v%v", prefix, "limits.maxDownloadMBs"), 2, "Maximum allowed download size (in MBs) per call.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "type"), defaultConfig.Type, "Sets the type of storage to configure [s3/minio/local/mem].") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "connection.endpoint"), defaultConfig.Connection.Endpoint.String(), "URL for storage client to connect to.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "connection.auth-type"), defaultConfig.Connection.AuthType, "Auth Type to use [iam, accesskey].") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "connection.access-key"), defaultConfig.Connection.AccessKey, "Access key to use. Only required when authtype is set to accesskey.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "connection.secret-key"), defaultConfig.Connection.SecretKey, "Secret to use when accesskey is set.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "connection.region"), defaultConfig.Connection.Region, "Region to connect to.") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "connection.disable-ssl"), defaultConfig.Connection.DisableSSL, "Disables SSL connection. Should only be used for development.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "container"), defaultConfig.InitContainer, "Initial container to create -if it doesn't exist-.'") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "cache.max_size_mbs"), defaultConfig.Cache.MaxSizeMegabytes, "Maximum size of the cache where the Blob store data is cached in-memory. If not specified or set to 0, cache is not used") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "cache.target_gc_percent"), defaultConfig.Cache.TargetGCPercent, "Sets the garbage collection target percentage.") + cmdFlags.Int64(fmt.Sprintf("%v%v", prefix, "limits.maxDownloadMBs"), defaultConfig.Limits.GetLimitMegabytes, "Maximum allowed download size (in MBs) per call.") return cmdFlags } diff --git a/storage/config_flags_test.go b/storage/config_flags_test.go index 2f39f00..429af71 100755 --- a/storage/config_flags_test.go +++ b/storage/config_flags_test.go @@ -103,7 +103,7 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vString, err := cmdFlags.GetString("type"); err == nil { - assert.Equal(t, string("s3"), vString) + assert.Equal(t, string(defaultConfig.Type), vString) } else { assert.FailNow(t, err.Error()) } @@ -125,7 +125,7 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vString, err := cmdFlags.GetString("connection.endpoint"); err == nil { - assert.Equal(t, string(""), vString) + assert.Equal(t, string(defaultConfig.Connection.Endpoint.String()), vString) } else { assert.FailNow(t, err.Error()) } @@ -147,7 +147,7 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vString, err := cmdFlags.GetString("connection.auth-type"); err == nil { - assert.Equal(t, string("iam"), vString) + assert.Equal(t, string(defaultConfig.Connection.AuthType), vString) } else { assert.FailNow(t, err.Error()) } @@ -169,7 +169,7 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vString, err := cmdFlags.GetString("connection.access-key"); err == nil { - assert.Equal(t, string(*new(string)), vString) + assert.Equal(t, string(defaultConfig.Connection.AccessKey), vString) } else { assert.FailNow(t, err.Error()) } @@ -191,7 +191,7 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vString, err := cmdFlags.GetString("connection.secret-key"); err == nil { - assert.Equal(t, string(*new(string)), vString) + assert.Equal(t, string(defaultConfig.Connection.SecretKey), vString) } else { assert.FailNow(t, err.Error()) } @@ -213,7 +213,7 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vString, err := cmdFlags.GetString("connection.region"); err == nil { - assert.Equal(t, string("us-east-1"), vString) + assert.Equal(t, string(defaultConfig.Connection.Region), vString) } else { assert.FailNow(t, err.Error()) } @@ -235,7 +235,7 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vBool, err := cmdFlags.GetBool("connection.disable-ssl"); err == nil { - assert.Equal(t, bool(*new(bool)), vBool) + assert.Equal(t, bool(defaultConfig.Connection.DisableSSL), vBool) } else { assert.FailNow(t, err.Error()) } @@ -257,7 +257,7 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vString, err := cmdFlags.GetString("container"); err == nil { - assert.Equal(t, string(*new(string)), vString) + assert.Equal(t, string(defaultConfig.InitContainer), vString) } else { assert.FailNow(t, err.Error()) } @@ -279,7 +279,7 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vInt, err := cmdFlags.GetInt("cache.max_size_mbs"); err == nil { - assert.Equal(t, int(*new(int)), vInt) + assert.Equal(t, int(defaultConfig.Cache.MaxSizeMegabytes), vInt) } else { assert.FailNow(t, err.Error()) } @@ -301,7 +301,7 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vInt, err := cmdFlags.GetInt("cache.target_gc_percent"); err == nil { - assert.Equal(t, int(*new(int)), vInt) + assert.Equal(t, int(defaultConfig.Cache.TargetGCPercent), vInt) } else { assert.FailNow(t, err.Error()) } @@ -323,7 +323,7 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vInt64, err := cmdFlags.GetInt64("limits.maxDownloadMBs"); err == nil { - assert.Equal(t, int64(2), vInt64) + assert.Equal(t, int64(defaultConfig.Limits.GetLimitMegabytes), vInt64) } else { assert.FailNow(t, err.Error()) } From f5dbeccb9ced31f067c02586d7f57126a4a7a983 Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Wed, 17 Apr 2019 11:37:50 -0700 Subject: [PATCH 02/13] Add lods to debug test concurrent issue --- config/tests/accessor_test.go | 42 ++++++++++++++++++++++++++--------- 1 file changed, 31 insertions(+), 11 deletions(-) diff --git a/config/tests/accessor_test.go b/config/tests/accessor_test.go index 34d8623..702c9fc 100644 --- a/config/tests/accessor_test.go +++ b/config/tests/accessor_test.go @@ -30,6 +30,11 @@ import ( type accessorCreatorFn func(registry config.Section, configPath string) config.Accessor +type testLogger interface { + Logf(format string, args ...interface{}) + Errorf(format string, args ...interface{}) +} + func getRandInt() uint64 { c := 10 b := make([]byte, c) @@ -323,7 +328,7 @@ func TestAccessor_UpdateConfig(t *testing.T) { r := reg.GetSection(MyComponentSectionKey).GetConfig().(*MyComponentConfig) firstValue := r.StringValue - fileUpdated, err := beginWaitForFileChange(configFile) + fileUpdated, err := beginWaitForFileChange(t, configFile) assert.NoError(t, err) _, err = populateConfigData(configFile) @@ -346,7 +351,7 @@ func TestAccessor_UpdateConfig(t *testing.T) { // Independently watch for when symlink underlying change happens to know when do we expect accessor to have picked up // the changes - fileUpdated, err := beginWaitForFileChange(configFile) + fileUpdated, err := beginWaitForFileChange(t, configFile) assert.NoError(t, err) // 2. Start accessor with the symlink as config location @@ -383,7 +388,7 @@ func TestAccessor_UpdateConfig(t *testing.T) { // Wait for filewatcher event assert.NoError(t, waitForFileChangeOrTimeout(fileUpdated)) - time.Sleep(2 * time.Second) + time.Sleep(5 * time.Second) r = section.GetConfig().(*MyComponentConfig) secondValue := r.StringValue @@ -462,7 +467,7 @@ func waitForFileChangeOrTimeout(done chan error) error { } } -func beginWaitForFileChange(filename string) (done chan error, terminalErr error) { +func beginWaitForFileChange(logger testLogger, filename string) (done chan error, terminalErr error) { watcher, err := fsnotify.NewWatcher() if err != nil { return nil, err @@ -480,12 +485,21 @@ func beginWaitForFileChange(filename string) (done chan error, terminalErr error go func() { for { select { - case event := <-watcher.Events: + case event, channelOpen := <-watcher.Events: + if !channelOpen { + logger.Logf("Events Channel has been closed") + done <- nil + return + } + + logger.Logf("Received watcher event [%v], %v", event) // we only care about the config file currentConfigFile, err := filepath.EvalSymlinks(filename) if err != nil { + logger.Errorf("Failed to EvalSymLinks. Will attempt to close watcher now. Error: %v", err) closeErr := watcher.Close() if closeErr != nil { + logger.Errorf("Failed to close watcher. Error: %v", closeErr) done <- closeErr } else { done <- err @@ -501,10 +515,12 @@ func beginWaitForFileChange(filename string) (done chan error, terminalErr error if (filepath.Clean(event.Name) == configFile && event.Op&writeOrCreateMask != 0) || (currentConfigFile != "" && currentConfigFile != realConfigFile) { + + logger.Logf("CurrentConfigFile [%v], RealConfigFile [%v]", currentConfigFile, realConfigFile) realConfigFile = currentConfigFile closeErr := watcher.Close() if closeErr != nil { - fmt.Printf("Close Watcher error: %v\n", closeErr) + logger.Errorf("Failed to close watcher. Error: %v", closeErr) } else { done <- nil } @@ -512,21 +528,25 @@ func beginWaitForFileChange(filename string) (done chan error, terminalErr error return } else if filepath.Clean(event.Name) == configFile && event.Op&fsnotify.Remove&fsnotify.Remove != 0 { + + logger.Logf("ConfigFile [%v] Removed.", configFile) closeErr := watcher.Close() if closeErr != nil { - fmt.Printf("Close Watcher error: %v\n", closeErr) + logger.Logf("Close Watcher error: %v", closeErr) } else { done <- nil } return } - case err, ok := <-watcher.Errors: - if ok { - fmt.Printf("Watcher error: %v\n", err) + case err, channelOpen := <-watcher.Errors: + if !channelOpen { + logger.Logf("Error Channel has been closed.") + } else { + logger.Logf("Watcher error: %v", err) closeErr := watcher.Close() if closeErr != nil { - fmt.Printf("Close Watcher error: %v\n", closeErr) + logger.Logf("Close Watcher error: %v\n", closeErr) } } From b65e1f1b31608dd0d06965e8d00209c00fcb6d37 Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Wed, 17 Apr 2019 16:29:41 -0700 Subject: [PATCH 03/13] atomic symlink --- config/tests/accessor_test.go | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/config/tests/accessor_test.go b/config/tests/accessor_test.go index 702c9fc..5bffe86 100644 --- a/config/tests/accessor_test.go +++ b/config/tests/accessor_test.go @@ -399,8 +399,8 @@ func TestAccessor_UpdateConfig(t *testing.T) { } func changeSymLink(targetPath, symLink string) error { + tmpLink := tempFileName("temp-sym-link-*") if runtime.GOOS == "windows" { - tmpLink := tempFileName("temp-sym-link-*") err := exec.Command("mklink", filepath.Clean(tmpLink), filepath.Clean(targetPath)).Run() if err != nil { return err @@ -414,7 +414,15 @@ func changeSymLink(targetPath, symLink string) error { return exec.Command("del", filepath.Clean(tmpLink)).Run() } - return exec.Command("ln", "-sfn", filepath.Clean(targetPath), filepath.Clean(symLink)).Run() + // ln -sfn is not an atomic operation. Under the hood, it first calls the system unlink then symlink calls. During + // that, there will be a brief moment when there is no symlink at all. mv operation is, however, atomic. That's + // why we make this command instead + err := exec.Command("ln", "-s", filepath.Clean(targetPath), filepath.Clean(tmpLink)).Run() + if err != nil { + return err + } + + return exec.Command("mv", "-Tf", filepath.Clean(tmpLink), filepath.Clean(symLink)).Run() } // 1. Create Dir structure: @@ -492,7 +500,7 @@ func beginWaitForFileChange(logger testLogger, filename string) (done chan error return } - logger.Logf("Received watcher event [%v], %v", event) + logger.Logf("Received watcher event [%v]", event) // we only care about the config file currentConfigFile, err := filepath.EvalSymlinks(filename) if err != nil { From d64b5d268b427005c8306cb39a8a3592c6b5098d Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Wed, 17 Apr 2019 23:29:42 -0700 Subject: [PATCH 04/13] simplify tests --- cli/pflags/api/generator.go | 43 ++++++-- cli/pflags/api/generator_test.go | 4 +- cli/pflags/api/sample.go | 1 + config/tests/accessor_test.go | 167 +++---------------------------- 4 files changed, 51 insertions(+), 164 deletions(-) diff --git a/cli/pflags/api/generator.go b/cli/pflags/api/generator.go index 0ee2a61..c6ea06f 100644 --- a/cli/pflags/api/generator.go +++ b/cli/pflags/api/generator.go @@ -5,6 +5,7 @@ import ( "fmt" "go/types" "path/filepath" + "strings" "github.com/lyft/flytestdlib/logger" @@ -86,17 +87,35 @@ func buildFieldForSlice(ctx context.Context, t SliceOrArray, name, goName, usage }, nil } -func appendAccessorIfNotEmpty(baseAccessor, childAccessor string) string { - if len(baseAccessor) == 0 { - return baseAccessor +// Appends field accessors using "." as the delimiter. +// e.g. appendAccessors("var1", "field1", "subField") will output "var1.field1.subField" +func appendAccessors(accessors ...string) string { + sb := strings.Builder{} + switch len(accessors) { + case 0: + return "" + case 1: + return accessors[0] } - return baseAccessor + "." + childAccessor + for _, s := range accessors { + if len(s) > 0 { + if sb.Len() > 0 { + sb.WriteString(".") + } + + sb.WriteString(s) + } + } + + return sb.String() } // Traverses fields in type and follows recursion tree to discover all fields. It stops when one of two conditions is // met; encountered a basic type (e.g. string, int... etc.) or the field type implements UnmarshalJSON. -func discoverFieldsRecursive(ctx context.Context, typ *types.Named, defaultValueAccessor string) ([]FieldInfo, error) { +// If passed a non-empty defaultValueAccessor, it'll be used to fill in default values instead of any default value +// specified in pflag tag. +func discoverFieldsRecursive(ctx context.Context, typ *types.Named, defaultValueAccessor, fieldPath string) ([]FieldInfo, error) { logger.Printf(ctx, "Finding all fields in [%v.%v.%v]", typ.Obj().Pkg().Path(), typ.Obj().Pkg().Name(), typ.Obj().Name()) @@ -149,8 +168,8 @@ func discoverFieldsRecursive(ctx context.Context, typ *types.Named, defaultValue } defaultValue := tag.DefaultValue - if accessor := appendAccessorIfNotEmpty(defaultValueAccessor, v.Name()); len(accessor) > 0 { - defaultValue = accessor + if len(defaultValueAccessor) > 0 { + defaultValue = appendAccessors(defaultValueAccessor, fieldPath, v.Name()) if isPtr { defaultValue = fmt.Sprintf("cfg.elemValueOrNil(%s).(%s)", defaultValue, t.Name()) @@ -184,11 +203,13 @@ func discoverFieldsRecursive(ctx context.Context, typ *types.Named, defaultValue } defaultValue := tag.DefaultValue - if accessor := appendAccessorIfNotEmpty(defaultValueAccessor, v.Name()); len(accessor) > 0 { - defaultValue = accessor + if len(defaultValueAccessor) > 0 { + defaultValue = appendAccessors(defaultValueAccessor, fieldPath, v.Name()) if isStringer(t) { defaultValue = defaultValue + ".String()" } else { + logger.Infof(ctx, "Field [%v] of type [%v] does not implement Stringer interface."+ + " Will use fmt.Sprintf() to get its default value.", v.Name(), t.String()) defaultValue = fmt.Sprintf("fmt.Sprintf(\"%%v\",%s)", defaultValue) } } @@ -211,7 +232,7 @@ func discoverFieldsRecursive(ctx context.Context, typ *types.Named, defaultValue } else { logger.Infof(ctx, "Traversing fields in type.") - nested, err := discoverFieldsRecursive(logger.WithIndent(ctx, indent), t, appendAccessorIfNotEmpty(defaultValueAccessor, v.Name())) + nested, err := discoverFieldsRecursive(logger.WithIndent(ctx, indent), t, defaultValueAccessor, appendAccessors(fieldPath, v.Name())) if err != nil { return nil, err } @@ -317,7 +338,7 @@ func (g PFlagProviderGenerator) Generate(ctx context.Context) (PFlagProvider, er defaultValueAccessor = g.defaultVar.Name() } - fields, err := discoverFieldsRecursive(ctx, g.st, defaultValueAccessor) + fields, err := discoverFieldsRecursive(ctx, g.st, defaultValueAccessor, "") if err != nil { return PFlagProvider{}, err } diff --git a/cli/pflags/api/generator_test.go b/cli/pflags/api/generator_test.go index b1d7c61..26f77a6 100644 --- a/cli/pflags/api/generator_test.go +++ b/cli/pflags/api/generator_test.go @@ -21,9 +21,9 @@ func elemValueOrNil(v interface{}) interface{} { if t := reflect.TypeOf(v); t.Kind() == reflect.Ptr { if reflect.ValueOf(v).IsNil() { return reflect.Zero(t.Elem()).Interface() - } else { - return reflect.ValueOf(v).Interface() } + + return reflect.ValueOf(v).Interface() } else if v == nil { return reflect.Zero(t).Interface() } diff --git a/cli/pflags/api/sample.go b/cli/pflags/api/sample.go index 43dfce7..f1b40e5 100644 --- a/cli/pflags/api/sample.go +++ b/cli/pflags/api/sample.go @@ -3,6 +3,7 @@ package api import ( "encoding/json" "errors" + "github.com/lyft/flytestdlib/storage" ) diff --git a/config/tests/accessor_test.go b/config/tests/accessor_test.go index 5bffe86..d8c9e2a 100644 --- a/config/tests/accessor_test.go +++ b/config/tests/accessor_test.go @@ -16,8 +16,6 @@ import ( "testing" "time" - "github.com/fsnotify/fsnotify" - k8sRand "k8s.io/apimachinery/pkg/util/rand" "github.com/lyft/flytestdlib/config" @@ -30,11 +28,6 @@ import ( type accessorCreatorFn func(registry config.Section, configPath string) config.Accessor -type testLogger interface { - Logf(format string, args ...interface{}) - Errorf(format string, args ...interface{}) -} - func getRandInt() uint64 { c := 10 b := make([]byte, c) @@ -328,16 +321,11 @@ func TestAccessor_UpdateConfig(t *testing.T) { r := reg.GetSection(MyComponentSectionKey).GetConfig().(*MyComponentConfig) firstValue := r.StringValue - fileUpdated, err := beginWaitForFileChange(t, configFile) - assert.NoError(t, err) - _, err = populateConfigData(configFile) assert.NoError(t, err) - // Simulate filewatcher event - assert.NoError(t, waitForFileChangeOrTimeout(fileUpdated)) - - time.Sleep(2 * time.Second) + // Wait enough for the file change notification to propagate. + time.Sleep(5 * time.Second) r = reg.GetSection(MyComponentSectionKey).GetConfig().(*MyComponentConfig) secondValue := r.StringValue @@ -345,20 +333,17 @@ func TestAccessor_UpdateConfig(t *testing.T) { }) t.Run(fmt.Sprintf("[%v] Change handler k8s configmaps", provider(config.Options{}).ID()), func(t *testing.T) { + reg := config.NewRootSection() + section, err := reg.RegisterSection(MyComponentSectionKey, &MyComponentConfig{}) + assert.NoError(t, err) + + var firstValue string + // 1. Create Dir structure watchDir, configFile, cleanup := newSymlinkedConfigFile(t) defer cleanup() - // Independently watch for when symlink underlying change happens to know when do we expect accessor to have picked up - // the changes - fileUpdated, err := beginWaitForFileChange(t, configFile) - assert.NoError(t, err) - // 2. Start accessor with the symlink as config location - reg := config.NewRootSection() - section, err := reg.RegisterSection(MyComponentSectionKey, &MyComponentConfig{}) - assert.NoError(t, err) - opts := config.Options{ SearchPaths: []string{configFile}, RootSection: reg, @@ -368,7 +353,8 @@ func TestAccessor_UpdateConfig(t *testing.T) { assert.NoError(t, err) r := section.GetConfig().(*MyComponentConfig) - firstValue := r.StringValue + firstValue = r.StringValue + t.Logf("First value: %v", firstValue) // 3. Now update /data symlink to point to data2 dataDir2 := path.Join(watchDir, "data2") @@ -376,8 +362,9 @@ func TestAccessor_UpdateConfig(t *testing.T) { assert.NoError(t, err) configFile2 := path.Join(dataDir2, "config.yaml") - _, err = populateConfigData(configFile2) + newData, err := populateConfigData(configFile2) assert.NoError(t, err) + t.Logf("New value written to file: %v", newData.MyComponentConfig.StringValue) // change the symlink using the `ln -sfn` command err = changeSymLink(dataDir2, path.Join(watchDir, "data")) @@ -385,9 +372,6 @@ func TestAccessor_UpdateConfig(t *testing.T) { t.Logf("New config Location: %v", configFile2) - // Wait for filewatcher event - assert.NoError(t, waitForFileChangeOrTimeout(fileUpdated)) - time.Sleep(5 * time.Second) r = section.GetConfig().(*MyComponentConfig) @@ -414,15 +398,9 @@ func changeSymLink(targetPath, symLink string) error { return exec.Command("del", filepath.Clean(tmpLink)).Run() } - // ln -sfn is not an atomic operation. Under the hood, it first calls the system unlink then symlink calls. During - // that, there will be a brief moment when there is no symlink at all. mv operation is, however, atomic. That's - // why we make this command instead - err := exec.Command("ln", "-s", filepath.Clean(targetPath), filepath.Clean(tmpLink)).Run() - if err != nil { - return err - } - - return exec.Command("mv", "-Tf", filepath.Clean(tmpLink), filepath.Clean(symLink)).Run() + //// ln -sfn is not an atomic operation. Under the hood, it first calls the system unlink then symlink calls. During + //// that, there will be a brief moment when there is no symlink at all. + return exec.Command("ln", "-sfn", filepath.Clean(targetPath), filepath.Clean(symLink)).Run() } // 1. Create Dir structure: @@ -444,6 +422,7 @@ func newSymlinkedConfigFile(t *testing.T) (watchDir, configFile string, cleanup assert.NoError(t, err) cleanup = func() { + t.Logf("Removing watchDir [%v]", watchDir) assert.NoError(t, os.RemoveAll(watchDir)) } @@ -458,120 +437,6 @@ func newSymlinkedConfigFile(t *testing.T) (watchDir, configFile string, cleanup return watchDir, configFile, cleanup } -func waitForFileChangeOrTimeout(done chan error) error { - timeout := make(chan bool, 1) - go func() { - time.Sleep(5 * time.Second) - timeout <- true - }() - - for { - select { - case <-timeout: - return fmt.Errorf("timed out") - case err := <-done: - return err - } - } -} - -func beginWaitForFileChange(logger testLogger, filename string) (done chan error, terminalErr error) { - watcher, err := fsnotify.NewWatcher() - if err != nil { - return nil, err - } - - configFile := filepath.Clean(filename) - realConfigFile, err := filepath.EvalSymlinks(configFile) - if err != nil { - return nil, err - } - - configDir, _ := filepath.Split(configFile) - - done = make(chan error) - go func() { - for { - select { - case event, channelOpen := <-watcher.Events: - if !channelOpen { - logger.Logf("Events Channel has been closed") - done <- nil - return - } - - logger.Logf("Received watcher event [%v]", event) - // we only care about the config file - currentConfigFile, err := filepath.EvalSymlinks(filename) - if err != nil { - logger.Errorf("Failed to EvalSymLinks. Will attempt to close watcher now. Error: %v", err) - closeErr := watcher.Close() - if closeErr != nil { - logger.Errorf("Failed to close watcher. Error: %v", closeErr) - done <- closeErr - } else { - done <- err - } - - return - } - - // We only care about the config file with the following cases: - // 1 - if the config file was modified or created - // 2 - if the real path to the config file changed (eg: k8s ConfigMap replacement) - const writeOrCreateMask = fsnotify.Write | fsnotify.Create - if (filepath.Clean(event.Name) == configFile && - event.Op&writeOrCreateMask != 0) || - (currentConfigFile != "" && currentConfigFile != realConfigFile) { - - logger.Logf("CurrentConfigFile [%v], RealConfigFile [%v]", currentConfigFile, realConfigFile) - realConfigFile = currentConfigFile - closeErr := watcher.Close() - if closeErr != nil { - logger.Errorf("Failed to close watcher. Error: %v", closeErr) - } else { - done <- nil - } - - return - } else if filepath.Clean(event.Name) == configFile && - event.Op&fsnotify.Remove&fsnotify.Remove != 0 { - - logger.Logf("ConfigFile [%v] Removed.", configFile) - closeErr := watcher.Close() - if closeErr != nil { - logger.Logf("Close Watcher error: %v", closeErr) - } else { - done <- nil - } - - return - } - case err, channelOpen := <-watcher.Errors: - if !channelOpen { - logger.Logf("Error Channel has been closed.") - } else { - logger.Logf("Watcher error: %v", err) - closeErr := watcher.Close() - if closeErr != nil { - logger.Logf("Close Watcher error: %v\n", closeErr) - } - } - - done <- nil - return - } - } - }() - - err = watcher.Add(configDir) - if err != nil { - return nil, err - } - - return done, err -} - func testTypes(accessor accessorCreatorFn) func(t *testing.T) { return func(t *testing.T) { t.Run("ArrayConfigType", func(t *testing.T) { From 32abd4decd765e049d7ba604f7b6f9e1a738edc5 Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Thu, 18 Apr 2019 00:07:53 -0700 Subject: [PATCH 05/13] Add marshal utils --- utils/marshal_utils.go | 66 +++++++++++++++ utils/marshal_utils_test.go | 165 ++++++++++++++++++++++++++++++++++++ 2 files changed, 231 insertions(+) create mode 100644 utils/marshal_utils.go create mode 100644 utils/marshal_utils_test.go diff --git a/utils/marshal_utils.go b/utils/marshal_utils.go new file mode 100644 index 0000000..4a43410 --- /dev/null +++ b/utils/marshal_utils.go @@ -0,0 +1,66 @@ +package utils + +import ( + "encoding/json" + "fmt" + + "github.com/golang/protobuf/jsonpb" + "github.com/golang/protobuf/proto" + structpb "github.com/golang/protobuf/ptypes/struct" +) + +var jsonPbMarshaler = jsonpb.Marshaler{} + +func UnmarshalStructToPb(structObj *structpb.Struct, msg proto.Message) error { + if structObj == nil { + return fmt.Errorf("nil Struct Object passed") + } + + jsonObj, err := jsonPbMarshaler.MarshalToString(structObj) + if err != nil { + return err + } + + if err = jsonpb.UnmarshalString(jsonObj, msg); err != nil { + return err + } + + return nil +} + +func MarshalPbToStruct(in proto.Message, out *structpb.Struct) error { + if out == nil { + return fmt.Errorf("nil Struct Object passed") + } + + jsonObj, err := jsonPbMarshaler.MarshalToString(in) + if err != nil { + return err + } + + if err = jsonpb.UnmarshalString(jsonObj, out); err != nil { + return err + } + + return nil +} + +func MarshalPbToString(msg proto.Message) (string, error) { + return jsonPbMarshaler.MarshalToString(msg) +} + +// TODO: Use the stdlib version in the future, or move there if not there. +// Don't use this if input is a proto Message. +func MarshalObjToStruct(input interface{}) (*structpb.Struct, error) { + b, err := json.Marshal(input) + if err != nil { + return nil, err + } + + // Turn JSON into a protobuf struct + structObj := &structpb.Struct{} + if err := jsonpb.UnmarshalString(string(b), structObj); err != nil { + return nil, err + } + return structObj, nil +} diff --git a/utils/marshal_utils_test.go b/utils/marshal_utils_test.go new file mode 100644 index 0000000..33d185f --- /dev/null +++ b/utils/marshal_utils_test.go @@ -0,0 +1,165 @@ +package utils + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/golang/protobuf/proto" + structpb "github.com/golang/protobuf/ptypes/struct" +) + +// Simple proto +type TestProto struct { + StringValue string `protobuf:"bytes,3,opt,name=string_value,json=stringValue,proto3" json:"string_value,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *TestProto) Reset() { *m = TestProto{} } +func (m *TestProto) String() string { return proto.CompactTextString(m) } +func (*TestProto) ProtoMessage() {} +func (*TestProto) Descriptor() ([]byte, []int) { + return []byte{}, []int{0} +} +func (m *TestProto) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_TestProto.Unmarshal(m, b) +} +func (m *TestProto) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_TestProto.Marshal(b, m, deterministic) +} +func (dst *TestProto) XXX_Merge(src proto.Message) { + xxx_messageInfo_TestProto.Merge(dst, src) +} +func (m *TestProto) XXX_Size() int { + return xxx_messageInfo_TestProto.Size(m) +} +func (m *TestProto) XXX_DiscardUnknown() { + xxx_messageInfo_TestProto.DiscardUnknown(m) +} + +var xxx_messageInfo_TestProto proto.InternalMessageInfo + +func (m *TestProto) GetWorkflowId() string { + if m != nil { + return m.StringValue + } + return "" +} + +func init() { + proto.RegisterType((*TestProto)(nil), "test.package.TestProto") +} + +func TestMarshalPbToString(t *testing.T) { + type args struct { + msg proto.Message + } + tests := []struct { + name string + args args + want string + wantErr bool + }{ + {"empty", args{msg: &TestProto{}}, "{}", false}, + {"has value", args{msg: &TestProto{StringValue: "hello"}}, `{"stringValue":"hello"}`, false}, + {"nil input", args{msg: nil}, "", true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := MarshalPbToString(tt.args.msg) + if (err != nil) != tt.wantErr { + t.Errorf("MarshalToString() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("MarshalToString() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestMarshalObjToStruct(t *testing.T) { + type args struct { + input interface{} + } + tests := []struct { + name string + args args + want *structpb.Struct + wantErr bool + }{ + {"has value", args{input: &TestProto{StringValue: "hello"}}, &structpb.Struct{Fields: map[string]*structpb.Value{ + "string_value": {Kind: &structpb.Value_StringValue{StringValue: "hello"}}, + }}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := MarshalObjToStruct(tt.args.input) + if (err != nil) != tt.wantErr { + t.Errorf("MarshalObjToStruct() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("MarshalObjToStruct() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestUnmarshalStructToPb(t *testing.T) { + type args struct { + structObj *structpb.Struct + msg proto.Message + } + tests := []struct { + name string + args args + expected proto.Message + wantErr bool + }{ + {"empty", args{structObj: &structpb.Struct{Fields: map[string]*structpb.Value{}}, msg: &TestProto{}}, &TestProto{}, false}, + {"has value", args{structObj: &structpb.Struct{Fields: map[string]*structpb.Value{ + "stringValue": {Kind: &structpb.Value_StringValue{StringValue: "hello"}}, + }}, msg: &TestProto{}}, &TestProto{StringValue: "hello"}, false}, + {"nil input", args{structObj: nil}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := UnmarshalStructToPb(tt.args.structObj, tt.args.msg); (err != nil) != tt.wantErr { + t.Errorf("UnmarshalStructToPb() error = %v, wantErr %v", err, tt.wantErr) + } else { + assert.Equal(t, tt.expected, tt.args.msg) + } + }) + } +} + +func TestMarshalPbToStruct(t *testing.T) { + type args struct { + in proto.Message + out *structpb.Struct + } + tests := []struct { + name string + args args + expected *structpb.Struct + wantErr bool + }{ + {"empty", args{in: &TestProto{}, out: &structpb.Struct{}}, &structpb.Struct{Fields: map[string]*structpb.Value{}}, false}, + {"has value", args{in: &TestProto{StringValue: "hello"}, out: &structpb.Struct{}}, &structpb.Struct{Fields: map[string]*structpb.Value{ + "stringValue": {Kind: &structpb.Value_StringValue{StringValue: "hello"}}, + }}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := MarshalPbToStruct(tt.args.in, tt.args.out); (err != nil) != tt.wantErr { + t.Errorf("MarshalPbToStruct() error = %v, wantErr %v", err, tt.wantErr) + } else { + assert.Equal(t, tt.expected.Fields, tt.args.out.Fields) + } + }) + } +} From 33d13beb7ce0c30154b5070807c0032fc405930d Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Thu, 18 Apr 2019 10:26:17 -0700 Subject: [PATCH 06/13] lint --- utils/marshal_utils_test.go | 25 ++----------------------- 1 file changed, 2 insertions(+), 23 deletions(-) diff --git a/utils/marshal_utils_test.go b/utils/marshal_utils_test.go index 33d185f..1cece1e 100644 --- a/utils/marshal_utils_test.go +++ b/utils/marshal_utils_test.go @@ -7,15 +7,12 @@ import ( "github.com/stretchr/testify/assert" "github.com/golang/protobuf/proto" - structpb "github.com/golang/protobuf/ptypes/struct" + "github.com/golang/protobuf/ptypes/struct" ) // Simple proto type TestProto struct { - StringValue string `protobuf:"bytes,3,opt,name=string_value,json=stringValue,proto3" json:"string_value,omitempty"` - XXX_NoUnkeyedLiteral struct{} `json:"-"` - XXX_unrecognized []byte `json:"-"` - XXX_sizecache int32 `json:"-"` + StringValue string `protobuf:"bytes,3,opt,name=string_value,json=stringValue,proto3" json:"string_value,omitempty"` } func (m *TestProto) Reset() { *m = TestProto{} } @@ -24,24 +21,6 @@ func (*TestProto) ProtoMessage() {} func (*TestProto) Descriptor() ([]byte, []int) { return []byte{}, []int{0} } -func (m *TestProto) XXX_Unmarshal(b []byte) error { - return xxx_messageInfo_TestProto.Unmarshal(m, b) -} -func (m *TestProto) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { - return xxx_messageInfo_TestProto.Marshal(b, m, deterministic) -} -func (dst *TestProto) XXX_Merge(src proto.Message) { - xxx_messageInfo_TestProto.Merge(dst, src) -} -func (m *TestProto) XXX_Size() int { - return xxx_messageInfo_TestProto.Size(m) -} -func (m *TestProto) XXX_DiscardUnknown() { - xxx_messageInfo_TestProto.DiscardUnknown(m) -} - -var xxx_messageInfo_TestProto proto.InternalMessageInfo - func (m *TestProto) GetWorkflowId() string { if m != nil { return m.StringValue From 6ab16423d2eb81c0d900415456cd2ef6f5334391 Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Thu, 18 Apr 2019 10:31:20 -0700 Subject: [PATCH 07/13] lint --- utils/marshal_utils_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/marshal_utils_test.go b/utils/marshal_utils_test.go index 1cece1e..a5f1512 100644 --- a/utils/marshal_utils_test.go +++ b/utils/marshal_utils_test.go @@ -21,7 +21,7 @@ func (*TestProto) ProtoMessage() {} func (*TestProto) Descriptor() ([]byte, []int) { return []byte{}, []int{0} } -func (m *TestProto) GetWorkflowId() string { +func (m *TestProto) GetWorkflowID() string { if m != nil { return m.StringValue } From 465e1abc24c48b5a6a09d001ce80780b2e70727c Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Thu, 18 Apr 2019 12:01:16 -0700 Subject: [PATCH 08/13] docs & refactor --- utils/marshal_utils.go | 52 +++++++++++++++++++++++-------------- utils/marshal_utils_test.go | 21 ++++++++++----- 2 files changed, 47 insertions(+), 26 deletions(-) diff --git a/utils/marshal_utils.go b/utils/marshal_utils.go index 4a43410..4d9cc14 100644 --- a/utils/marshal_utils.go +++ b/utils/marshal_utils.go @@ -1,66 +1,80 @@ package utils import ( + "bytes" "encoding/json" "fmt" - "github.com/golang/protobuf/jsonpb" "github.com/golang/protobuf/proto" - structpb "github.com/golang/protobuf/ptypes/struct" + "github.com/golang/protobuf/ptypes/struct" + "github.com/pkg/errors" ) var jsonPbMarshaler = jsonpb.Marshaler{} +// Unmarshals a proto struct into a proto message using jsonPb marshaler. func UnmarshalStructToPb(structObj *structpb.Struct, msg proto.Message) error { if structObj == nil { - return fmt.Errorf("nil Struct Object passed") + return fmt.Errorf("nil Struct object passed") + } + + if msg == nil { + return fmt.Errorf("nil proto.Message object passed") } jsonObj, err := jsonPbMarshaler.MarshalToString(structObj) if err != nil { - return err + return errors.WithMessage(err, "Failed to marshal strcutObj input") } if err = jsonpb.UnmarshalString(jsonObj, msg); err != nil { - return err + return errors.WithMessage(err, "Failed to unmarshal json obj into proto") } return nil } -func MarshalPbToStruct(in proto.Message, out *structpb.Struct) error { - if out == nil { - return fmt.Errorf("nil Struct Object passed") +// Marshals a proto message into proto Struct using jsonPb marshaler. +func MarshalPbToStruct(in proto.Message) (out *structpb.Struct, err error) { + if in == nil { + return nil, fmt.Errorf("nil proto message passed") } - jsonObj, err := jsonPbMarshaler.MarshalToString(in) - if err != nil { - return err + var buf bytes.Buffer + if err := jsonPbMarshaler.Marshal(&buf, in); err != nil { + return nil, errors.WithMessage(err, "Failed to marshal input proto message") } - if err = jsonpb.UnmarshalString(jsonObj, out); err != nil { - return err + out = &structpb.Struct{} + if err = jsonpb.Unmarshal(bytes.NewReader(buf.Bytes()), out); err != nil { + return nil, errors.WithMessage(err, "Failed to unmarshal json object into struct") } - return nil + return out, nil } +// Marshals a proto message using jsonPb marshaler to string. func MarshalPbToString(msg proto.Message) (string, error) { return jsonPbMarshaler.MarshalToString(msg) } -// TODO: Use the stdlib version in the future, or move there if not there. -// Don't use this if input is a proto Message. +// Marshals obj into a struct. Will use jsonPb if input is a proto message, otherwise, it'll use json +// marshaler. func MarshalObjToStruct(input interface{}) (*structpb.Struct, error) { + if p, casted := input.(proto.Message); casted { + return MarshalPbToStruct(p) + } + b, err := json.Marshal(input) if err != nil { - return nil, err + return nil, errors.WithMessage(err, "Failed to marshal input proto message") } // Turn JSON into a protobuf struct structObj := &structpb.Struct{} - if err := jsonpb.UnmarshalString(string(b), structObj); err != nil { - return nil, err + if err := jsonpb.Unmarshal(bytes.NewReader(b), structObj); err != nil { + return nil, errors.WithMessage(err, "Failed to unmarshal json object into struct") } + return structObj, nil } diff --git a/utils/marshal_utils_test.go b/utils/marshal_utils_test.go index a5f1512..4ac0fc1 100644 --- a/utils/marshal_utils_test.go +++ b/utils/marshal_utils_test.go @@ -10,6 +10,10 @@ import ( "github.com/golang/protobuf/ptypes/struct" ) +type SimpleType struct { + StringValue string `json:"string_value,omitempty"` +} + // Simple proto type TestProto struct { StringValue string `protobuf:"bytes,3,opt,name=string_value,json=stringValue,proto3" json:"string_value,omitempty"` @@ -70,9 +74,13 @@ func TestMarshalObjToStruct(t *testing.T) { want *structpb.Struct wantErr bool }{ - {"has value", args{input: &TestProto{StringValue: "hello"}}, &structpb.Struct{Fields: map[string]*structpb.Value{ + {"has proto value", args{input: &TestProto{StringValue: "hello"}}, &structpb.Struct{Fields: map[string]*structpb.Value{ + "stringValue": {Kind: &structpb.Value_StringValue{StringValue: "hello"}}, + }}, false}, + {"has struct value", args{input: SimpleType{StringValue: "hello"}}, &structpb.Struct{Fields: map[string]*structpb.Value{ "string_value": {Kind: &structpb.Value_StringValue{StringValue: "hello"}}, }}, false}, + {"has string value", args{input: "hello"}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -118,8 +126,7 @@ func TestUnmarshalStructToPb(t *testing.T) { func TestMarshalPbToStruct(t *testing.T) { type args struct { - in proto.Message - out *structpb.Struct + in proto.Message } tests := []struct { name string @@ -127,17 +134,17 @@ func TestMarshalPbToStruct(t *testing.T) { expected *structpb.Struct wantErr bool }{ - {"empty", args{in: &TestProto{}, out: &structpb.Struct{}}, &structpb.Struct{Fields: map[string]*structpb.Value{}}, false}, - {"has value", args{in: &TestProto{StringValue: "hello"}, out: &structpb.Struct{}}, &structpb.Struct{Fields: map[string]*structpb.Value{ + {"empty", args{in: &TestProto{}}, &structpb.Struct{Fields: map[string]*structpb.Value{}}, false}, + {"has value", args{in: &TestProto{StringValue: "hello"}}, &structpb.Struct{Fields: map[string]*structpb.Value{ "stringValue": {Kind: &structpb.Value_StringValue{StringValue: "hello"}}, }}, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := MarshalPbToStruct(tt.args.in, tt.args.out); (err != nil) != tt.wantErr { + if got, err := MarshalPbToStruct(tt.args.in); (err != nil) != tt.wantErr { t.Errorf("MarshalPbToStruct() error = %v, wantErr %v", err, tt.wantErr) } else { - assert.Equal(t, tt.expected.Fields, tt.args.out.Fields) + assert.Equal(t, tt.expected.Fields, got.Fields) } }) } From 1573ed41bd776ea8f2a0b39082435d17eb64ac50 Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Thu, 18 Apr 2019 12:58:47 -0700 Subject: [PATCH 09/13] minor --- cli/pflags/api/generator.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/cli/pflags/api/generator.go b/cli/pflags/api/generator.go index c6ea06f..2b7faa4 100644 --- a/cli/pflags/api/generator.go +++ b/cli/pflags/api/generator.go @@ -139,11 +139,10 @@ func discoverFieldsRecursive(ctx context.Context, typ *types.Named, defaultValue tag.Name = v.Name() } - isPtr := false typ := v.Type() - if ptr, casted := typ.(*types.Pointer); casted { + ptr, isPtr := typ.(*types.Pointer) + if isPtr { typ = ptr.Elem() - isPtr = true } switch t := typ.(type) { From 19afc26a59a207f769de2a5ca8fc1fbeb638e126 Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Thu, 18 Apr 2019 13:33:26 -0700 Subject: [PATCH 10/13] lint --- utils/marshal_utils.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/utils/marshal_utils.go b/utils/marshal_utils.go index 4d9cc14..129de4f 100644 --- a/utils/marshal_utils.go +++ b/utils/marshal_utils.go @@ -4,9 +4,10 @@ import ( "bytes" "encoding/json" "fmt" + "github.com/golang/protobuf/jsonpb" "github.com/golang/protobuf/proto" - "github.com/golang/protobuf/ptypes/struct" + structpb "github.com/golang/protobuf/ptypes/struct" "github.com/pkg/errors" ) From 8f052f040605dbaa7f851a3576968bf94cf0ef29 Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Thu, 18 Apr 2019 14:11:48 -0700 Subject: [PATCH 11/13] solidify unit tests --- cli/pflags/api/generator.go | 10 ++++++++-- config/tests/accessor_test.go | 21 +++++++++++++++++++++ utils/auto_refresh_cache_test.go | 16 +++++++++++++--- 3 files changed, 42 insertions(+), 5 deletions(-) diff --git a/cli/pflags/api/generator.go b/cli/pflags/api/generator.go index 2b7faa4..2f41faa 100644 --- a/cli/pflags/api/generator.go +++ b/cli/pflags/api/generator.go @@ -101,10 +101,16 @@ func appendAccessors(accessors ...string) string { for _, s := range accessors { if len(s) > 0 { if sb.Len() > 0 { - sb.WriteString(".") + if _, err := sb.WriteString("."); err != nil { + fmt.Printf("Failed to writeString, error: %v", err) + return "" + } } - sb.WriteString(s) + if _, err := sb.WriteString(s); err != nil { + fmt.Printf("Failed to writeString, error: %v", err) + return "" + } } } diff --git a/config/tests/accessor_test.go b/config/tests/accessor_test.go index d8c9e2a..e6f9309 100644 --- a/config/tests/accessor_test.go +++ b/config/tests/accessor_test.go @@ -379,6 +379,27 @@ func TestAccessor_UpdateConfig(t *testing.T) { // Make sure values have changed assert.NotEqual(t, firstValue, secondValue) }) + + t.Run(fmt.Sprintf("[%v] Default variables", provider(config.Options{}).ID()), func(t *testing.T) { + reg := config.NewRootSection() + _, err := reg.RegisterSection(MyComponentSectionKey, &MyComponentConfig{ + StringValue: "default value 1", + StringValue2: "default value 2", + }) + assert.NoError(t, err) + + v := provider(config.Options{ + SearchPaths: []string{filepath.Join("testdata", "config.yaml")}, + RootSection: reg, + }) + key := strings.ToUpper("my-component.str") + assert.NoError(t, os.Setenv(key, "Set From Env")) + defer func() { assert.NoError(t, os.Unsetenv(key)) }() + assert.NoError(t, v.UpdateConfig(context.TODO())) + r := reg.GetSection(MyComponentSectionKey).GetConfig().(*MyComponentConfig) + assert.Equal(t, "Set From Env", r.StringValue) + assert.Equal(t, "default value 2", r.StringValue2) + }) } } diff --git a/utils/auto_refresh_cache_test.go b/utils/auto_refresh_cache_test.go index 85a09ce..05d80ed 100644 --- a/utils/auto_refresh_cache_test.go +++ b/utils/auto_refresh_cache_test.go @@ -2,6 +2,7 @@ package utils import ( "context" + "sync" "testing" "time" @@ -15,6 +16,7 @@ type testCacheItem struct { val int deleted atomic.Bool resyncPeriod time.Duration + wg sync.WaitGroup } func (m *testCacheItem) ID() string { @@ -28,6 +30,8 @@ func (m *testCacheItem) moveNext() { } func (m *testCacheItem) syncItem(ctx context.Context, obj CacheItem) (CacheItem, error) { + defer func() { m.wg.Done() }() + if m.deleted.Load() { return nil, nil } @@ -51,10 +55,15 @@ func TestCache(t *testing.T) { testResyncPeriod := time.Millisecond rateLimiter := NewRateLimiter("mockLimiter", 100, 1) - item := &testCacheItem{val: 0, resyncPeriod: testResyncPeriod, deleted: atomic.NewBool(false)} + wg := sync.WaitGroup{} + wg.Add(1) + item := &testCacheItem{ + val: 0, + resyncPeriod: testResyncPeriod, + deleted: atomic.NewBool(false), + wg: wg,} cache := NewAutoRefreshCache(item.syncItem, rateLimiter, testResyncPeriod) - //ctx := context.Background() ctx, cancel := context.WithCancel(context.Background()) cache.Start(ctx) @@ -75,7 +84,8 @@ func TestCache(t *testing.T) { // removed? item.moveNext() item.deleted.Store(true) - time.Sleep(testResyncPeriod * 10) // spare enough time to process remove! + wg.Wait() + time.Sleep(testResyncPeriod * 2) // spare enough time to process remove! val := cache.Get(item.ID()) assert.Nil(t, val) From 99ac62004e7da823f74ae48d05466ee8dcb798f1 Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Thu, 18 Apr 2019 14:34:14 -0700 Subject: [PATCH 12/13] Refactor test --- utils/auto_refresh_cache_test.go | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/utils/auto_refresh_cache_test.go b/utils/auto_refresh_cache_test.go index 05d80ed..b7007a7 100644 --- a/utils/auto_refresh_cache_test.go +++ b/utils/auto_refresh_cache_test.go @@ -2,7 +2,6 @@ package utils import ( "context" - "sync" "testing" "time" @@ -16,7 +15,7 @@ type testCacheItem struct { val int deleted atomic.Bool resyncPeriod time.Duration - wg sync.WaitGroup + synced atomic.Int32 } func (m *testCacheItem) ID() string { @@ -30,11 +29,12 @@ func (m *testCacheItem) moveNext() { } func (m *testCacheItem) syncItem(ctx context.Context, obj CacheItem) (CacheItem, error) { - defer func() { m.wg.Done() }() + defer func() { m.synced.Inc() }() if m.deleted.Load() { return nil, nil } + return m, nil } @@ -55,13 +55,11 @@ func TestCache(t *testing.T) { testResyncPeriod := time.Millisecond rateLimiter := NewRateLimiter("mockLimiter", 100, 1) - wg := sync.WaitGroup{} - wg.Add(1) item := &testCacheItem{ val: 0, resyncPeriod: testResyncPeriod, deleted: atomic.NewBool(false), - wg: wg,} + synced: atomic.NewInt32(0),} cache := NewAutoRefreshCache(item.syncItem, rateLimiter, testResyncPeriod) ctx, cancel := context.WithCancel(context.Background()) @@ -83,9 +81,14 @@ func TestCache(t *testing.T) { // removed? item.moveNext() + currentSyncCount := item.synced.Load() item.deleted.Store(true) - wg.Wait() - time.Sleep(testResyncPeriod * 2) // spare enough time to process remove! + for currentSyncCount == item.synced.Load() { + time.Sleep(testResyncPeriod * 5) // spare enough time to process remove! + } + + time.Sleep(testResyncPeriod * 10) // spare enough time to process remove! + val := cache.Get(item.ID()) assert.Nil(t, val) From 55b56f3b0028515157b8bd0793802dda7549beba Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Thu, 18 Apr 2019 14:52:49 -0700 Subject: [PATCH 13/13] lint --- utils/auto_refresh_cache_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/auto_refresh_cache_test.go b/utils/auto_refresh_cache_test.go index b7007a7..dab8c9d 100644 --- a/utils/auto_refresh_cache_test.go +++ b/utils/auto_refresh_cache_test.go @@ -59,7 +59,7 @@ func TestCache(t *testing.T) { val: 0, resyncPeriod: testResyncPeriod, deleted: atomic.NewBool(false), - synced: atomic.NewInt32(0),} + synced: atomic.NewInt32(0)} cache := NewAutoRefreshCache(item.syncItem, rateLimiter, testResyncPeriod) ctx, cancel := context.WithCancel(context.Background())