Skip to content

Commit

Permalink
Skip header value matching if entry is empty slice (#104)
Browse files Browse the repository at this point in the history
* Skip header value matching if entry is empty slice

* Test function name tweak

* Fix race cond for test by atomic int & sync map

* Fix test build error

Repleace atomic.Int64 and sync.Map with a naive map with mutex lock.
Doing so because the build tool is at go version 1.14,
while atomic.Int64 is introduced at go 1.19 I think.

* Shorten test function name to pass ci lint

* Simplify test case to pass cyclomatic complexity constraint
  • Loading branch information
Xinyu-bot authored Sep 21, 2022
1 parent 376acbc commit 0b61d55
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 0 deletions.
4 changes: 4 additions & 0 deletions tollbooth.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ func ShouldSkipLimiter(lmt *limiter.Limiter, r *http.Request) bool {
requestHeadersDefinedInLimiter = false

for headerKey, headerValues := range lmtHeaders {
if len(headerValues) == 0 {
requestHeadersDefinedInLimiter = true
continue
}
for _, headerValue := range headerValues {
if r.Header.Get(headerKey) == headerValue {
requestHeadersDefinedInLimiter = true
Expand Down
118 changes: 118 additions & 0 deletions tollbooth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -487,3 +488,120 @@ func isInSlice(key string, keys []string) bool {
}
return false
}

type LockMap struct {
m map[string]int64
sync.Mutex
}

func (lm *LockMap) Set(key string, value int64) {
lm.Lock()
lm.m[key] = value
lm.Unlock()
}

func (lm *LockMap) Get(key string) (int64, bool) {
lm.Lock()
value, ok := lm.m[key]
lm.Unlock()
return value, ok
}

func (lm *LockMap) Add(key string, incr int64) {
lm.Lock()
if val, ok := lm.m[key]; ok {
lm.m[key] = val + incr
} else {
lm.m[key] = incr
}
lm.Unlock()
}

func TestLimitHandlerEmptyHeader(t *testing.T) {
lmt := limiter.New(nil).SetMax(1).SetBurst(1)
lmt.SetIPLookups([]string{"X-Real-IP", "RemoteAddr", "X-Forwarded-For"})
lmt.SetMethods([]string{"POST"})
lmt.SetHeader("user_id", []string{})

counterMap := &LockMap{m: map[string]int64{}}
lmt.SetOnLimitReached(func(w http.ResponseWriter, r *http.Request) {
_, _ = w, r
counterMap.Add(r.Header.Get("user_id"), 1)
})

handler := LimitHandler(lmt, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = r
w.Write([]byte(`hello world`))
}))

req, err := http.NewRequest("POST", "/doesntmatter", nil)
if err != nil {
t.Fatal(err)
}

req.Header.Set("X-Real-IP", "2601:7:1c82:4097:59a0:a80b:2841:b8c8")
req.Header.Set("user_id", "0")

rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
{ // Should not be limited
if status := rr.Code; status != http.StatusOK {
t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusOK)
}
// check RateLimit headers
if value := rr.Result().Header[http.CanonicalHeaderKey("RateLimit-Limit")]; len(value) < 1 || value[0] != "1" {
t.Errorf("handler returned wrong value: got %s want %s", value, "1")
}
if value := rr.Result().Header[http.CanonicalHeaderKey("RateLimit-Reset")]; len(value) < 1 || value[0] != "1" {
t.Errorf("handler returned wrong value: got %s want %s", value, "1")
}
if value := rr.Result().Header[http.CanonicalHeaderKey("RateLimit-Remaining")]; len(value) < 1 || value[0] != "0" {
t.Errorf("handler returned wrong value: got %s want %s", value, "0")
}
}

wg := sync.WaitGroup{}
wg.Add(1)

// same user_id, should be limited
go func() {
defer wg.Done()

req1, _ := http.NewRequest("POST", "/doesntmatter", nil)
req1.Header.Set("X-Real-IP", "2601:7:1c82:4097:59a0:a80b:2841:b8c8")
req1.Header.Set("user_id", "0")
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req1)
// Should be limited
{
if status := rr.Code; status != http.StatusTooManyRequests {
t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusTooManyRequests)
}
// check X-Rate-Limit headers
if value := rr.Result().Header[http.CanonicalHeaderKey("X-Rate-Limit-Limit")]; len(value) < 1 || value[0] != "1.00" {
t.Errorf("X-Rate-Limit-Limit has wrong value: got %s want %v", value, "1.00")
}
if value := rr.Result().Header[http.CanonicalHeaderKey("X-Rate-Limit-Duration")]; len(value) < 1 || value[0] != "1" {
t.Errorf("X-Rate-Limit-Duration has wrong value: got %s want %v", value, "1")
}
// check RateLimit headers
if value := rr.Result().Header[http.CanonicalHeaderKey("RateLimit-Limit")]; len(value) < 1 || value[0] != "1" {
t.Errorf("RateLimit-Limit has wrong value: got %s want %v", value, "1")
}
if value := rr.Result().Header[http.CanonicalHeaderKey("RateLimit-Reset")]; len(value) < 1 || value[0] != "1" {
t.Errorf("RateLimit-Reset has wrong value: got %s want %v", value, "1")
}
if value := rr.Result().Header[http.CanonicalHeaderKey("RateLimit-Remaining")]; len(value) < 1 || value[0] != "0" {
t.Errorf("RateLimit-Remaining has wrong value: got %s want %v", value, "0")
}
// OnLimitReached should be called
if aint, ok := counterMap.Get(req1.Header.Get("user_id")); ok {
if aint == 0 {
t.Errorf("onLimitReached was not called")
}
}
}
}()

wg.Wait() // Block until go func is done.
}

0 comments on commit 0b61d55

Please sign in to comment.