diff --git a/ratelimit_var.go b/ratelimit_var.go index 0c310ed..d28f072 100644 --- a/ratelimit_var.go +++ b/ratelimit_var.go @@ -7,10 +7,20 @@ import ( "strings" "time" + sliceutil "github.com/projectdiscovery/utils/slice" stringsutil "github.com/projectdiscovery/utils/strings" timeutil "github.com/projectdiscovery/utils/time" ) +var ( + AllowedUnits = []string{"ns", "us", "ms", "s", "m"} + rateLimitOptionMap map[*RateLimitMap]Options +) + +func init() { + rateLimitOptionMap = make(map[*RateLimitMap]Options) +} + type RateLimit struct { MaxCount uint Duration time.Duration @@ -25,21 +35,34 @@ func (rateLimitMap *RateLimitMap) Set(value string) error { if rateLimitMap.kv == nil { rateLimitMap.kv = make(map[string]RateLimit) } - var k, v string - if idxSep := strings.Index(value, kvSep); idxSep > 0 { - k = value[:idxSep] - v = value[idxSep+1:] + + option, ok := rateLimitOptionMap[rateLimitMap] + if !ok { + option = StringSliceOptions + } + rateLimits, err := ToStringSlice(value, option) + if err != nil { + return err } - // note: - // - inserting multiple times the same key will override the previous value - // - empty string is legitimate value - - if k != "" { - rateLimit, err := parseRateLimit(v) - if err != nil { - return err + + for _, rateLimit := range rateLimits { + + var k, v string + if idxSep := strings.Index(rateLimit, kvSep); idxSep > 0 { + k = rateLimit[:idxSep] + v = rateLimit[idxSep+1:] + } + // note: + // - inserting multiple times the same key will override the previous v + // - empty string is legitimate rateLimit + + if k != "" { + rateLimit, err := parseRateLimit(v) + if err != nil { + return err + } + rateLimitMap.kv[k] = rateLimit } - rateLimitMap.kv[k] = rateLimit } return nil } @@ -88,6 +111,7 @@ func (flagSet *FlagSet) RateLimitMapVarP(field *RateLimitMap, long, short string panic(fmt.Errorf("field cannot be nil for flag -%v", long)) } + rateLimitOptionMap[field] = options for _, defaultItem := range defaultValue { values, _ := ToStringSlice(defaultItem, options) for _, value := range values { @@ -124,9 +148,15 @@ func parseRateLimit(s string) (RateLimit, error) { if err != nil { return RateLimit{}, errors.New("parse error: " + err.Error()) } + duration, err := timeutil.ParseDuration("1" + sArr[1]) if err != nil { return RateLimit{}, errors.New("parse error: " + err.Error()) } + + if !sliceutil.Contains(AllowedUnits, sArr[1]) { + return RateLimit{}, errors.New("unit " + sArr[1] + " is not allowed") + } + return RateLimit{MaxCount: uint(maxCount), Duration: duration}, nil } diff --git a/ratelimit_var_test.go b/ratelimit_var_test.go index a68913e..41730cc 100644 --- a/ratelimit_var_test.go +++ b/ratelimit_var_test.go @@ -49,11 +49,12 @@ func TestRateLimitMapVar(t *testing.T) { ) os.Args = []string{ os.Args[0], - "-rls", "hackertarget=10/d", + "-rls", "hackertarget=10/m", } err := flagSet.Parse() assert.Nil(t, err) - assert.Equal(t, RateLimit{MaxCount: 10, Duration: time.Hour * 24}, rateLimitMap.AsMap()["hackertarget"]) + assert.Equal(t, RateLimit{MaxCount: 10, Duration: time.Minute}, rateLimitMap.AsMap()["hackertarget"]) + tearDown(t.Name()) }) @@ -65,11 +66,12 @@ func TestRateLimitMapVar(t *testing.T) { ) os.Args = []string{ os.Args[0], - "-rls", "hackertarget=10/d", + "-rls", "hackertarget=1/s,github=1/ms", } err := flagSet.Parse() assert.Nil(t, err) - assert.Equal(t, RateLimit{MaxCount: 10, Duration: time.Hour * 24}, rateLimitMap.AsMap()["hackertarget"]) + assert.Equal(t, RateLimit{MaxCount: 1, Duration: time.Second}, rateLimitMap.AsMap()["hackertarget"]) + assert.Equal(t, RateLimit{MaxCount: 1, Duration: time.Millisecond}, rateLimitMap.AsMap()["github"]) tearDown(t.Name()) })