diff --git a/examples/basic/main.go b/examples/basic/main.go index f4fc21a..a73d8ae 100644 --- a/examples/basic/main.go +++ b/examples/basic/main.go @@ -15,6 +15,7 @@ type Options struct { Address goflags.StringSlice fileSize goflags.Size duration time.Duration + rls goflags.RateLimitMap } func main() { @@ -28,6 +29,7 @@ func main() { flagSet.CreateGroup("info", "Info", flagSet.StringVarP(&testOptions.name, "name", "n", "", "name of the user"), flagSet.StringSliceVarP(&testOptions.Email, "email", "e", nil, "email of the user", goflags.CommaSeparatedStringSliceOptions), + flagSet.RateLimitMapVarP(&testOptions.rls, "rate-limits", "rls", nil, "rate limits in format k=v/d i.e hackertarget=10/s", goflags.CommaSeparatedStringSliceOptions), ) flagSet.CreateGroup("additional", "Additional", flagSet.StringVarP(&testOptions.Phone, "phone", "ph", "", "phone of the user"), @@ -40,4 +42,7 @@ func main() { if err := flagSet.Parse(); err != nil { log.Fatal(err) } + + // ratelimits value is + fmt.Printf("Got RateLimits: %+v\n", testOptions.rls) } diff --git a/ratelimit_var.go b/ratelimit_var.go new file mode 100644 index 0000000..bcd9cdb --- /dev/null +++ b/ratelimit_var.go @@ -0,0 +1,170 @@ +package goflags + +import ( + "errors" + "fmt" + "strconv" + "strings" + "time" + "unicode" + + stringsutil "github.com/projectdiscovery/utils/strings" + timeutil "github.com/projectdiscovery/utils/time" +) + +var ( + MaxRateLimitTime = time.Minute // anything above time.Minute is not practical (for our use case) + rateLimitOptionMap map[*RateLimitMap]Options +) + +func init() { + rateLimitOptionMap = make(map[*RateLimitMap]Options) +} + +type RateLimit struct { + MaxCount uint + Duration time.Duration +} + +type RateLimitMap struct { + kv map[string]RateLimit +} + +// Set inserts a value to the map. Format: key=value +func (rateLimitMap *RateLimitMap) Set(value string) error { + if rateLimitMap.kv == nil { + rateLimitMap.kv = make(map[string]RateLimit) + } + + option, ok := rateLimitOptionMap[rateLimitMap] + if !ok { + option = StringSliceOptions + } + rateLimits, err := ToStringSlice(value, option) + if err != nil { + return err + } + + for _, rateLimit := range rateLimits { + var k, v string + if idxSep := strings.Index(rateLimit, kvSep); idxSep > 0 { + k = rateLimit[:idxSep] + v = rateLimit[idxSep+1:] + } + + // note: + // - inserting multiple times the same key will override the previous v + // - empty string is legitimate rateLimit + if k != "" { + rateLimit, err := parseRateLimit(v) + if err != nil { + return err + } + rateLimitMap.kv[k] = rateLimit + } + } + return nil +} + +// Del removes the specified key +func (rateLimitMap *RateLimitMap) Del(key string) error { + if rateLimitMap.kv == nil { + return errors.New("empty runtime map") + } + delete(rateLimitMap.kv, key) + return nil +} + +// IsEmpty specifies if the underlying map is empty +func (rateLimitMap *RateLimitMap) IsEmpty() bool { + return rateLimitMap.kv == nil || len(rateLimitMap.kv) == 0 +} + +// AsMap returns the internal map as reference - changes are allowed +func (rateLimitMap *RateLimitMap) AsMap() map[string]RateLimit { + return rateLimitMap.kv +} + +func (rateLimitMap RateLimitMap) String() string { + defaultBuilder := &strings.Builder{} + defaultBuilder.WriteString("{") + + var items string + for k, v := range rateLimitMap.kv { + items += fmt.Sprintf("\"%s\":\"%d/%s\",", k, v.MaxCount, v.Duration.String()) + } + defaultBuilder.WriteString(stringsutil.TrimSuffixAny(items, ",", ":")) + defaultBuilder.WriteString("}") + return defaultBuilder.String() +} + +// RateLimitMapVar adds a ratelimit flag with a longname +func (flagSet *FlagSet) RateLimitMapVar(field *RateLimitMap, long string, defaultValue []string, usage string, options Options) *FlagData { + return flagSet.RateLimitMapVarP(field, long, "", defaultValue, usage, options) +} + +// RateLimitMapVarP adds a ratelimit flag with a short name and long name. +// It is equivalent to RateLimitMapVar, and also allows specifying ratelimits in days (e.g., "hackertarget=2/d" 2 requests per day, which is equivalent to 24h). +func (flagSet *FlagSet) RateLimitMapVarP(field *RateLimitMap, long, short string, defaultValue StringSlice, usage string, options Options) *FlagData { + if field == nil { + panic(fmt.Errorf("field cannot be nil for flag -%v", long)) + } + + rateLimitOptionMap[field] = options + for _, defaultItem := range defaultValue { + values, _ := ToStringSlice(defaultItem, options) + for _, value := range values { + if err := field.Set(value); err != nil { + panic(fmt.Errorf("failed to set default value for flag -%v: %v", long, err)) + } + } + } + + flagData := &FlagData{ + usage: usage, + long: long, + defaultValue: defaultValue, + skipMarshal: true, + } + if short != "" { + flagData.short = short + flagSet.CommandLine.Var(field, short, usage) + flagSet.flagKeys.Set(short, flagData) + } + flagSet.CommandLine.Var(field, long, usage) + flagSet.flagKeys.Set(long, flagData) + return flagData +} + +func parseRateLimit(s string) (RateLimit, error) { + sArr := strings.Split(s, "/") + + if len(sArr) < 2 { + return RateLimit{}, errors.New("parse error: expected format k=v/d (e.g., scanme.sh=10/s got " + s) + } + + maxCount, err := strconv.ParseUint(sArr[0], 10, 64) + if err != nil { + return RateLimit{}, errors.New("parse error: " + err.Error()) + } + timeValue := sArr[1] + if len(timeValue) > 0 { + // check if time is given ex: 1s + // if given value is just s (add prefix 1) + firstChar := timeValue[0] + if !unicode.IsDigit(rune(firstChar)) { + timeValue = "1" + timeValue + } + } + + duration, err := timeutil.ParseDuration(timeValue) + if err != nil { + return RateLimit{}, errors.New("parse error: " + err.Error()) + } + + if MaxRateLimitTime < duration { + return RateLimit{}, fmt.Errorf("duration cannot be more than %v but got %v", MaxRateLimitTime, duration) + } + + return RateLimit{MaxCount: uint(maxCount), Duration: duration}, nil +} diff --git a/ratelimit_var_test.go b/ratelimit_var_test.go new file mode 100644 index 0000000..41730cc --- /dev/null +++ b/ratelimit_var_test.go @@ -0,0 +1,93 @@ +package goflags + +import ( + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestRateLimitMapVar(t *testing.T) { + + t.Run("default-value", func(t *testing.T) { + var rateLimitMap RateLimitMap + flagSet := NewFlagSet() + flagSet.CreateGroup("Config", "Config", + flagSet.RateLimitMapVarP(&rateLimitMap, "rate-limits", "rls", []string{"hackertarget=1/ms"}, "rate limits", CommaSeparatedStringSliceOptions), + ) + os.Args = []string{ + os.Args[0], + } + err := flagSet.Parse() + assert.Nil(t, err) + assert.Equal(t, RateLimit{MaxCount: 1, Duration: time.Millisecond}, rateLimitMap.AsMap()["hackertarget"]) + tearDown(t.Name()) + }) + + t.Run("multiple-default-value", func(t *testing.T) { + var rateLimitMap RateLimitMap + flagSet := NewFlagSet() + flagSet.CreateGroup("Config", "Config", + flagSet.RateLimitMapVarP(&rateLimitMap, "rate-limits", "rls", []string{"hackertarget=1/s,github=1/ms"}, "rate limits", CommaSeparatedStringSliceOptions), + ) + os.Args = []string{ + os.Args[0], + } + err := flagSet.Parse() + assert.Nil(t, err) + assert.Equal(t, RateLimit{MaxCount: 1, Duration: time.Second}, rateLimitMap.AsMap()["hackertarget"]) + assert.Equal(t, RateLimit{MaxCount: 1, Duration: time.Millisecond}, rateLimitMap.AsMap()["github"]) + tearDown(t.Name()) + }) + + t.Run("valid-rate-limit", func(t *testing.T) { + var rateLimitMap RateLimitMap + flagSet := NewFlagSet() + flagSet.CreateGroup("Config", "Config", + flagSet.RateLimitMapVarP(&rateLimitMap, "rate-limits", "rls", nil, "rate limits", CommaSeparatedStringSliceOptions), + ) + os.Args = []string{ + os.Args[0], + "-rls", "hackertarget=10/m", + } + err := flagSet.Parse() + assert.Nil(t, err) + assert.Equal(t, RateLimit{MaxCount: 10, Duration: time.Minute}, rateLimitMap.AsMap()["hackertarget"]) + + tearDown(t.Name()) + }) + + t.Run("valid-rate-limits", func(t *testing.T) { + var rateLimitMap RateLimitMap + flagSet := NewFlagSet() + flagSet.CreateGroup("Config", "Config", + flagSet.RateLimitMapVarP(&rateLimitMap, "rate-limits", "rls", nil, "rate limits", CommaSeparatedStringSliceOptions), + ) + os.Args = []string{ + os.Args[0], + "-rls", "hackertarget=1/s,github=1/ms", + } + err := flagSet.Parse() + assert.Nil(t, err) + assert.Equal(t, RateLimit{MaxCount: 1, Duration: time.Second}, rateLimitMap.AsMap()["hackertarget"]) + assert.Equal(t, RateLimit{MaxCount: 1, Duration: time.Millisecond}, rateLimitMap.AsMap()["github"]) + tearDown(t.Name()) + }) + + t.Run("without-unit", func(t *testing.T) { + var rateLimitMap RateLimitMap + err := rateLimitMap.Set("hackertarget=1") + assert.NotNil(t, err) + assert.ErrorContains(t, err, "parse error") + tearDown(t.Name()) + }) + + t.Run("invalid-unit", func(t *testing.T) { + var rateLimitMap RateLimitMap + err := rateLimitMap.Set("hackertarget=1/x") + assert.NotNil(t, err) + assert.ErrorContains(t, err, "parse error") + tearDown(t.Name()) + }) +}