-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add ratelimit flag * use StringSlice * limit unit to m * add example + Max threshold * fix build test * fmt --------- Co-authored-by: Tarun Koyalwar <[email protected]> Co-authored-by: mzack <[email protected]>
- Loading branch information
1 parent
1ab6157
commit dbf9da1
Showing
3 changed files
with
268 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
package goflags | ||
|
||
import ( | ||
"errors" | ||
"fmt" | ||
"strconv" | ||
"strings" | ||
"time" | ||
"unicode" | ||
|
||
stringsutil "github.com/projectdiscovery/utils/strings" | ||
timeutil "github.com/projectdiscovery/utils/time" | ||
) | ||
|
||
var ( | ||
MaxRateLimitTime = time.Minute // anything above time.Minute is not practical (for our use case) | ||
rateLimitOptionMap map[*RateLimitMap]Options | ||
) | ||
|
||
func init() { | ||
rateLimitOptionMap = make(map[*RateLimitMap]Options) | ||
} | ||
|
||
type RateLimit struct { | ||
MaxCount uint | ||
Duration time.Duration | ||
} | ||
|
||
type RateLimitMap struct { | ||
kv map[string]RateLimit | ||
} | ||
|
||
// Set inserts a value to the map. Format: key=value | ||
func (rateLimitMap *RateLimitMap) Set(value string) error { | ||
if rateLimitMap.kv == nil { | ||
rateLimitMap.kv = make(map[string]RateLimit) | ||
} | ||
|
||
option, ok := rateLimitOptionMap[rateLimitMap] | ||
if !ok { | ||
option = StringSliceOptions | ||
} | ||
rateLimits, err := ToStringSlice(value, option) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
for _, rateLimit := range rateLimits { | ||
var k, v string | ||
if idxSep := strings.Index(rateLimit, kvSep); idxSep > 0 { | ||
k = rateLimit[:idxSep] | ||
v = rateLimit[idxSep+1:] | ||
} | ||
|
||
// note: | ||
// - inserting multiple times the same key will override the previous v | ||
// - empty string is legitimate rateLimit | ||
if k != "" { | ||
rateLimit, err := parseRateLimit(v) | ||
if err != nil { | ||
return err | ||
} | ||
rateLimitMap.kv[k] = rateLimit | ||
} | ||
} | ||
return nil | ||
} | ||
|
||
// Del removes the specified key | ||
func (rateLimitMap *RateLimitMap) Del(key string) error { | ||
if rateLimitMap.kv == nil { | ||
return errors.New("empty runtime map") | ||
} | ||
delete(rateLimitMap.kv, key) | ||
return nil | ||
} | ||
|
||
// IsEmpty specifies if the underlying map is empty | ||
func (rateLimitMap *RateLimitMap) IsEmpty() bool { | ||
return rateLimitMap.kv == nil || len(rateLimitMap.kv) == 0 | ||
} | ||
|
||
// AsMap returns the internal map as reference - changes are allowed | ||
func (rateLimitMap *RateLimitMap) AsMap() map[string]RateLimit { | ||
return rateLimitMap.kv | ||
} | ||
|
||
func (rateLimitMap RateLimitMap) String() string { | ||
defaultBuilder := &strings.Builder{} | ||
defaultBuilder.WriteString("{") | ||
|
||
var items string | ||
for k, v := range rateLimitMap.kv { | ||
items += fmt.Sprintf("\"%s\":\"%d/%s\",", k, v.MaxCount, v.Duration.String()) | ||
} | ||
defaultBuilder.WriteString(stringsutil.TrimSuffixAny(items, ",", ":")) | ||
defaultBuilder.WriteString("}") | ||
return defaultBuilder.String() | ||
} | ||
|
||
// RateLimitMapVar adds a ratelimit flag with a longname | ||
func (flagSet *FlagSet) RateLimitMapVar(field *RateLimitMap, long string, defaultValue []string, usage string, options Options) *FlagData { | ||
return flagSet.RateLimitMapVarP(field, long, "", defaultValue, usage, options) | ||
} | ||
|
||
// RateLimitMapVarP adds a ratelimit flag with a short name and long name. | ||
// It is equivalent to RateLimitMapVar, and also allows specifying ratelimits in days (e.g., "hackertarget=2/d" 2 requests per day, which is equivalent to 24h). | ||
func (flagSet *FlagSet) RateLimitMapVarP(field *RateLimitMap, long, short string, defaultValue StringSlice, usage string, options Options) *FlagData { | ||
if field == nil { | ||
panic(fmt.Errorf("field cannot be nil for flag -%v", long)) | ||
} | ||
|
||
rateLimitOptionMap[field] = options | ||
for _, defaultItem := range defaultValue { | ||
values, _ := ToStringSlice(defaultItem, options) | ||
for _, value := range values { | ||
if err := field.Set(value); err != nil { | ||
panic(fmt.Errorf("failed to set default value for flag -%v: %v", long, err)) | ||
} | ||
} | ||
} | ||
|
||
flagData := &FlagData{ | ||
usage: usage, | ||
long: long, | ||
defaultValue: defaultValue, | ||
skipMarshal: true, | ||
} | ||
if short != "" { | ||
flagData.short = short | ||
flagSet.CommandLine.Var(field, short, usage) | ||
flagSet.flagKeys.Set(short, flagData) | ||
} | ||
flagSet.CommandLine.Var(field, long, usage) | ||
flagSet.flagKeys.Set(long, flagData) | ||
return flagData | ||
} | ||
|
||
func parseRateLimit(s string) (RateLimit, error) { | ||
sArr := strings.Split(s, "/") | ||
|
||
if len(sArr) < 2 { | ||
return RateLimit{}, errors.New("parse error: expected format k=v/d (e.g., scanme.sh=10/s got " + s) | ||
} | ||
|
||
maxCount, err := strconv.ParseUint(sArr[0], 10, 64) | ||
if err != nil { | ||
return RateLimit{}, errors.New("parse error: " + err.Error()) | ||
} | ||
timeValue := sArr[1] | ||
if len(timeValue) > 0 { | ||
// check if time is given ex: 1s | ||
// if given value is just s (add prefix 1) | ||
firstChar := timeValue[0] | ||
if !unicode.IsDigit(rune(firstChar)) { | ||
timeValue = "1" + timeValue | ||
} | ||
} | ||
|
||
duration, err := timeutil.ParseDuration(timeValue) | ||
if err != nil { | ||
return RateLimit{}, errors.New("parse error: " + err.Error()) | ||
} | ||
|
||
if MaxRateLimitTime < duration { | ||
return RateLimit{}, fmt.Errorf("duration cannot be more than %v but got %v", MaxRateLimitTime, duration) | ||
} | ||
|
||
return RateLimit{MaxCount: uint(maxCount), Duration: duration}, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
package goflags | ||
|
||
import ( | ||
"os" | ||
"testing" | ||
"time" | ||
|
||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func TestRateLimitMapVar(t *testing.T) { | ||
|
||
t.Run("default-value", func(t *testing.T) { | ||
var rateLimitMap RateLimitMap | ||
flagSet := NewFlagSet() | ||
flagSet.CreateGroup("Config", "Config", | ||
flagSet.RateLimitMapVarP(&rateLimitMap, "rate-limits", "rls", []string{"hackertarget=1/ms"}, "rate limits", CommaSeparatedStringSliceOptions), | ||
) | ||
os.Args = []string{ | ||
os.Args[0], | ||
} | ||
err := flagSet.Parse() | ||
assert.Nil(t, err) | ||
assert.Equal(t, RateLimit{MaxCount: 1, Duration: time.Millisecond}, rateLimitMap.AsMap()["hackertarget"]) | ||
tearDown(t.Name()) | ||
}) | ||
|
||
t.Run("multiple-default-value", func(t *testing.T) { | ||
var rateLimitMap RateLimitMap | ||
flagSet := NewFlagSet() | ||
flagSet.CreateGroup("Config", "Config", | ||
flagSet.RateLimitMapVarP(&rateLimitMap, "rate-limits", "rls", []string{"hackertarget=1/s,github=1/ms"}, "rate limits", CommaSeparatedStringSliceOptions), | ||
) | ||
os.Args = []string{ | ||
os.Args[0], | ||
} | ||
err := flagSet.Parse() | ||
assert.Nil(t, err) | ||
assert.Equal(t, RateLimit{MaxCount: 1, Duration: time.Second}, rateLimitMap.AsMap()["hackertarget"]) | ||
assert.Equal(t, RateLimit{MaxCount: 1, Duration: time.Millisecond}, rateLimitMap.AsMap()["github"]) | ||
tearDown(t.Name()) | ||
}) | ||
|
||
t.Run("valid-rate-limit", func(t *testing.T) { | ||
var rateLimitMap RateLimitMap | ||
flagSet := NewFlagSet() | ||
flagSet.CreateGroup("Config", "Config", | ||
flagSet.RateLimitMapVarP(&rateLimitMap, "rate-limits", "rls", nil, "rate limits", CommaSeparatedStringSliceOptions), | ||
) | ||
os.Args = []string{ | ||
os.Args[0], | ||
"-rls", "hackertarget=10/m", | ||
} | ||
err := flagSet.Parse() | ||
assert.Nil(t, err) | ||
assert.Equal(t, RateLimit{MaxCount: 10, Duration: time.Minute}, rateLimitMap.AsMap()["hackertarget"]) | ||
|
||
tearDown(t.Name()) | ||
}) | ||
|
||
t.Run("valid-rate-limits", func(t *testing.T) { | ||
var rateLimitMap RateLimitMap | ||
flagSet := NewFlagSet() | ||
flagSet.CreateGroup("Config", "Config", | ||
flagSet.RateLimitMapVarP(&rateLimitMap, "rate-limits", "rls", nil, "rate limits", CommaSeparatedStringSliceOptions), | ||
) | ||
os.Args = []string{ | ||
os.Args[0], | ||
"-rls", "hackertarget=1/s,github=1/ms", | ||
} | ||
err := flagSet.Parse() | ||
assert.Nil(t, err) | ||
assert.Equal(t, RateLimit{MaxCount: 1, Duration: time.Second}, rateLimitMap.AsMap()["hackertarget"]) | ||
assert.Equal(t, RateLimit{MaxCount: 1, Duration: time.Millisecond}, rateLimitMap.AsMap()["github"]) | ||
tearDown(t.Name()) | ||
}) | ||
|
||
t.Run("without-unit", func(t *testing.T) { | ||
var rateLimitMap RateLimitMap | ||
err := rateLimitMap.Set("hackertarget=1") | ||
assert.NotNil(t, err) | ||
assert.ErrorContains(t, err, "parse error") | ||
tearDown(t.Name()) | ||
}) | ||
|
||
t.Run("invalid-unit", func(t *testing.T) { | ||
var rateLimitMap RateLimitMap | ||
err := rateLimitMap.Set("hackertarget=1/x") | ||
assert.NotNil(t, err) | ||
assert.ErrorContains(t, err, "parse error") | ||
tearDown(t.Name()) | ||
}) | ||
} |