Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Merge pull request #6 from lyft/config-default-values
Browse files Browse the repository at this point in the history
Add option to PFlags to use a variable to get default values
  • Loading branch information
EngHabu authored Apr 18, 2019
2 parents 3c7c8a8 + 55b56f3 commit 6541fbe
Show file tree
Hide file tree
Showing 19 changed files with 630 additions and 230 deletions.
101 changes: 88 additions & 13 deletions cli/pflags/api/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"go/types"
"path/filepath"
"strings"

"github.com/lyft/flytestdlib/logger"

Expand All @@ -19,8 +20,9 @@ const (

// PFlagProviderGenerator parses and generates GetPFlagSet implementation to add PFlags for a given struct's fields.
type PFlagProviderGenerator struct {
pkg *types.Package
st *types.Named
pkg *types.Package
st *types.Named
defaultVar *types.Var
}

// This list is restricted because that's the only kinds viper parses out, otherwise it assumes strings.
Expand Down Expand Up @@ -54,7 +56,7 @@ func buildFieldForSlice(ctx context.Context, t SliceOrArray, name, goName, usage
emptyDefaultValue := `[]string{}`
if b, ok := t.Elem().(*types.Basic); !ok {
logger.Infof(ctx, "Elem of type [%v] is not a basic type. It must be json unmarshalable or generation will fail.", t.Elem())
if !jsonUnmarshaler(t.Elem()) {
if !isJSONUnmarshaler(t.Elem()) {
return FieldInfo{},
fmt.Errorf("slice of type [%v] is not supported. Only basic slices or slices of json-unmarshalable types are supported",
t.Elem().String())
Expand Down Expand Up @@ -85,9 +87,41 @@ func buildFieldForSlice(ctx context.Context, t SliceOrArray, name, goName, usage
}, nil
}

// Appends field accessors using "." as the delimiter.
// e.g. appendAccessors("var1", "field1", "subField") will output "var1.field1.subField"
func appendAccessors(accessors ...string) string {
sb := strings.Builder{}
switch len(accessors) {
case 0:
return ""
case 1:
return accessors[0]
}

for _, s := range accessors {
if len(s) > 0 {
if sb.Len() > 0 {
if _, err := sb.WriteString("."); err != nil {
fmt.Printf("Failed to writeString, error: %v", err)
return ""
}
}

if _, err := sb.WriteString(s); err != nil {
fmt.Printf("Failed to writeString, error: %v", err)
return ""
}
}
}

return sb.String()
}

// Traverses fields in type and follows recursion tree to discover all fields. It stops when one of two conditions is
// met; encountered a basic type (e.g. string, int... etc.) or the field type implements UnmarshalJSON.
func discoverFieldsRecursive(ctx context.Context, typ *types.Named) ([]FieldInfo, error) {
// If passed a non-empty defaultValueAccessor, it'll be used to fill in default values instead of any default value
// specified in pflag tag.
func discoverFieldsRecursive(ctx context.Context, typ *types.Named, defaultValueAccessor, fieldPath string) ([]FieldInfo, error) {
logger.Printf(ctx, "Finding all fields in [%v.%v.%v]",
typ.Obj().Pkg().Path(), typ.Obj().Pkg().Name(), typ.Obj().Name())

Expand All @@ -112,7 +146,8 @@ func discoverFieldsRecursive(ctx context.Context, typ *types.Named) ([]FieldInfo
}

typ := v.Type()
if ptr, isPtr := typ.(*types.Pointer); isPtr {
ptr, isPtr := typ.(*types.Pointer)
if isPtr {
typ = ptr.Elem()
}

Expand All @@ -137,12 +172,21 @@ func discoverFieldsRecursive(ctx context.Context, typ *types.Named) ([]FieldInfo
t.String(), t.Kind(), allowedKinds)
}

defaultValue := tag.DefaultValue
if len(defaultValueAccessor) > 0 {
defaultValue = appendAccessors(defaultValueAccessor, fieldPath, v.Name())

if isPtr {
defaultValue = fmt.Sprintf("cfg.elemValueOrNil(%s).(%s)", defaultValue, t.Name())
}
}

fields = append(fields, FieldInfo{
Name: tag.Name,
GoName: v.Name(),
Typ: t,
FlagMethodName: camelCase(t.String()),
DefaultValue: tag.DefaultValue,
DefaultValue: defaultValue,
UsageString: tag.Usage,
TestValue: `"1"`,
TestStrategy: JSON,
Expand All @@ -155,14 +199,26 @@ func discoverFieldsRecursive(ctx context.Context, typ *types.Named) ([]FieldInfo

// If the type has json unmarshaler, then stop the recursion and assume the type is string. config package
// will use json unmarshaler to fill in the final config object.
jsonUnmarshaler := jsonUnmarshaler(t)
jsonUnmarshaler := isJSONUnmarshaler(t)

testValue := tag.DefaultValue
if len(tag.DefaultValue) == 0 {
tag.DefaultValue = `""`
testValue = `"1"`
}

defaultValue := tag.DefaultValue
if len(defaultValueAccessor) > 0 {
defaultValue = appendAccessors(defaultValueAccessor, fieldPath, v.Name())
if isStringer(t) {
defaultValue = defaultValue + ".String()"
} else {
logger.Infof(ctx, "Field [%v] of type [%v] does not implement Stringer interface."+
" Will use fmt.Sprintf() to get its default value.", v.Name(), t.String())
defaultValue = fmt.Sprintf("fmt.Sprintf(\"%%v\",%s)", defaultValue)
}
}

logger.Infof(ctx, "[%v] is of a Named type (struct) with default value [%v].", tag.Name, tag.DefaultValue)

if jsonUnmarshaler {
Expand All @@ -173,15 +229,15 @@ func discoverFieldsRecursive(ctx context.Context, typ *types.Named) ([]FieldInfo
GoName: v.Name(),
Typ: types.Typ[types.String],
FlagMethodName: "String",
DefaultValue: tag.DefaultValue,
DefaultValue: defaultValue,
UsageString: tag.Usage,
TestValue: testValue,
TestStrategy: JSON,
})
} else {
logger.Infof(ctx, "Traversing fields in type.")

nested, err := discoverFieldsRecursive(logger.WithIndent(ctx, indent), t)
nested, err := discoverFieldsRecursive(logger.WithIndent(ctx, indent), t, defaultValueAccessor, appendAccessors(fieldPath, v.Name()))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -228,7 +284,8 @@ func discoverFieldsRecursive(ctx context.Context, typ *types.Named) ([]FieldInfo
// NewGenerator initializes a PFlagProviderGenerator for pflags files for targetTypeName struct under pkg. If pkg is not filled in,
// it's assumed to be current package (which is expected to be the common use case when invoking pflags from
// go:generate comments)
func NewGenerator(pkg, targetTypeName string) (*PFlagProviderGenerator, error) {
func NewGenerator(pkg, targetTypeName, defaultVariableName string) (*PFlagProviderGenerator, error) {
ctx := context.Background()
var err error
// Resolve package path
if pkg == "" || pkg[0] == '.' {
Expand Down Expand Up @@ -257,9 +314,22 @@ func NewGenerator(pkg, targetTypeName string) (*PFlagProviderGenerator, error) {
return nil, fmt.Errorf("%s should be an struct, was %s", targetTypeName, obj.Type().Underlying())
}

var defaultVar *types.Var
obj = targetPackage.Scope().Lookup(defaultVariableName)
if obj != nil {
defaultVar = obj.(*types.Var)
}

if defaultVar != nil {
logger.Infof(ctx, "Using default variable with name [%v] to assign all default values.", defaultVariableName)
} else {
logger.Infof(ctx, "Using default values defined in tags if any.")
}

return &PFlagProviderGenerator{
st: st,
pkg: targetPackage,
st: st,
pkg: targetPackage,
defaultVar: defaultVar,
}, nil
}

Expand All @@ -268,7 +338,12 @@ func (g PFlagProviderGenerator) GetTargetPackage() *types.Package {
}

func (g PFlagProviderGenerator) Generate(ctx context.Context) (PFlagProvider, error) {
fields, err := discoverFieldsRecursive(ctx, g.st)
defaultValueAccessor := ""
if g.defaultVar != nil {
defaultValueAccessor = g.defaultVar.Name()
}

fields, err := discoverFieldsRecursive(ctx, g.st, defaultValueAccessor, "")
if err != nil {
return PFlagProvider{}, err
}
Expand Down
32 changes: 31 additions & 1 deletion cli/pflags/api/generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"io/ioutil"
"os"
"path/filepath"
"reflect"
"testing"

"github.com/stretchr/testify/assert"
Expand All @@ -14,8 +15,37 @@ import (
// Make sure existing config file(s) parse correctly before overriding them with this flag!
var update = flag.Bool("update", false, "Updates testdata")

// If v is a pointer, it will get its element value or the zero value of the element type.
// If v is not a pointer, it will return it as is.
func elemValueOrNil(v interface{}) interface{} {
if t := reflect.TypeOf(v); t.Kind() == reflect.Ptr {
if reflect.ValueOf(v).IsNil() {
return reflect.Zero(t.Elem()).Interface()
}

return reflect.ValueOf(v).Interface()
} else if v == nil {
return reflect.Zero(t).Interface()
}

return v
}

func TestElemValueOrNil(t *testing.T) {
var iPtr *int
assert.Equal(t, 0, elemValueOrNil(iPtr))
var sPtr *string
assert.Equal(t, "", elemValueOrNil(sPtr))
var i int
assert.Equal(t, 0, elemValueOrNil(i))
var s string
assert.Equal(t, "", elemValueOrNil(s))
var arr []string
assert.Equal(t, arr, elemValueOrNil(arr))
}

func TestNewGenerator(t *testing.T) {
g, err := NewGenerator(".", "TestType")
g, err := NewGenerator(".", "TestType", "DefaultTestType")
assert.NoError(t, err)

ctx := context.Background()
Expand Down
4 changes: 4 additions & 0 deletions cli/pflags/api/sample.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ import (
"github.com/lyft/flytestdlib/storage"
)

var DefaultTestType = &TestType{
StringValue: "Welcome to defaults",
}

type TestType struct {
StringValue string `json:"str" pflag:"\"hello world\",\"life is short\""`
BoolValue bool `json:"bl" pflag:"true"`
Expand Down
18 changes: 17 additions & 1 deletion cli/pflags/api/templates.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,25 @@ import (
{{$name}} "{{$path}}"{{end}}
)
// If v is a pointer, it will get its element value or the zero value of the element type.
// If v is not a pointer, it will return it as is.
func ({{ .Name }}) elemValueOrNil(v interface{}) interface{} {
if t := reflect.TypeOf(v); t.Kind() == reflect.Ptr {
if reflect.ValueOf(v).IsNil() {
return reflect.Zero(t.Elem()).Interface()
} else {
return reflect.ValueOf(v).Interface()
}
} else if v == nil {
return reflect.Zero(t).Interface()
}
return v
}
// GetPFlagSet will return strongly types pflags for all fields in {{ .Name }} and its nested types. The format of the
// flags is json-name.json-sub-name... etc.
func ({{ .Name }}) GetPFlagSet(prefix string) *pflag.FlagSet {
func (cfg {{ .Name }}) GetPFlagSet(prefix string) *pflag.FlagSet {
cmdFlags := pflag.NewFlagSet("{{ .Name }}", pflag.ExitOnError)
{{- range .Fields }}
cmdFlags.{{ .FlagMethodName }}(fmt.Sprintf("%v%v", prefix, "{{ .Name }}"), {{ .DefaultValue }}, {{ .UsageString }})
Expand Down
51 changes: 34 additions & 17 deletions cli/pflags/api/testdata/testtype.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 6541fbe

Please sign in to comment.