From fa23778d4266d1e661e3cb33b546f53806f5d5a8 Mon Sep 17 00:00:00 2001 From: Umputun Date: Mon, 27 Nov 2023 12:05:17 -0600 Subject: [PATCH] make LBSelector interface and implement all the current methods plus roundrobin --- README.md | 2 +- app/main.go | 13 ++-- app/proxy/lb_selector.go | 45 +++++++++++++ app/proxy/lb_selector_test.go | 121 ++++++++++++++++++++++++++++++++++ app/proxy/proxy.go | 14 ++-- app/proxy/proxy_test.go | 2 +- 6 files changed, 184 insertions(+), 13 deletions(-) create mode 100644 app/proxy/lb_selector.go create mode 100644 app/proxy/lb_selector_test.go diff --git a/README.md b/README.md index 8737c778..f7361e42 100644 --- a/README.md +++ b/README.md @@ -364,7 +364,7 @@ This is the list of all options supporting multiple elements: -x, --header= outgoing proxy headers to add [$HEADER] --drop-header= incoming headers to drop [$DROP_HEADERS] --basic-htpasswd= htpasswd file for basic auth [$BASIC_HTPASSWD] - --lb-type=[random|failover] load balancer type (default: random) [$LB_TYPE] + --lb-type=[random|failover|roundrobin] load balancer type (default: random) [$LB_TYPE] --signature enable reproxy signature headers [$SIGNATURE] --remote-lookup-headers enable remote lookup headers [$REMOTE_LOOKUP_HEADERS] --dbg debug mode [$DEBUG] diff --git a/app/main.go b/app/main.go index cb078f98..f69cdb11 100644 --- a/app/main.go +++ b/app/main.go @@ -6,7 +6,6 @@ import ( "fmt" "io" "math" - "math/rand" "net/http" "net/rpc" "os" @@ -36,7 +35,7 @@ var opts struct { DropHeaders []string `long:"drop-header" env:"DROP_HEADERS" description:"incoming headers to drop" env-delim:","` AuthBasicHtpasswd string `long:"basic-htpasswd" env:"BASIC_HTPASSWD" description:"htpasswd file for basic auth"` RemoteLookupHeaders bool `long:"remote-lookup-headers" env:"REMOTE_LOOKUP_HEADERS" description:"enable remote lookup headers"` - LBType string `long:"lb-type" env:"LB_TYPE" description:"load balancer type" choice:"random" choice:"failover" default:"random"` // nolint + LBType string `long:"lb-type" env:"LB_TYPE" description:"load balancer type" choice:"random" choice:"failover" choice:"roundrobin" default:"random"` // nolint SSL struct { Type string `long:"type" env:"TYPE" description:"ssl (auto) support" choice:"none" choice:"static" choice:"auto" default:"none"` // nolint @@ -414,14 +413,16 @@ func makeSSLConfig() (config proxy.SSLConfig, err error) { return config, err } -func makeLBSelector() func(len int) int { +func makeLBSelector() proxy.LBSelector { switch opts.LBType { case "random": - return rand.Intn + return &proxy.RandomSelector{} case "failover": - return func(int) int { return 0 } // dead server won't be in the list, we can safely pick the first one + return &proxy.FailoverSelector{} + case "roundrobin": + return &proxy.RoundRobinSelector{} default: - return func(int) int { return 0 } + return &proxy.FailoverSelector{} } } diff --git a/app/proxy/lb_selector.go b/app/proxy/lb_selector.go new file mode 100644 index 00000000..a20bf178 --- /dev/null +++ b/app/proxy/lb_selector.go @@ -0,0 +1,45 @@ +package proxy + +import ( + "math/rand" + "sync" +) + +// RoundRobinSelector is a simple round-robin selector, thread-safe +type RoundRobinSelector struct { + lastSelected int + mu sync.Mutex +} + +// Select returns next backend index +func (r *RoundRobinSelector) Select(n int) int { + r.mu.Lock() + defer r.mu.Unlock() + selected := r.lastSelected + r.lastSelected = (r.lastSelected + 1) % n + return selected +} + +// RandomSelector is a random selector, thread-safe +type RandomSelector struct{} + +// Select returns random backend index +func (r *RandomSelector) Select(n int) int { + return rand.Intn(n) //nolint:gosec // no need for crypto/rand here +} + +// FailoverSelector is a selector with failover, thread-safe +type FailoverSelector struct{} + +// Select returns next backend index +func (r *FailoverSelector) Select(_ int) int { + return 0 // dead server won't be in the list, we can safely pick the first one +} + +// LBSelectorFunc is a functional adapted for LBSelector to select backend from the list +type LBSelectorFunc func(n int) int + +// Select returns backend index +func (f LBSelectorFunc) Select(n int) int { + return f(n) +} diff --git a/app/proxy/lb_selector_test.go b/app/proxy/lb_selector_test.go new file mode 100644 index 00000000..c5a78e19 --- /dev/null +++ b/app/proxy/lb_selector_test.go @@ -0,0 +1,121 @@ +package proxy + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRoundRobinSelector_Select(t *testing.T) { + selector := &RoundRobinSelector{} + + testCases := []struct { + name string + len int + expected int + }{ + {"First call", 3, 0}, + {"Second call", 3, 1}, + {"Third call", 3, 2}, + {"Back to zero", 3, 0}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := selector.Select(tc.len) + assert.Equal(t, tc.expected, result) + }) + } +} + +func TestRoundRobinSelector_SelectConcurrent(t *testing.T) { + selector := &RoundRobinSelector{} + l := 3 + numGoroutines := 1000 + + var wg sync.WaitGroup + wg.Add(numGoroutines) + + results := &sync.Map{} + + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + result := selector.Select(l) + results.Store(result, struct{}{}) + }() + } + + wg.Wait() + + // check that all possible results are present in the map. + for i := 0; i < l; i++ { + _, ok := results.Load(i) + assert.True(t, ok, "expected to find %d in the results", i) + } +} + +func TestRandomSelector_Select(t *testing.T) { + selector := &RandomSelector{} + + testCases := []struct { + name string + len int + }{ + {"First call", 5}, + {"Second call", 5}, + {"Third call", 5}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := selector.Select(tc.len) + assert.True(t, result >= 0 && result < tc.len) + }) + } +} + +func TestFailoverSelector_Select(t *testing.T) { + selector := &FailoverSelector{} + + testCases := []struct { + name string + len int + expected int + }{ + {"First call", 5, 0}, + {"Second call", 5, 0}, + {"Third call", 5, 0}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := selector.Select(tc.len) + assert.Equal(t, tc.expected, result) + }) + } +} + +func TestLBSelectorFunc_Select(t *testing.T) { + selector := LBSelectorFunc(func(n int) int { + return n - 1 // simple selection logic for testing + }) + + testCases := []struct { + name string + len int + expected int + }{ + {"First call", 5, 4}, + {"Second call", 3, 2}, + {"Third call", 1, 0}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := selector.Select(tc.len) + assert.Equal(t, tc.expected, result) + }) + } +} diff --git a/app/proxy/proxy.go b/app/proxy/proxy.go index b3af185b..3d247409 100644 --- a/app/proxy/proxy.go +++ b/app/proxy/proxy.go @@ -5,7 +5,6 @@ import ( "context" "fmt" "io" - "math/rand" "net" "net/http" "net/http/httputil" @@ -47,7 +46,7 @@ type Http struct { // nolint golint Metrics MiddlewareProvider PluginConductor MiddlewareProvider Reporter Reporter - LBSelector func(len int) int + LBSelector LBSelector OnlyFrom *OnlyFrom BasicAuthEnabled bool BasicAuthAllowed []string @@ -75,6 +74,11 @@ type Reporter interface { Report(w http.ResponseWriter, code int) } +// LBSelector defines load balancer strategy +type LBSelector interface { + Select(len int) int // return index of picked server +} + // Timeouts consolidate timeouts for both server and transport type Timeouts struct { // server timeouts @@ -101,7 +105,7 @@ func (h *Http) Run(ctx context.Context) error { } if h.LBSelector == nil { - h.LBSelector = rand.Intn + h.LBSelector = &RandomSelector{} } var httpServer, httpsServer *http.Server @@ -277,7 +281,7 @@ func (h *Http) proxyHandler() http.HandlerFunc { // and if match found sets it to the request context. Context used by proxy handler as well as by plugin conductor func (h *Http) matchHandler(next http.Handler) http.Handler { - getMatch := func(mm discovery.Matches, picker func(len int) int) (m discovery.MatchedRoute, ok bool) { + getMatch := func(mm discovery.Matches, picker LBSelector) (m discovery.MatchedRoute, ok bool) { if len(mm.Routes) == 0 { return m, false } @@ -294,7 +298,7 @@ func (h *Http) matchHandler(next http.Handler) http.Handler { case 1: return matches[0], true default: - return matches[picker(len(matches))], true + return matches[picker.Select(len(matches))], true } } diff --git a/app/proxy/proxy_test.go b/app/proxy/proxy_test.go index 4d916b18..25144c53 100644 --- a/app/proxy/proxy_test.go +++ b/app/proxy/proxy_test.go @@ -874,7 +874,7 @@ func TestHttp_matchHandler(t *testing.T) { client := http.Client{} for _, tt := range tbl { t.Run(tt.name, func(t *testing.T) { - h := Http{Matcher: matcherMock, LBSelector: func(len int) int { return 0 }} + h := Http{Matcher: matcherMock, LBSelector: &FailoverSelector{}} handler := h.matchHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Logf("req: %+v", r) t.Logf("dst: %v", r.Context().Value(ctxURL))