Skip to content

Commit

Permalink
add ratelimit flag (#123)
Browse files Browse the repository at this point in the history
* add ratelimit flag

* use StringSlice

* limit unit to m

* add example + Max threshold

* fix build test

* fmt

---------

Co-authored-by: Tarun Koyalwar <[email protected]>
Co-authored-by: mzack <[email protected]>
  • Loading branch information
3 people authored Jul 24, 2023
1 parent 1ab6157 commit dbf9da1
Show file tree
Hide file tree
Showing 3 changed files with 268 additions and 0 deletions.
5 changes: 5 additions & 0 deletions examples/basic/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ type Options struct {
Address goflags.StringSlice
fileSize goflags.Size
duration time.Duration
rls goflags.RateLimitMap
}

func main() {
Expand All @@ -28,6 +29,7 @@ func main() {
flagSet.CreateGroup("info", "Info",
flagSet.StringVarP(&testOptions.name, "name", "n", "", "name of the user"),
flagSet.StringSliceVarP(&testOptions.Email, "email", "e", nil, "email of the user", goflags.CommaSeparatedStringSliceOptions),
flagSet.RateLimitMapVarP(&testOptions.rls, "rate-limits", "rls", nil, "rate limits in format k=v/d i.e hackertarget=10/s", goflags.CommaSeparatedStringSliceOptions),
)
flagSet.CreateGroup("additional", "Additional",
flagSet.StringVarP(&testOptions.Phone, "phone", "ph", "", "phone of the user"),
Expand All @@ -40,4 +42,7 @@ func main() {
if err := flagSet.Parse(); err != nil {
log.Fatal(err)
}

// ratelimits value is
fmt.Printf("Got RateLimits: %+v\n", testOptions.rls)
}
170 changes: 170 additions & 0 deletions ratelimit_var.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
package goflags

import (
"errors"
"fmt"
"strconv"
"strings"
"time"
"unicode"

stringsutil "github.com/projectdiscovery/utils/strings"
timeutil "github.com/projectdiscovery/utils/time"
)

var (
MaxRateLimitTime = time.Minute // anything above time.Minute is not practical (for our use case)
rateLimitOptionMap map[*RateLimitMap]Options
)

func init() {
rateLimitOptionMap = make(map[*RateLimitMap]Options)
}

type RateLimit struct {
MaxCount uint
Duration time.Duration
}

type RateLimitMap struct {
kv map[string]RateLimit
}

// Set inserts a value to the map. Format: key=value
func (rateLimitMap *RateLimitMap) Set(value string) error {
if rateLimitMap.kv == nil {
rateLimitMap.kv = make(map[string]RateLimit)
}

option, ok := rateLimitOptionMap[rateLimitMap]
if !ok {
option = StringSliceOptions
}
rateLimits, err := ToStringSlice(value, option)
if err != nil {
return err
}

for _, rateLimit := range rateLimits {
var k, v string
if idxSep := strings.Index(rateLimit, kvSep); idxSep > 0 {
k = rateLimit[:idxSep]
v = rateLimit[idxSep+1:]
}

// note:
// - inserting multiple times the same key will override the previous v
// - empty string is legitimate rateLimit
if k != "" {
rateLimit, err := parseRateLimit(v)
if err != nil {
return err
}
rateLimitMap.kv[k] = rateLimit
}
}
return nil
}

// Del removes the specified key
func (rateLimitMap *RateLimitMap) Del(key string) error {
if rateLimitMap.kv == nil {
return errors.New("empty runtime map")
}
delete(rateLimitMap.kv, key)
return nil
}

// IsEmpty specifies if the underlying map is empty
func (rateLimitMap *RateLimitMap) IsEmpty() bool {
return rateLimitMap.kv == nil || len(rateLimitMap.kv) == 0
}

// AsMap returns the internal map as reference - changes are allowed
func (rateLimitMap *RateLimitMap) AsMap() map[string]RateLimit {
return rateLimitMap.kv
}

func (rateLimitMap RateLimitMap) String() string {
defaultBuilder := &strings.Builder{}
defaultBuilder.WriteString("{")

var items string
for k, v := range rateLimitMap.kv {
items += fmt.Sprintf("\"%s\":\"%d/%s\",", k, v.MaxCount, v.Duration.String())
}
defaultBuilder.WriteString(stringsutil.TrimSuffixAny(items, ",", ":"))
defaultBuilder.WriteString("}")
return defaultBuilder.String()
}

// RateLimitMapVar adds a ratelimit flag with a longname
func (flagSet *FlagSet) RateLimitMapVar(field *RateLimitMap, long string, defaultValue []string, usage string, options Options) *FlagData {
return flagSet.RateLimitMapVarP(field, long, "", defaultValue, usage, options)
}

// RateLimitMapVarP adds a ratelimit flag with a short name and long name.
// It is equivalent to RateLimitMapVar, and also allows specifying ratelimits in days (e.g., "hackertarget=2/d" 2 requests per day, which is equivalent to 24h).
func (flagSet *FlagSet) RateLimitMapVarP(field *RateLimitMap, long, short string, defaultValue StringSlice, usage string, options Options) *FlagData {
if field == nil {
panic(fmt.Errorf("field cannot be nil for flag -%v", long))
}

rateLimitOptionMap[field] = options
for _, defaultItem := range defaultValue {
values, _ := ToStringSlice(defaultItem, options)
for _, value := range values {
if err := field.Set(value); err != nil {
panic(fmt.Errorf("failed to set default value for flag -%v: %v", long, err))
}
}
}

flagData := &FlagData{
usage: usage,
long: long,
defaultValue: defaultValue,
skipMarshal: true,
}
if short != "" {
flagData.short = short
flagSet.CommandLine.Var(field, short, usage)
flagSet.flagKeys.Set(short, flagData)
}
flagSet.CommandLine.Var(field, long, usage)
flagSet.flagKeys.Set(long, flagData)
return flagData
}

func parseRateLimit(s string) (RateLimit, error) {
sArr := strings.Split(s, "/")

if len(sArr) < 2 {
return RateLimit{}, errors.New("parse error: expected format k=v/d (e.g., scanme.sh=10/s got " + s)
}

maxCount, err := strconv.ParseUint(sArr[0], 10, 64)
if err != nil {
return RateLimit{}, errors.New("parse error: " + err.Error())
}
timeValue := sArr[1]
if len(timeValue) > 0 {
// check if time is given ex: 1s
// if given value is just s (add prefix 1)
firstChar := timeValue[0]
if !unicode.IsDigit(rune(firstChar)) {
timeValue = "1" + timeValue
}
}

duration, err := timeutil.ParseDuration(timeValue)
if err != nil {
return RateLimit{}, errors.New("parse error: " + err.Error())
}

if MaxRateLimitTime < duration {
return RateLimit{}, fmt.Errorf("duration cannot be more than %v but got %v", MaxRateLimitTime, duration)
}

return RateLimit{MaxCount: uint(maxCount), Duration: duration}, nil
}
93 changes: 93 additions & 0 deletions ratelimit_var_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package goflags

import (
"os"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestRateLimitMapVar(t *testing.T) {

t.Run("default-value", func(t *testing.T) {
var rateLimitMap RateLimitMap
flagSet := NewFlagSet()
flagSet.CreateGroup("Config", "Config",
flagSet.RateLimitMapVarP(&rateLimitMap, "rate-limits", "rls", []string{"hackertarget=1/ms"}, "rate limits", CommaSeparatedStringSliceOptions),
)
os.Args = []string{
os.Args[0],
}
err := flagSet.Parse()
assert.Nil(t, err)
assert.Equal(t, RateLimit{MaxCount: 1, Duration: time.Millisecond}, rateLimitMap.AsMap()["hackertarget"])
tearDown(t.Name())
})

t.Run("multiple-default-value", func(t *testing.T) {
var rateLimitMap RateLimitMap
flagSet := NewFlagSet()
flagSet.CreateGroup("Config", "Config",
flagSet.RateLimitMapVarP(&rateLimitMap, "rate-limits", "rls", []string{"hackertarget=1/s,github=1/ms"}, "rate limits", CommaSeparatedStringSliceOptions),
)
os.Args = []string{
os.Args[0],
}
err := flagSet.Parse()
assert.Nil(t, err)
assert.Equal(t, RateLimit{MaxCount: 1, Duration: time.Second}, rateLimitMap.AsMap()["hackertarget"])
assert.Equal(t, RateLimit{MaxCount: 1, Duration: time.Millisecond}, rateLimitMap.AsMap()["github"])
tearDown(t.Name())
})

t.Run("valid-rate-limit", func(t *testing.T) {
var rateLimitMap RateLimitMap
flagSet := NewFlagSet()
flagSet.CreateGroup("Config", "Config",
flagSet.RateLimitMapVarP(&rateLimitMap, "rate-limits", "rls", nil, "rate limits", CommaSeparatedStringSliceOptions),
)
os.Args = []string{
os.Args[0],
"-rls", "hackertarget=10/m",
}
err := flagSet.Parse()
assert.Nil(t, err)
assert.Equal(t, RateLimit{MaxCount: 10, Duration: time.Minute}, rateLimitMap.AsMap()["hackertarget"])

tearDown(t.Name())
})

t.Run("valid-rate-limits", func(t *testing.T) {
var rateLimitMap RateLimitMap
flagSet := NewFlagSet()
flagSet.CreateGroup("Config", "Config",
flagSet.RateLimitMapVarP(&rateLimitMap, "rate-limits", "rls", nil, "rate limits", CommaSeparatedStringSliceOptions),
)
os.Args = []string{
os.Args[0],
"-rls", "hackertarget=1/s,github=1/ms",
}
err := flagSet.Parse()
assert.Nil(t, err)
assert.Equal(t, RateLimit{MaxCount: 1, Duration: time.Second}, rateLimitMap.AsMap()["hackertarget"])
assert.Equal(t, RateLimit{MaxCount: 1, Duration: time.Millisecond}, rateLimitMap.AsMap()["github"])
tearDown(t.Name())
})

t.Run("without-unit", func(t *testing.T) {
var rateLimitMap RateLimitMap
err := rateLimitMap.Set("hackertarget=1")
assert.NotNil(t, err)
assert.ErrorContains(t, err, "parse error")
tearDown(t.Name())
})

t.Run("invalid-unit", func(t *testing.T) {
var rateLimitMap RateLimitMap
err := rateLimitMap.Set("hackertarget=1/x")
assert.NotNil(t, err)
assert.ErrorContains(t, err, "parse error")
tearDown(t.Name())
})
}

0 comments on commit dbf9da1

Please sign in to comment.