Skip to content

Commit

Permalink
limit unit to m
Browse files Browse the repository at this point in the history
  • Loading branch information
dogancanbakir committed Jul 21, 2023
1 parent 9831d71 commit 23f48de
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 17 deletions.
56 changes: 43 additions & 13 deletions ratelimit_var.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,20 @@ import (
"strings"
"time"

sliceutil "github.com/projectdiscovery/utils/slice"
stringsutil "github.com/projectdiscovery/utils/strings"
timeutil "github.com/projectdiscovery/utils/time"
)

var (
AllowedUnits = []string{"ns", "us", "ms", "s", "m"}
rateLimitOptionMap map[*RateLimitMap]Options
)

func init() {
rateLimitOptionMap = make(map[*RateLimitMap]Options)
}

type RateLimit struct {
MaxCount uint
Duration time.Duration
Expand All @@ -25,21 +35,34 @@ func (rateLimitMap *RateLimitMap) Set(value string) error {
if rateLimitMap.kv == nil {
rateLimitMap.kv = make(map[string]RateLimit)
}
var k, v string
if idxSep := strings.Index(value, kvSep); idxSep > 0 {
k = value[:idxSep]
v = value[idxSep+1:]

option, ok := rateLimitOptionMap[rateLimitMap]
if !ok {
option = StringSliceOptions
}
rateLimits, err := ToStringSlice(value, option)
if err != nil {
return err
}
// note:
// - inserting multiple times the same key will override the previous value
// - empty string is legitimate value

if k != "" {
rateLimit, err := parseRateLimit(v)
if err != nil {
return err

for _, rateLimit := range rateLimits {

var k, v string
if idxSep := strings.Index(rateLimit, kvSep); idxSep > 0 {
k = rateLimit[:idxSep]
v = rateLimit[idxSep+1:]
}
// note:
// - inserting multiple times the same key will override the previous v
// - empty string is legitimate rateLimit

if k != "" {
rateLimit, err := parseRateLimit(v)
if err != nil {
return err
}
rateLimitMap.kv[k] = rateLimit
}
rateLimitMap.kv[k] = rateLimit
}
return nil
}
Expand Down Expand Up @@ -88,6 +111,7 @@ func (flagSet *FlagSet) RateLimitMapVarP(field *RateLimitMap, long, short string
panic(fmt.Errorf("field cannot be nil for flag -%v", long))
}

rateLimitOptionMap[field] = options
for _, defaultItem := range defaultValue {
values, _ := ToStringSlice(defaultItem, options)
for _, value := range values {
Expand Down Expand Up @@ -124,9 +148,15 @@ func parseRateLimit(s string) (RateLimit, error) {
if err != nil {
return RateLimit{}, errors.New("parse error: " + err.Error())
}

duration, err := timeutil.ParseDuration("1" + sArr[1])
if err != nil {
return RateLimit{}, errors.New("parse error: " + err.Error())
}

if !sliceutil.Contains(AllowedUnits, sArr[1]) {
return RateLimit{}, errors.New("unit " + sArr[1] + " is not allowed")
}

return RateLimit{MaxCount: uint(maxCount), Duration: duration}, nil
}
10 changes: 6 additions & 4 deletions ratelimit_var_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,12 @@ func TestRateLimitMapVar(t *testing.T) {
)
os.Args = []string{
os.Args[0],
"-rls", "hackertarget=10/d",
"-rls", "hackertarget=10/m",
}
err := flagSet.Parse()
assert.Nil(t, err)
assert.Equal(t, RateLimit{MaxCount: 10, Duration: time.Hour * 24}, rateLimitMap.AsMap()["hackertarget"])
assert.Equal(t, RateLimit{MaxCount: 10, Duration: time.Minute}, rateLimitMap.AsMap()["hackertarget"])

tearDown(t.Name())
})

Expand All @@ -65,11 +66,12 @@ func TestRateLimitMapVar(t *testing.T) {
)
os.Args = []string{
os.Args[0],
"-rls", "hackertarget=10/d",
"-rls", "hackertarget=1/s,github=1/ms",
}
err := flagSet.Parse()
assert.Nil(t, err)
assert.Equal(t, RateLimit{MaxCount: 10, Duration: time.Hour * 24}, rateLimitMap.AsMap()["hackertarget"])
assert.Equal(t, RateLimit{MaxCount: 1, Duration: time.Second}, rateLimitMap.AsMap()["hackertarget"])
assert.Equal(t, RateLimit{MaxCount: 1, Duration: time.Millisecond}, rateLimitMap.AsMap()["github"])
tearDown(t.Name())
})

Expand Down

0 comments on commit 23f48de

Please sign in to comment.