diff --git a/configo.go b/configo.go index f66c2bd..aa6148b 100644 --- a/configo.go +++ b/configo.go @@ -5,6 +5,7 @@ import ( "github.com/shafreeck/toml" "github.com/shafreeck/toml/ast" + "flag" "fmt" goast "go/ast" "reflect" @@ -272,6 +273,9 @@ func Unmarshal(data []byte, v interface{}) error { if err := applyDefault(reflect.ValueOf(v), false); err != nil { return err } + + // apply flag param + ApplyFlags(flag.CommandLine, v) return nil } diff --git a/example/conf/example.toml b/example/conf/example.toml index 0bf05f5..b8b2057 100644 --- a/example/conf/example.toml +++ b/example/conf/example.toml @@ -10,10 +10,10 @@ #default: 10000 #max-connection = 10000 -[redis] +#[redis] #type: []string #rules: dialstring #description: The addresses of redis cluster #required -cluster = [] +#cluster = [] diff --git a/example/flags.go b/example/flags.go new file mode 100644 index 0000000..c3eed55 --- /dev/null +++ b/example/flags.go @@ -0,0 +1,29 @@ +package main + +import ( + "flag" + "fmt" + "io/ioutil" + "log" + + "github.com/distributedio/configo" + "github.com/distributedio/configo/example/conf" +) + +func main() { + c := &conf.Config{} + + configo.AddFlags(flag.CommandLine, c, "listen", "redis", "redis.0.cluster") + flag.Parse() + + data, err := ioutil.ReadFile("conf/example.toml") + if err != nil { + log.Fatalln(err) + } + + err = configo.Unmarshal(data, c) + if err != nil { + log.Fatalln(err) + } + fmt.Printf("%#v\n", c) +} diff --git a/flags.go b/flags.go new file mode 100644 index 0000000..65c8044 --- /dev/null +++ b/flags.go @@ -0,0 +1,200 @@ +package configo + +import ( + "flag" + "log" + "reflect" + "strconv" + "strings" + "time" + + "github.com/shafreeck/toml" +) + +/* + `flags` 实现将对象中的变量添加到`flag`中,从而实现通过命令行设置变量的功能。 + + import ( + "log" + + "github.com/distributedio/configo" + ) + + type Config struct { + Key string `cfg:"key; default;; simple type example"` + Child *Child `cfg:"child; ;; class type "` + Array []string `cfg:"array;;; simple array type"` + CompArray []*Child `cfg:"comp;;; complex array type"` + } + + type Child struct { + Name string `cfg:"name; noname;; child class item` + } + + func main() { + conf := &Config{} + configo.AddFlags(conf) + flag.Parse() + + if err := configo.Load("conf/example.toml", conf); err != nil { + log.Fatalln(err) + } + } + + 首先,需要在`flag.Parse()`之前调用`AddFlags()`将对象中的变量添加到`falg`中。 + `configo.Load()`会在内部调用`ApplyFlags()`方法,将`flag`中设置的变量应用到 + 对象中。 + + 对象中的变量按照如下规则对应`flag`中的`key`: + + * 简单数据类型,直接使用`cfg`中的`name`作为`flag`中的`key`。 + 如`Conf.Key`,对应`flag`中的`key`。 + * 对象数据类型,需要添加上一层对象的名称。 + 如 `Conf.Child.Name` 对应`flag`中的`child.name` + * 数组或slice类型,要增加下标作为一个层级。 + 如 `Conf.CompArray[0].Name`,对应`flag`中的`comp.0.name` + * 对于简单数据类型的数组或slice也可以使用名称作为`flag`中的`key`, + 使用字符串表示一个数组。 + 例如:`Conf.Array`,对应`flag`中的`array`。同时在执行中,使用如下的 + 方式设置`array`: + ./cmd -array="[\"a1\", \"a2\"]" +*/ + +const ( + ConfigoFlagSuffix = "[configo]" +) + +// AddFlags 将对象中的变量加入到flag中,从而可以通过命令行设置对应的变量。 +// +// * `obj` 为待加入到`flag`中的对象的实例 +// * `keys` 限定加入`flag`中变量的范围,**不设置**的时候表示将所有变量都加入到`flag`中。 +func AddFlags(fs *flag.FlagSet, obj interface{}, keys ...string) { + flagMap := make(map[string]struct{}, len(keys)) + for i := range keys { + flagMap[keys[i]] = struct{}{} + } + t := NewTravel(func(path string, tag *toml.CfgTag, fv reflect.Value) { + if _, ok := flagMap[path]; len(flagMap) > 0 && !ok { + return + } + var err error + switch fv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, + reflect.Int32, reflect.Int64: + var v int64 + if v, err = strconv.ParseInt(tag.Value, 10, 64); err != nil { + if fv.Kind() == reflect.Int64 { + //try to parse a time.Duration + if d, err := time.ParseDuration(tag.Value); err == nil { + fs.Duration(path, time.Duration(d), ConfigoFlagSuffix+tag.Description) + return + } + } + log.Fatalln(err) + return + } + fs.Int64(path, v, ConfigoFlagSuffix+tag.Description) + case reflect.Uint, reflect.Uint8, reflect.Uint16, + reflect.Uint32, reflect.Uint64: + var v uint64 + if v, err = strconv.ParseUint(tag.Value, 10, 64); err != nil { + log.Fatalln(err) + return + } + fs.Uint64(path, v, ConfigoFlagSuffix+tag.Description) + case reflect.Float32, reflect.Float64: + var v float64 + if v, err = strconv.ParseFloat(tag.Value, 64); err != nil { + log.Fatalln(err) + return + } + fs.Float64(path, v, ConfigoFlagSuffix+tag.Description) + case reflect.Bool: + var v bool + if v, err = strconv.ParseBool(tag.Value); err != nil { + log.Fatalln(err) + return + } + fs.Bool(path, v, ConfigoFlagSuffix+tag.Description) + case reflect.String: + fs.String(path, tag.Value, ConfigoFlagSuffix+tag.Description) + case reflect.Slice, reflect.Array: + // TODO 使用flag.Var设置变量 + fs.String(path, tag.Value, ConfigoFlagSuffix+tag.Description) + default: + log.Printf("unknow type %s for set flag", fv.Type()) + } + }) + t.Travel(obj) +} + +// ApplyFlags 将命令行中设置的变量值应用到`obj`中。 +// +// **注意:** configo中的函数默认会调用这个函数设置配置文件,所以不需要显示调用。 +func ApplyFlags(fs *flag.FlagSet, obj interface{}) { + actualFlags := make(map[string]*flag.Flag) + fs.Visit(func(f *flag.Flag) { + if strings.Contains(f.Usage, ConfigoFlagSuffix) { + actualFlags[f.Name] = f + } + }) + t := NewTravel(func(path string, tag *toml.CfgTag, fv reflect.Value) { + f, ok := actualFlags[path] + if !ok { + return + } + var err error + switch fv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, + reflect.Int32, reflect.Int64: + var v int64 + if v, err = strconv.ParseInt(f.Value.String(), 10, 64); err != nil { + if fv.Kind() == reflect.Int64 { + //try to parse a time.Duration + if d, err := time.ParseDuration(f.Value.String()); err == nil { + fv.SetInt(int64(d)) + return + } + } + log.Fatalln(err) + return + } + fv.SetInt(v) + case reflect.Uint, reflect.Uint8, reflect.Uint16, + reflect.Uint32, reflect.Uint64: + var v uint64 + if v, err = strconv.ParseUint(f.Value.String(), 10, 64); err != nil { + log.Fatalln(err) + return + } + fv.SetUint(v) + case reflect.Float32, reflect.Float64: + var v float64 + if v, err = strconv.ParseFloat(f.Value.String(), 64); err != nil { + log.Fatalln(err) + return + } + fv.SetFloat(v) + case reflect.Bool: + var v bool + if v, err = strconv.ParseBool(f.Value.String()); err != nil { + log.Fatalln(err) + return + } + fv.SetBool(v) + case reflect.String: + fv.SetString(f.Value.String()) + case reflect.Slice, reflect.Array: + // TODO NOT support + // if err := unmarshalArray("name", f.Value.String(), &s); err != nil { + // log.Fatalln(err) + // return + // } + // fv.Set(reflect.ValueOf(s.Name)) + // log.Printf("get list =%#v\n", s) + default: + log.Printf("unknow type %s for set flag", fv.Type()) + } + }) + t.Travel(obj) +} diff --git a/go.mod b/go.mod index 075fb0f..5b12bd6 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.12 require ( github.com/shafreeck/toml v0.0.0-20190326060449-44ad86712acc + github.com/stretchr/testify v1.3.0 golang.org/x/crypto v0.0.0-20190313024323-a1f597ede03a // indirect golang.org/x/net v0.0.0-20190320064053-1272bf9dcd53 // indirect golang.org/x/sys v0.0.0-20190318195719-6c81ef8f67ca // indirect diff --git a/travel.go b/travel.go new file mode 100644 index 0000000..c5045d2 --- /dev/null +++ b/travel.go @@ -0,0 +1,74 @@ +package configo + +import ( + "fmt" + "reflect" + + "github.com/shafreeck/toml" +) + +type TravelHandle func(path string, tag *toml.CfgTag, v reflect.Value) + +type Travel struct { + handle TravelHandle +} + +func NewTravel(h TravelHandle) *Travel { + return &Travel{handle: h} +} + +func (t *Travel) Travel(obj interface{}) { + t.travel("", nil, reflect.ValueOf(obj)) +} + +func (t *Travel) travel(path string, tag *toml.CfgTag, v reflect.Value) { + switch v.Kind() { + case reflect.Ptr: + vValue := v.Elem() + if !vValue.IsValid() { + return + } + t.travel(path, tag, vValue) + case reflect.Interface: + vValue := v.Elem() + if !vValue.IsValid() { + return + } + t.travel(path, tag, vValue) + case reflect.Struct: + for i := 0; i < v.NumField(); i += 1 { + if !v.Field(i).IsValid() { + continue + } + tag := extractTag(v.Type().Field(i).Tag.Get(fieldTagName)) + p := tag.Name + if len(path) > 0 { + p = path + "." + tag.Name + } + t.travel(p, tag, v.Field(i)) + } + case reflect.Slice, reflect.Array: + // handle slice & array as a whole + t.handle(path, tag, v) + for i := 0; i < v.Len(); i++ { + p := fmt.Sprintf("%d", i) + if len(path) > 0 { + p = path + "." + p + } + // handle every element + t.travel(p, tag, v.Index(i)) + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + fallthrough + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + fallthrough + case reflect.Float32, reflect.Float64: + fallthrough + case reflect.Bool: + fallthrough + case reflect.String: + t.handle(path, tag, v) + default: + panic(fmt.Sprintf("config file use unsupport type. %v", v.Type())) + } +} diff --git a/travel_test.go b/travel_test.go new file mode 100644 index 0000000..17c6abe --- /dev/null +++ b/travel_test.go @@ -0,0 +1,149 @@ +package configo + +import ( + "reflect" + "testing" + "time" + + "github.com/shafreeck/toml" + "github.com/stretchr/testify/require" +) + +func TestTravel_Travel(t *testing.T) { + tests := []struct { + name string + obj interface{} + want map[string]interface{} + }{ + { + name: "test string", + obj: struct { + name string `cfg:"name;;;user name"` + }{ + name: "name-value", + }, + want: map[string]interface{}{ + "name": "name-value", + }, + }, + { + name: "test int,float,bool", + obj: struct { + name string `cfg:"name;;;user name"` + age int `cfg:"age;;;user age"` + score float64 `cfg:"score;;;user score"` + sex bool `cfg:"sex;;;user sex"` + }{ + name: "user-name", + age: 18, + score: 60.1, + sex: true, + }, + want: map[string]interface{}{ + "name": "user-name", + "age": 18, + "score": 60.1, + "sex": true, + }, + }, + { + name: "test slice", + obj: struct { + Cluster []string `cfg:"cluster;;;the address of redis cluster"` + }{ + Cluster: []string{"127.0.0.1:6379", "127.0.0.1:7379"}, + }, + want: map[string]interface{}{ + "cluster": []string{"127.0.0.1:6379", "127.0.0.1:7379"}, + "cluster.0": "127.0.0.1:6379", + "cluster.1": "127.0.0.1:7379", + }, + }, + { + name: "test array", + obj: struct { + Cluster [2]string `cfg:"cluster;;;the address of redis cluster"` + }{ + Cluster: [2]string{"127.0.0.1:6379", "127.0.0.1:7379"}, + }, + want: map[string]interface{}{ + "cluster": [2]string{"127.0.0.1:6379", "127.0.0.1:7379"}, + "cluster.0": "127.0.0.1:6379", + "cluster.1": "127.0.0.1:7379", + }, + }, + { + name: "test time duration", + obj: struct { + Timeout time.Duration `cfg:"timeout;10s;;time out"` + }{ + Timeout: 3 * time.Second, + }, + want: map[string]interface{}{ + "timeout": 3 * time.Second, + }, + }, + { + name: "test struct in struct", + obj: struct { + nest struct { + Timeout time.Duration `cfg:"timeout;10s;;time out"` + } `cfg:"nest;;;nest struct"` + }{ + nest: struct { + Timeout time.Duration `cfg:"timeout;10s;;time out"` + }{ + Timeout: 3 * time.Second, + }, + }, + want: map[string]interface{}{ + "nest.timeout": 3 * time.Second, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + handle := func(path string, tag *toml.CfgTag, v reflect.Value) { + if w, ok := tt.want[path]; ok { + wt := reflect.ValueOf(w) + testEqual(t, wt, v) + delete(tt.want, path) + } else { + require.FailNowf(t, "unknow path", "path=%s, v=%v", path, v) + } + } + tr := NewTravel(handle) + tr.Travel(tt.obj) + + if len(tt.want) > 0 { + require.FailNowf(t, "not get some path", "%#v", tt.want) + } + }) + } +} + +func testEqual(t *testing.T, want, get reflect.Value) { + require.Equal(t, want.Kind(), get.Kind(), "data kind not same") + switch get.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + require.Equal(t, want.Int(), get.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + require.Equal(t, want.Uint(), get.Uint()) + case reflect.Uintptr: + case reflect.Float32, reflect.Float64: + require.Equal(t, want.Float(), get.Float()) + case reflect.Bool: + require.Equal(t, want.Bool(), get.Bool()) + case reflect.String: + require.Equal(t, want.String(), get.String()) + + case reflect.Slice, reflect.Array: + require.Equal(t, want.Len(), get.Len()) + for i := 0; i < want.Len(); i++ { + testEqual(t, want.Index(i), get.Index(i)) + } + default: + t.Fatalf("uset type %v", get.Kind()) + } +}