Skip to content

Commit

Permalink
Refactor into smaller files
Browse files Browse the repository at this point in the history
  • Loading branch information
AdamSLevy committed Mar 2, 2020
1 parent 389966e commit 44ea29f
Show file tree
Hide file tree
Showing 10 changed files with 263 additions and 134 deletions.
92 changes: 13 additions & 79 deletions build.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,14 @@ package flagbuilder

import (
"flag"
"fmt"
"reflect"
"strconv"
"strings"
"time"
"unicode"
)

var ErrorInvalidType = fmt.Errorf("v must be a pointer to a struct")

type ErrorDefaultValue struct {
FieldName string
Value string
Err error
}

func (err ErrorDefaultValue) Error() string {
return fmt.Sprintf("%v: cannot assign default value %q: %v",
err.FieldName, err.Value, err.Err)
}
func (err ErrorDefaultValue) Unwrap() error {
return err.Err
}
"github.com/spf13/pflag"
)

func Build(flg FlagSet, v interface{}) error {
func Bind(flg FlagSet, v interface{}) error {
ptr := reflect.ValueOf(v)
if ptr.Kind() != reflect.Ptr {
return ErrorInvalidType
Expand Down Expand Up @@ -146,67 +129,18 @@ func Build(flg FlagSet, v interface{}) error {
fieldT.Name, tag.Value, err}
}
}
flg.Var(p, tag.Name, tag.Usage)
switch flg := flg.(type) {
case STDFlagSet:
flg.Var(p, tag.Name, tag.Usage)
case PFlagSet:
pp, ok := p.(pflag.Value)
if !ok {
pp = pflagValue{p, fieldT.Type.Name()}
}
flg.Var(pp, tag.Name, tag.Usage)
}
}
}

return nil
}

type FlagSet interface {
BoolVar(p *bool, name string, value bool, usage string)
DurationVar(p *time.Duration, name string, value time.Duration, usage string)
Float64Var(p *float64, name string, value float64, usage string)
Int64Var(p *int64, name string, value int64, usage string)
IntVar(p *int, name string, value int, usage string)
StringVar(p *string, name string, value string, usage string)
Uint64Var(p *uint64, name string, value uint64, usage string)
UintVar(p *uint, name string, value uint, usage string)
Var(value flag.Value, name string, usage string)
}

type _ struct {
X int `flag:"-X;"`
}
type flagTag struct {
Name string
Value string
Usage string

Ignored bool
}

func newFlagTag(tag string) (fTag flagTag) {
if len(tag) == 0 {
return
}
args := strings.Split(tag, ";")
fTag.Ignored = args[0] == "-" // Ignore this field
if fTag.Ignored {
return
}
fTag.Name = strings.TrimLeft(args[0], "-")
if len(args) == 1 {
return
}
fTag.Value = args[1]
if len(args) == 2 {
return
}
fTag.Usage = args[2]
return
}

func kebabCase(name string) string {
var kebab string
for _, r := range name {
if unicode.IsUpper(r) {
if len(kebab) > 0 {
kebab += "-"
}
r = unicode.ToLower(r)
}
kebab += string(r)
}
return kebab
}
127 changes: 73 additions & 54 deletions build_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,54 +4,88 @@ import (
"bytes"
"flag"
"fmt"
"strings"
"io"
"testing"
"time"

"github.com/spf13/pflag"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

type FlagTest struct {
// BindTest stores all data for a test of Bind.
type BindTest struct {
Name string
F interface{}
ErrBuild string
Usage string
Args []string
ExpF interface{}
ErrParse string
UsePFlag bool
// This is the *struct{} to bind flags to.
F interface{}
ErrBind string
Usage string
Args []string
ExpF interface{}
ErrParse string
ErrPFlagParse string
}

func (test *FlagTest) Run(t *testing.T) {
func (test *BindTest) Run(t *testing.T) {
t.Run(test.Name, test.test)
test.UsePFlag = true
t.Run(test.Name+" pflag", test.test)
}

func (test *FlagTest) test(t *testing.T) {
func (test *BindTest) test(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
flg := flag.NewFlagSet("", flag.ContinueOnError)
var flg interface {
FlagSet
SetOutput(io.Writer)
Usage()
Parse([]string) error
}
args := test.Args
if test.UsePFlag {
flg = pflagSetUsage{pflag.NewFlagSet("", pflag.ContinueOnError)}
args = append([]string{}, args...)
for i, arg := range args {
if arg[0:1] != "-" ||
len(arg) == 2 {
continue
}
args[i] = "-" + arg
}
} else {
flg = flagSetUsage{flag.NewFlagSet("", flag.ContinueOnError)}
}
usageOutput := bytes.NewBuffer(nil)
flg.SetOutput(usageOutput)

err := Build(flg, test.F)
err := Bind(flg, test.F)

if test.ErrBuild != "" {
assert.EqualError(err, test.ErrBuild, "flagbuilder.Build()")
if test.ErrBind != "" {
assert.EqualError(err, test.ErrBind, "flagbuilder.Bind()")
return
}
require.NoError(err, "flagbuilder.Build()")
require.NoError(err, "flagbuilder.Bind()")

if test.Usage != "" {
flg.Usage()
assert.Contains(string(usageOutput.Bytes()), test.Usage,
"flag.FlagSet.Usage()")
}

err = flg.Parse(test.Args)
err = flg.Parse(args)

if test.ErrParse != "" {
assert.EqualError(err, test.ErrParse, "flag.FlagSet.Parse()")
return
if test.UsePFlag {
if test.ErrPFlagParse != "" {
assert.EqualError(err, test.ErrPFlagParse, "flag.FlagSet.Parse()")
return
}

} else {
if test.ErrParse != "" {
assert.EqualError(err, test.ErrParse, "flag.FlagSet.Parse()")
return
}
}
require.NoError(err, "flag.FlagSet.Parse()")

Expand Down Expand Up @@ -85,34 +119,23 @@ type ValidTestFlags struct {
ValueDefault TestValue `flag:";true;"`
}

type TestValue bool

func (v *TestValue) Set(text string) error {
switch strings.ToLower(text) {
case "true":
*v = true
case "false":
*v = false
default:
return fmt.Errorf("could not parse %q as TestValue", text)
func TestBind(t *testing.T) {
for _, test := range tests {
test.Run(t)
}
return nil
}
func (v TestValue) String() string {
return fmt.Sprint(bool(v))
}

var tests = []FlagTest{
var tests = []BindTest{
{
Name: "invalid type",
F: struct {
Bool bool
}{},
ErrBuild: ErrorInvalidType.Error(),
ErrBind: ErrorInvalidType.Error(),
}, {
Name: "invalid type",
F: new(int),
ErrBuild: ErrorInvalidType.Error(),
Name: "invalid type",
F: new(int),
ErrBind: ErrorInvalidType.Error(),
}, {
Name: "valid",
F: &ValidTestFlags{
Expand Down Expand Up @@ -168,7 +191,8 @@ var tests = []FlagTest{
PtrDefaultInherit: func() *bool { b := true; return &b }(),
PtrDefault: func() *bool { b := true; return &b }(),
},
ErrParse: "flag provided but not defined: -ignored",
ErrParse: "flag provided but not defined: -ignored",
ErrPFlagParse: "unknown flag: --ignored",
}, {
Name: "skip unexported",
F: &ValidTestFlags{},
Expand All @@ -181,75 +205,70 @@ var tests = []FlagTest{
PtrDefaultInherit: func() *bool { b := true; return &b }(),
PtrDefault: func() *bool { b := true; return &b }(),
},
ErrParse: "flag provided but not defined: -skip",
ErrParse: "flag provided but not defined: -skip",
ErrPFlagParse: "unknown flag: --skip",
}, {
Name: "invalid default Value",
F: &struct {
Value TestValue `flag:";asdf;"`
}{},
ErrBuild: ErrorDefaultValue{"Value", "asdf",
ErrBind: ErrorDefaultValue{"Value", "asdf",
fmt.Errorf(`could not parse "asdf" as TestValue`)}.Error(),
}, {
Name: "invalid default bool",
F: &struct {
Bool bool `flag:";asdf;"`
}{},
ErrBuild: ErrorDefaultValue{"Bool", "asdf",
ErrBind: ErrorDefaultValue{"Bool", "asdf",
fmt.Errorf(`strconv.ParseBool: parsing "asdf": invalid syntax`),
}.Error(),
}, {
Name: "invalid default int",
F: &struct {
Int int `flag:";asdf;"`
}{},
ErrBuild: ErrorDefaultValue{"Int", "asdf",
ErrBind: ErrorDefaultValue{"Int", "asdf",
fmt.Errorf(`strconv.ParseInt: parsing "asdf": invalid syntax`),
}.Error(),
}, {
Name: "invalid default uint",
F: &struct {
Uint uint `flag:";-1;"`
}{},
ErrBuild: ErrorDefaultValue{"Uint", "-1",
ErrBind: ErrorDefaultValue{"Uint", "-1",
fmt.Errorf(`strconv.ParseUint: parsing "-1": invalid syntax`),
}.Error(),
}, {
Name: "invalid default uint64",
F: &struct {
Uint64 uint64 `flag:";-1;"`
}{},
ErrBuild: ErrorDefaultValue{"Uint64", "-1",
ErrBind: ErrorDefaultValue{"Uint64", "-1",
fmt.Errorf(`strconv.ParseUint: parsing "-1": invalid syntax`),
}.Error(),
}, {
Name: "invalid default int64",
F: &struct {
Int64 int64 `flag:";asdf;"`
}{},
ErrBuild: ErrorDefaultValue{"Int64", "asdf",
ErrBind: ErrorDefaultValue{"Int64", "asdf",
fmt.Errorf(`strconv.ParseInt: parsing "asdf": invalid syntax`),
}.Error(),
}, {
Name: "invalid default float64",
F: &struct {
Float64 float64 `flag:";asdf;"`
}{},
ErrBuild: ErrorDefaultValue{"Float64", "asdf",
ErrBind: ErrorDefaultValue{"Float64", "asdf",
fmt.Errorf(`strconv.ParseFloat: parsing "asdf": invalid syntax`),
}.Error(),
}, {
Name: "invalid default time.Duration",
F: &struct {
Duration time.Duration `flag:";asdf;"`
}{},
ErrBuild: ErrorDefaultValue{"Duration", "asdf",
ErrBind: ErrorDefaultValue{"Duration", "asdf",
fmt.Errorf(`time: invalid duration asdf`),
}.Error(),
},
}

func TestBuild(t *testing.T) {
for _, test := range tests {
test.Run(t)
}
}
19 changes: 19 additions & 0 deletions error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package flagbuilder

import "fmt"

var ErrorInvalidType = fmt.Errorf("v must be a pointer to a struct")

type ErrorDefaultValue struct {
FieldName string
Value string
Err error
}

func (err ErrorDefaultValue) Error() string {
return fmt.Sprintf("%v: cannot assign default value %q: %v",
err.FieldName, err.Value, err.Err)
}
func (err ErrorDefaultValue) Unwrap() error {
return err.Err
}
Loading

0 comments on commit 44ea29f

Please sign in to comment.