From a96e9a9ce2c711d754dc4534f265decb119cdaef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Do=C4=9Fan=20Can=20Bak=C4=B1r?= Date: Wed, 19 Jul 2023 09:19:57 +0300 Subject: [PATCH 1/6] add ratelimit flag --- ratelimit_var.go | 129 ++++++++++++++++++++++++++++++++++++++++++ ratelimit_var_test.go | 75 ++++++++++++++++++++++++ 2 files changed, 204 insertions(+) create mode 100644 ratelimit_var.go create mode 100644 ratelimit_var_test.go diff --git a/ratelimit_var.go b/ratelimit_var.go new file mode 100644 index 0000000..8cb57e6 --- /dev/null +++ b/ratelimit_var.go @@ -0,0 +1,129 @@ +package goflags + +import ( + "errors" + "fmt" + "strconv" + "strings" + "time" + + stringsutil "github.com/projectdiscovery/utils/strings" + timeutil "github.com/projectdiscovery/utils/time" +) + +type RateLimit struct { + MaxCount uint + Duration time.Duration +} + +type RateLimitMap struct { + kv map[string]RateLimit +} + +// Set inserts a value to the map. Format: key=value +func (rateLimitMap *RateLimitMap) Set(value string) error { + if rateLimitMap.kv == nil { + rateLimitMap.kv = make(map[string]RateLimit) + } + var k, v string + if idxSep := strings.Index(value, kvSep); idxSep > 0 { + k = value[:idxSep] + v = value[idxSep+1:] + } + // note: + // - inserting multiple times the same key will override the previous value + // - empty string is legitimate value + + if k != "" { + rateLimit, err := parseRateLimit(v) + if err != nil { + return err + } + rateLimitMap.kv[k] = rateLimit + } + return nil +} + +// Del removes the specified key +func (rateLimitMap *RateLimitMap) Del(key string) error { + if rateLimitMap.kv == nil { + return errors.New("empty runtime map") + } + delete(rateLimitMap.kv, key) + return nil +} + +// IsEmpty specifies if the underlying map is empty +func (rateLimitMap *RateLimitMap) IsEmpty() bool { + return rateLimitMap.kv == nil || len(rateLimitMap.kv) == 0 +} + +// AsMap returns the internal map as reference - changes are allowed +func (rateLimitMap *RateLimitMap) AsMap() map[string]RateLimit { + return rateLimitMap.kv +} + +func (rateLimitMap RateLimitMap) String() string { + defaultBuilder := &strings.Builder{} + defaultBuilder.WriteString("{") + + var items string + for k, v := range rateLimitMap.kv { + items += fmt.Sprintf("\"%s\"=\"%s\"%s", k, v.Duration.String(), kvSep) + } + defaultBuilder.WriteString(stringsutil.TrimSuffixAny(items, ",", "=")) + defaultBuilder.WriteString("}") + return defaultBuilder.String() +} + +// RateLimitMapVar adds a ratelimit flag with a longname +func (flagSet *FlagSet) RateLimitMapVar(field *RateLimitMap, long string, defaultValue []string, usage string) *FlagData { + return flagSet.RateLimitMapVarP(field, long, "", defaultValue, usage) +} + +// RateLimitMapVarP adds a ratelimit flag with a short name and long name. +// It is equivalent to RateLimitMapVar, and also allows specifying ratelimits in days (e.g., "hackertarget=2/d" 2 requests per day, which is equivalent to 24h). +func (flagSet *FlagSet) RateLimitMapVarP(field *RateLimitMap, long, short string, defaultValue []string, usage string) *FlagData { + if field == nil { + panic(fmt.Errorf("field cannot be nil for flag -%v", long)) + } + + for _, item := range defaultValue { + if err := field.Set(item); err != nil { + panic(fmt.Errorf("failed to set default value for flag -%v: %v", long, err)) + } + } + + flagData := &FlagData{ + usage: usage, + long: long, + defaultValue: defaultValue, + skipMarshal: true, + } + if short != "" { + flagData.short = short + flagSet.CommandLine.Var(field, short, usage) + flagSet.flagKeys.Set(short, flagData) + } + flagSet.CommandLine.Var(field, long, usage) + flagSet.flagKeys.Set(long, flagData) + return flagData +} + +func parseRateLimit(s string) (RateLimit, error) { + sArr := strings.Split(s, "/") + + if len(sArr) < 2 { + return RateLimit{}, errors.New("parse error") + } + + maxCount, err := strconv.ParseUint(sArr[0], 10, 64) + if err != nil { + return RateLimit{}, errors.New("parse error: " + err.Error()) + } + duration, err := timeutil.ParseDuration("1" + sArr[1]) + if err != nil { + return RateLimit{}, errors.New("parse error: " + err.Error()) + } + return RateLimit{MaxCount: uint(maxCount), Duration: duration}, nil +} diff --git a/ratelimit_var_test.go b/ratelimit_var_test.go new file mode 100644 index 0000000..2ef9c2d --- /dev/null +++ b/ratelimit_var_test.go @@ -0,0 +1,75 @@ +package goflags + +import ( + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestRateLimitMapVar(t *testing.T) { + + t.Run("default-value", func(t *testing.T) { + var rateLimitMap RateLimitMap + flagSet := NewFlagSet() + flagSet.CreateGroup("Config", "Config", + flagSet.RateLimitMapVarP(&rateLimitMap, "rate-limits", "rls", []string{"hackertarget=1/ms"}, "rate limits"), + ) + os.Args = []string{ + os.Args[0], + } + err := flagSet.Parse() + assert.Nil(t, err) + assert.Equal(t, RateLimit{MaxCount: 1, Duration: time.Millisecond}, rateLimitMap.AsMap()["hackertarget"]) + tearDown(t.Name()) + }) + + t.Run("valid-rate-limit", func(t *testing.T) { + var rateLimitMap RateLimitMap + flagSet := NewFlagSet() + flagSet.CreateGroup("Config", "Config", + flagSet.RateLimitMapVarP(&rateLimitMap, "rate-limits", "rls", nil, "rate limits"), + ) + os.Args = []string{ + os.Args[0], + "-rls", "hackertarget=10/d", + } + err := flagSet.Parse() + assert.Nil(t, err) + assert.Equal(t, RateLimit{MaxCount: 10, Duration: time.Hour * 24}, rateLimitMap.AsMap()["hackertarget"]) + tearDown(t.Name()) + }) + + t.Run("valid-rate-limits", func(t *testing.T) { + var rateLimitMap RateLimitMap + flagSet := NewFlagSet() + flagSet.CreateGroup("Config", "Config", + flagSet.RateLimitMapVarP(&rateLimitMap, "rate-limits", "rls", nil, "rate limits"), + ) + os.Args = []string{ + os.Args[0], + "-rls", "hackertarget=10/d", + } + err := flagSet.Parse() + assert.Nil(t, err) + assert.Equal(t, RateLimit{MaxCount: 10, Duration: time.Hour * 24}, rateLimitMap.AsMap()["hackertarget"]) + tearDown(t.Name()) + }) + + t.Run("without-unit", func(t *testing.T) { + var rateLimitMap RateLimitMap + err := rateLimitMap.Set("hackertarget=1") + assert.NotNil(t, err) + assert.ErrorContains(t, err, "parse error") + tearDown(t.Name()) + }) + + t.Run("invalid-unit", func(t *testing.T) { + var rateLimitMap RateLimitMap + err := rateLimitMap.Set("hackertarget=1/x") + assert.NotNil(t, err) + assert.ErrorContains(t, err, "parse error") + tearDown(t.Name()) + }) +} From 9831d712b16dcfbb62d376aa45222fcde12f3dca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Do=C4=9Fan=20Can=20Bak=C4=B1r?= Date: Fri, 21 Jul 2023 07:13:03 +0000 Subject: [PATCH 2/6] use StringSlice --- ratelimit_var.go | 19 +++++++++++-------- ratelimit_var_test.go | 22 +++++++++++++++++++--- 2 files changed, 30 insertions(+), 11 deletions(-) diff --git a/ratelimit_var.go b/ratelimit_var.go index 8cb57e6..0c310ed 100644 --- a/ratelimit_var.go +++ b/ratelimit_var.go @@ -69,28 +69,31 @@ func (rateLimitMap RateLimitMap) String() string { var items string for k, v := range rateLimitMap.kv { - items += fmt.Sprintf("\"%s\"=\"%s\"%s", k, v.Duration.String(), kvSep) + items += fmt.Sprintf("\"%s\":\"%d/%s\",", k, v.MaxCount, v.Duration.String()) } - defaultBuilder.WriteString(stringsutil.TrimSuffixAny(items, ",", "=")) + defaultBuilder.WriteString(stringsutil.TrimSuffixAny(items, ",", ":")) defaultBuilder.WriteString("}") return defaultBuilder.String() } // RateLimitMapVar adds a ratelimit flag with a longname -func (flagSet *FlagSet) RateLimitMapVar(field *RateLimitMap, long string, defaultValue []string, usage string) *FlagData { - return flagSet.RateLimitMapVarP(field, long, "", defaultValue, usage) +func (flagSet *FlagSet) RateLimitMapVar(field *RateLimitMap, long string, defaultValue []string, usage string, options Options) *FlagData { + return flagSet.RateLimitMapVarP(field, long, "", defaultValue, usage, options) } // RateLimitMapVarP adds a ratelimit flag with a short name and long name. // It is equivalent to RateLimitMapVar, and also allows specifying ratelimits in days (e.g., "hackertarget=2/d" 2 requests per day, which is equivalent to 24h). -func (flagSet *FlagSet) RateLimitMapVarP(field *RateLimitMap, long, short string, defaultValue []string, usage string) *FlagData { +func (flagSet *FlagSet) RateLimitMapVarP(field *RateLimitMap, long, short string, defaultValue StringSlice, usage string, options Options) *FlagData { if field == nil { panic(fmt.Errorf("field cannot be nil for flag -%v", long)) } - for _, item := range defaultValue { - if err := field.Set(item); err != nil { - panic(fmt.Errorf("failed to set default value for flag -%v: %v", long, err)) + for _, defaultItem := range defaultValue { + values, _ := ToStringSlice(defaultItem, options) + for _, value := range values { + if err := field.Set(value); err != nil { + panic(fmt.Errorf("failed to set default value for flag -%v: %v", long, err)) + } } } diff --git a/ratelimit_var_test.go b/ratelimit_var_test.go index 2ef9c2d..a68913e 100644 --- a/ratelimit_var_test.go +++ b/ratelimit_var_test.go @@ -14,7 +14,7 @@ func TestRateLimitMapVar(t *testing.T) { var rateLimitMap RateLimitMap flagSet := NewFlagSet() flagSet.CreateGroup("Config", "Config", - flagSet.RateLimitMapVarP(&rateLimitMap, "rate-limits", "rls", []string{"hackertarget=1/ms"}, "rate limits"), + flagSet.RateLimitMapVarP(&rateLimitMap, "rate-limits", "rls", []string{"hackertarget=1/ms"}, "rate limits", CommaSeparatedStringSliceOptions), ) os.Args = []string{ os.Args[0], @@ -25,11 +25,27 @@ func TestRateLimitMapVar(t *testing.T) { tearDown(t.Name()) }) + t.Run("multiple-default-value", func(t *testing.T) { + var rateLimitMap RateLimitMap + flagSet := NewFlagSet() + flagSet.CreateGroup("Config", "Config", + flagSet.RateLimitMapVarP(&rateLimitMap, "rate-limits", "rls", []string{"hackertarget=1/s,github=1/ms"}, "rate limits", CommaSeparatedStringSliceOptions), + ) + os.Args = []string{ + os.Args[0], + } + err := flagSet.Parse() + assert.Nil(t, err) + assert.Equal(t, RateLimit{MaxCount: 1, Duration: time.Second}, rateLimitMap.AsMap()["hackertarget"]) + assert.Equal(t, RateLimit{MaxCount: 1, Duration: time.Millisecond}, rateLimitMap.AsMap()["github"]) + tearDown(t.Name()) + }) + t.Run("valid-rate-limit", func(t *testing.T) { var rateLimitMap RateLimitMap flagSet := NewFlagSet() flagSet.CreateGroup("Config", "Config", - flagSet.RateLimitMapVarP(&rateLimitMap, "rate-limits", "rls", nil, "rate limits"), + flagSet.RateLimitMapVarP(&rateLimitMap, "rate-limits", "rls", nil, "rate limits", CommaSeparatedStringSliceOptions), ) os.Args = []string{ os.Args[0], @@ -45,7 +61,7 @@ func TestRateLimitMapVar(t *testing.T) { var rateLimitMap RateLimitMap flagSet := NewFlagSet() flagSet.CreateGroup("Config", "Config", - flagSet.RateLimitMapVarP(&rateLimitMap, "rate-limits", "rls", nil, "rate limits"), + flagSet.RateLimitMapVarP(&rateLimitMap, "rate-limits", "rls", nil, "rate limits", CommaSeparatedStringSliceOptions), ) os.Args = []string{ os.Args[0], From 23f48de56f74ce4a17f7ed025cfab61838fa584e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Do=C4=9Fan=20Can=20Bak=C4=B1r?= Date: Fri, 21 Jul 2023 10:24:34 +0000 Subject: [PATCH 3/6] limit unit to m --- ratelimit_var.go | 56 +++++++++++++++++++++++++++++++++---------- ratelimit_var_test.go | 10 ++++---- 2 files changed, 49 insertions(+), 17 deletions(-) diff --git a/ratelimit_var.go b/ratelimit_var.go index 0c310ed..d28f072 100644 --- a/ratelimit_var.go +++ b/ratelimit_var.go @@ -7,10 +7,20 @@ import ( "strings" "time" + sliceutil "github.com/projectdiscovery/utils/slice" stringsutil "github.com/projectdiscovery/utils/strings" timeutil "github.com/projectdiscovery/utils/time" ) +var ( + AllowedUnits = []string{"ns", "us", "ms", "s", "m"} + rateLimitOptionMap map[*RateLimitMap]Options +) + +func init() { + rateLimitOptionMap = make(map[*RateLimitMap]Options) +} + type RateLimit struct { MaxCount uint Duration time.Duration @@ -25,21 +35,34 @@ func (rateLimitMap *RateLimitMap) Set(value string) error { if rateLimitMap.kv == nil { rateLimitMap.kv = make(map[string]RateLimit) } - var k, v string - if idxSep := strings.Index(value, kvSep); idxSep > 0 { - k = value[:idxSep] - v = value[idxSep+1:] + + option, ok := rateLimitOptionMap[rateLimitMap] + if !ok { + option = StringSliceOptions + } + rateLimits, err := ToStringSlice(value, option) + if err != nil { + return err } - // note: - // - inserting multiple times the same key will override the previous value - // - empty string is legitimate value - - if k != "" { - rateLimit, err := parseRateLimit(v) - if err != nil { - return err + + for _, rateLimit := range rateLimits { + + var k, v string + if idxSep := strings.Index(rateLimit, kvSep); idxSep > 0 { + k = rateLimit[:idxSep] + v = rateLimit[idxSep+1:] + } + // note: + // - inserting multiple times the same key will override the previous v + // - empty string is legitimate rateLimit + + if k != "" { + rateLimit, err := parseRateLimit(v) + if err != nil { + return err + } + rateLimitMap.kv[k] = rateLimit } - rateLimitMap.kv[k] = rateLimit } return nil } @@ -88,6 +111,7 @@ func (flagSet *FlagSet) RateLimitMapVarP(field *RateLimitMap, long, short string panic(fmt.Errorf("field cannot be nil for flag -%v", long)) } + rateLimitOptionMap[field] = options for _, defaultItem := range defaultValue { values, _ := ToStringSlice(defaultItem, options) for _, value := range values { @@ -124,9 +148,15 @@ func parseRateLimit(s string) (RateLimit, error) { if err != nil { return RateLimit{}, errors.New("parse error: " + err.Error()) } + duration, err := timeutil.ParseDuration("1" + sArr[1]) if err != nil { return RateLimit{}, errors.New("parse error: " + err.Error()) } + + if !sliceutil.Contains(AllowedUnits, sArr[1]) { + return RateLimit{}, errors.New("unit " + sArr[1] + " is not allowed") + } + return RateLimit{MaxCount: uint(maxCount), Duration: duration}, nil } diff --git a/ratelimit_var_test.go b/ratelimit_var_test.go index a68913e..41730cc 100644 --- a/ratelimit_var_test.go +++ b/ratelimit_var_test.go @@ -49,11 +49,12 @@ func TestRateLimitMapVar(t *testing.T) { ) os.Args = []string{ os.Args[0], - "-rls", "hackertarget=10/d", + "-rls", "hackertarget=10/m", } err := flagSet.Parse() assert.Nil(t, err) - assert.Equal(t, RateLimit{MaxCount: 10, Duration: time.Hour * 24}, rateLimitMap.AsMap()["hackertarget"]) + assert.Equal(t, RateLimit{MaxCount: 10, Duration: time.Minute}, rateLimitMap.AsMap()["hackertarget"]) + tearDown(t.Name()) }) @@ -65,11 +66,12 @@ func TestRateLimitMapVar(t *testing.T) { ) os.Args = []string{ os.Args[0], - "-rls", "hackertarget=10/d", + "-rls", "hackertarget=1/s,github=1/ms", } err := flagSet.Parse() assert.Nil(t, err) - assert.Equal(t, RateLimit{MaxCount: 10, Duration: time.Hour * 24}, rateLimitMap.AsMap()["hackertarget"]) + assert.Equal(t, RateLimit{MaxCount: 1, Duration: time.Second}, rateLimitMap.AsMap()["hackertarget"]) + assert.Equal(t, RateLimit{MaxCount: 1, Duration: time.Millisecond}, rateLimitMap.AsMap()["github"]) tearDown(t.Name()) }) From de209574c515ea68a99081fecbd1aa13cd96f96f Mon Sep 17 00:00:00 2001 From: Tarun Koyalwar Date: Fri, 21 Jul 2023 20:28:27 +0530 Subject: [PATCH 4/6] add example + Max threshold --- examples/basic/main.go | 5 +++++ ratelimit_var.go | 11 +++++------ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/examples/basic/main.go b/examples/basic/main.go index f4fc21a..a73d8ae 100644 --- a/examples/basic/main.go +++ b/examples/basic/main.go @@ -15,6 +15,7 @@ type Options struct { Address goflags.StringSlice fileSize goflags.Size duration time.Duration + rls goflags.RateLimitMap } func main() { @@ -28,6 +29,7 @@ func main() { flagSet.CreateGroup("info", "Info", flagSet.StringVarP(&testOptions.name, "name", "n", "", "name of the user"), flagSet.StringSliceVarP(&testOptions.Email, "email", "e", nil, "email of the user", goflags.CommaSeparatedStringSliceOptions), + flagSet.RateLimitMapVarP(&testOptions.rls, "rate-limits", "rls", nil, "rate limits in format k=v/d i.e hackertarget=10/s", goflags.CommaSeparatedStringSliceOptions), ) flagSet.CreateGroup("additional", "Additional", flagSet.StringVarP(&testOptions.Phone, "phone", "ph", "", "phone of the user"), @@ -40,4 +42,7 @@ func main() { if err := flagSet.Parse(); err != nil { log.Fatal(err) } + + // ratelimits value is + fmt.Printf("Got RateLimits: %+v\n", testOptions.rls) } diff --git a/ratelimit_var.go b/ratelimit_var.go index d28f072..246cd68 100644 --- a/ratelimit_var.go +++ b/ratelimit_var.go @@ -7,13 +7,12 @@ import ( "strings" "time" - sliceutil "github.com/projectdiscovery/utils/slice" stringsutil "github.com/projectdiscovery/utils/strings" timeutil "github.com/projectdiscovery/utils/time" ) var ( - AllowedUnits = []string{"ns", "us", "ms", "s", "m"} + MaxRateLimitTime = time.Minute // anything above time.Minute is not practical (for our use case) rateLimitOptionMap map[*RateLimitMap]Options ) @@ -141,7 +140,7 @@ func parseRateLimit(s string) (RateLimit, error) { sArr := strings.Split(s, "/") if len(sArr) < 2 { - return RateLimit{}, errors.New("parse error") + return RateLimit{}, errors.New("parse error: expected format k=v/d (e.g., scanme.sh=10/s got " + s) } maxCount, err := strconv.ParseUint(sArr[0], 10, 64) @@ -149,13 +148,13 @@ func parseRateLimit(s string) (RateLimit, error) { return RateLimit{}, errors.New("parse error: " + err.Error()) } - duration, err := timeutil.ParseDuration("1" + sArr[1]) + duration, err := timeutil.ParseDuration(sArr[1]) if err != nil { return RateLimit{}, errors.New("parse error: " + err.Error()) } - if !sliceutil.Contains(AllowedUnits, sArr[1]) { - return RateLimit{}, errors.New("unit " + sArr[1] + " is not allowed") + if MaxRateLimitTime < duration { + return RateLimit{}, fmt.Errorf("duration cannot be more than %v but got %v", MaxRateLimitTime, duration) } return RateLimit{MaxCount: uint(maxCount), Duration: duration}, nil From 1040830ead2d6a88044369cfc2e6f095a50bf536 Mon Sep 17 00:00:00 2001 From: Tarun Koyalwar Date: Fri, 21 Jul 2023 20:34:42 +0530 Subject: [PATCH 5/6] fix build test --- ratelimit_var.go | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/ratelimit_var.go b/ratelimit_var.go index 246cd68..57edf60 100644 --- a/ratelimit_var.go +++ b/ratelimit_var.go @@ -6,6 +6,7 @@ import ( "strconv" "strings" "time" + "unicode" stringsutil "github.com/projectdiscovery/utils/strings" timeutil "github.com/projectdiscovery/utils/time" @@ -147,8 +148,17 @@ func parseRateLimit(s string) (RateLimit, error) { if err != nil { return RateLimit{}, errors.New("parse error: " + err.Error()) } + timeValue := sArr[1] + if len(timeValue) > 0 { + // check if time is given ex: 1s + // if given value is just s (add prefix 1) + firstChar := timeValue[0] + if !unicode.IsDigit(rune(firstChar)) { + timeValue = "1" + timeValue + } + } - duration, err := timeutil.ParseDuration(sArr[1]) + duration, err := timeutil.ParseDuration(timeValue) if err != nil { return RateLimit{}, errors.New("parse error: " + err.Error()) } From 12ec7fb868575c9b30e0b4b13dcca2bff649d3f0 Mon Sep 17 00:00:00 2001 From: mzack Date: Mon, 24 Jul 2023 14:13:54 +0200 Subject: [PATCH 6/6] fmt --- ratelimit_var.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ratelimit_var.go b/ratelimit_var.go index 57edf60..bcd9cdb 100644 --- a/ratelimit_var.go +++ b/ratelimit_var.go @@ -46,16 +46,15 @@ func (rateLimitMap *RateLimitMap) Set(value string) error { } for _, rateLimit := range rateLimits { - var k, v string if idxSep := strings.Index(rateLimit, kvSep); idxSep > 0 { k = rateLimit[:idxSep] v = rateLimit[idxSep+1:] } + // note: // - inserting multiple times the same key will override the previous v // - empty string is legitimate rateLimit - if k != "" { rateLimit, err := parseRateLimit(v) if err != nil {