Skip to content

Commit

Permalink
Add FlagValue interface to support other flag systems.
Browse files Browse the repository at this point in the history
Using an interface allows people to use their favourite flag system
with viper without being restricted to the semantics of pflag or the
standard library.

This change introduce two new functions `BindFlagValues` and
`BindFlagValue` that behave like `BindFlags` and `BindFlag` but using
the new interface as values.

This change also introduces two internal structures to transform
`*pflag.FlagSet` and `*pflag.Flag` into the new interface. This way,
viper keeps working as expected for people that are currently using the
pflag package without breaking backwards compatibility.

Signed-off-by: David Calavera <[email protected]>
  • Loading branch information
calavera authored and spf13 committed Dec 24, 2015
1 parent 105e3d0 commit 66249a6
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 18 deletions.
57 changes: 57 additions & 0 deletions flags.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package viper

import "github.com/spf13/pflag"

// FlagValueSet is an interface that users can implement
// to bind a set of flags to viper.
type FlagValueSet interface {
VisitAll(fn func(FlagValue))
}

// FlagValue is an interface that users can implement
// to bind different flags to viper.
type FlagValue interface {
HasChanged() bool
Name() string
ValueString() string
ValueType() string
}

// pflagValueSet is a wrapper around *pflag.ValueSet
// that implements FlagValueSet.
type pflagValueSet struct {
flags *pflag.FlagSet
}

// VisitAll iterates over all *pflag.Flag inside the *pflag.FlagSet.
func (p pflagValueSet) VisitAll(fn func(flag FlagValue)) {
p.flags.VisitAll(func(flag *pflag.Flag) {
fn(pflagValue{flag})
})
}

// pflagValue is a wrapper aroung *pflag.flag
// that implements FlagValue
type pflagValue struct {
flag *pflag.Flag
}

// HasChanges returns whether the flag has changes or not.
func (p pflagValue) HasChanged() bool {
return p.flag.Changed
}

// Name returns the name of the flag.
func (p pflagValue) Name() string {
return p.flag.Name
}

// ValueString returns the value of the flag as a string.
func (p pflagValue) ValueString() string {
return p.flag.Value.String()
}

// ValueType returns the type of the flag as a string.
func (p pflagValue) ValueType() string {
return p.flag.Value.Type()
}
66 changes: 66 additions & 0 deletions flags_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package viper

import (
"testing"

"github.com/spf13/pflag"
"github.com/stretchr/testify/assert"
)

func TestBindFlagValueSet(t *testing.T) {
flagSet := pflag.NewFlagSet("test", pflag.ContinueOnError)

var testValues = map[string]*string{
"host": nil,
"port": nil,
"endpoint": nil,
}

var mutatedTestValues = map[string]string{
"host": "localhost",
"port": "6060",
"endpoint": "/public",
}

for name, _ := range testValues {
testValues[name] = flagSet.String(name, "", "test")
}

flagValueSet := pflagValueSet{flagSet}

err := BindFlagValues(flagValueSet)
if err != nil {
t.Fatalf("error binding flag set, %v", err)
}

flagSet.VisitAll(func(flag *pflag.Flag) {
flag.Value.Set(mutatedTestValues[flag.Name])
flag.Changed = true
})

for name, expected := range mutatedTestValues {
assert.Equal(t, Get(name), expected)
}
}

func TestBindFlagValue(t *testing.T) {
var testString = "testing"
var testValue = newStringValue(testString, &testString)

flag := &pflag.Flag{
Name: "testflag",
Value: testValue,
Changed: false,
}

flagValue := pflagValue{flag}
BindFlagValue("testvalue", flagValue)

assert.Equal(t, testString, Get("testvalue"))

flag.Value.Set("testing_mutate")
flag.Changed = true //hack for pflag usage

assert.Equal(t, "testing_mutate", Get("testvalue"))

}
54 changes: 36 additions & 18 deletions viper.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ type Viper struct {
override map[string]interface{}
defaults map[string]interface{}
kvstore map[string]interface{}
pflags map[string]*pflag.Flag
pflags map[string]FlagValue
env map[string]string
aliases map[string]string
typeByDefValue bool
Expand All @@ -166,7 +166,7 @@ func New() *Viper {
v.override = make(map[string]interface{})
v.defaults = make(map[string]interface{})
v.kvstore = make(map[string]interface{})
v.pflags = make(map[string]*pflag.Flag)
v.pflags = make(map[string]FlagValue)
v.env = make(map[string]string)
v.aliases = make(map[string]string)
v.typeByDefValue = false
Expand Down Expand Up @@ -467,13 +467,13 @@ func (v *Viper) Get(key string) interface{} {
if val == nil {
if flag, exists := v.pflags[lcaseKey]; exists {
jww.TRACE.Println(key, "get pflag default", val)
switch flag.Value.Type() {
switch flag.ValueType() {
case "int", "int8", "int16", "int32", "int64":
val = cast.ToInt(flag.Value.String())
val = cast.ToInt(flag.ValueString())
case "bool":
val = cast.ToBool(flag.Value.String())
val = cast.ToBool(flag.ValueString())
default:
val = flag.Value.String()
val = flag.ValueString()
}
}
}
Expand Down Expand Up @@ -618,22 +618,40 @@ func (v *Viper) Unmarshal(rawVal interface{}) error {
// name as the config key.
func BindPFlags(flags *pflag.FlagSet) (err error) { return v.BindPFlags(flags) }
func (v *Viper) BindPFlags(flags *pflag.FlagSet) (err error) {
flags.VisitAll(func(flag *pflag.Flag) {
if err = v.BindPFlag(flag.Name, flag); err != nil {
return v.BindFlagValues(pflagValueSet{flags})
}

// Bind a specific key to a pflag (as used by cobra)
// Example(where serverCmd is a Cobra instance):
//
// serverCmd.Flags().Int("port", 1138, "Port to run Application server on")
// Viper.BindPFlag("port", serverCmd.Flags().Lookup("port"))
//
func BindPFlag(key string, flag *pflag.Flag) (err error) { return v.BindPFlag(key, flag) }
func (v *Viper) BindPFlag(key string, flag *pflag.Flag) (err error) {
return v.BindFlagValue(key, pflagValue{flag})
}

// Bind a full FlagValue set to the configuration, using each flag's long
// name as the config key.
func BindFlagValues(flags FlagValueSet) (err error) { return v.BindFlagValues(flags) }
func (v *Viper) BindFlagValues(flags FlagValueSet) (err error) {
flags.VisitAll(func(flag FlagValue) {
if err = v.BindFlagValue(flag.Name(), flag); err != nil {
return
}
})
return nil
}

// Bind a specific key to a flag (as used by cobra)
// Bind a specific key to a FlagValue.
// Example(where serverCmd is a Cobra instance):
//
// serverCmd.Flags().Int("port", 1138, "Port to run Application server on")
// Viper.BindPFlag("port", serverCmd.Flags().Lookup("port"))
// Viper.BindFlagValue("port", serverCmd.Flags().Lookup("port"))
//
func BindPFlag(key string, flag *pflag.Flag) (err error) { return v.BindPFlag(key, flag) }
func (v *Viper) BindPFlag(key string, flag *pflag.Flag) (err error) {
func BindFlagValue(key string, flag FlagValue) (err error) { return v.BindFlagValue(key, flag) }
func (v *Viper) BindFlagValue(key string, flag FlagValue) (err error) {
if flag == nil {
return fmt.Errorf("flag for %q is nil", key)
}
Expand Down Expand Up @@ -678,15 +696,15 @@ func (v *Viper) find(key string) interface{} {

// PFlag Override first
flag, exists := v.pflags[key]
if exists && flag.Changed {
jww.TRACE.Println(key, "found in override (via pflag):", flag.Value)
switch flag.Value.Type() {
if exists && flag.HasChanged() {
jww.TRACE.Println(key, "found in override (via pflag):", flag.ValueString())
switch flag.ValueType() {
case "int", "int8", "int16", "int32", "int64":
return cast.ToInt(flag.Value.String())
return cast.ToInt(flag.ValueString())
case "bool":
return cast.ToBool(flag.Value.String())
return cast.ToBool(flag.ValueString())
default:
return flag.Value.String()
return flag.ValueString()
}
}

Expand Down

0 comments on commit 66249a6

Please sign in to comment.