Skip to content

Commit

Permalink
make LBSelector interface and implement all the current methods plus …
Browse files Browse the repository at this point in the history
…roundrobin
  • Loading branch information
umputun committed Nov 27, 2023
1 parent 8bde167 commit fa23778
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 13 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ This is the list of all options supporting multiple elements:
-x, --header= outgoing proxy headers to add [$HEADER]
--drop-header= incoming headers to drop [$DROP_HEADERS]
--basic-htpasswd= htpasswd file for basic auth [$BASIC_HTPASSWD]
--lb-type=[random|failover] load balancer type (default: random) [$LB_TYPE]
--lb-type=[random|failover|roundrobin] load balancer type (default: random) [$LB_TYPE]
--signature enable reproxy signature headers [$SIGNATURE]
--remote-lookup-headers enable remote lookup headers [$REMOTE_LOOKUP_HEADERS]
--dbg debug mode [$DEBUG]
Expand Down
13 changes: 7 additions & 6 deletions app/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"fmt"
"io"
"math"
"math/rand"
"net/http"
"net/rpc"
"os"
Expand Down Expand Up @@ -36,7 +35,7 @@ var opts struct {
DropHeaders []string `long:"drop-header" env:"DROP_HEADERS" description:"incoming headers to drop" env-delim:","`
AuthBasicHtpasswd string `long:"basic-htpasswd" env:"BASIC_HTPASSWD" description:"htpasswd file for basic auth"`
RemoteLookupHeaders bool `long:"remote-lookup-headers" env:"REMOTE_LOOKUP_HEADERS" description:"enable remote lookup headers"`
LBType string `long:"lb-type" env:"LB_TYPE" description:"load balancer type" choice:"random" choice:"failover" default:"random"` // nolint
LBType string `long:"lb-type" env:"LB_TYPE" description:"load balancer type" choice:"random" choice:"failover" choice:"roundrobin" default:"random"` // nolint

SSL struct {
Type string `long:"type" env:"TYPE" description:"ssl (auto) support" choice:"none" choice:"static" choice:"auto" default:"none"` // nolint
Expand Down Expand Up @@ -414,14 +413,16 @@ func makeSSLConfig() (config proxy.SSLConfig, err error) {
return config, err
}

func makeLBSelector() func(len int) int {
func makeLBSelector() proxy.LBSelector {
switch opts.LBType {
case "random":
return rand.Intn
return &proxy.RandomSelector{}
case "failover":
return func(int) int { return 0 } // dead server won't be in the list, we can safely pick the first one
return &proxy.FailoverSelector{}
case "roundrobin":
return &proxy.RoundRobinSelector{}
default:
return func(int) int { return 0 }
return &proxy.FailoverSelector{}
}
}

Expand Down
45 changes: 45 additions & 0 deletions app/proxy/lb_selector.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package proxy

import (
"math/rand"
"sync"
)

// RoundRobinSelector is a simple round-robin selector, thread-safe
type RoundRobinSelector struct {
lastSelected int
mu sync.Mutex
}

// Select returns next backend index
func (r *RoundRobinSelector) Select(n int) int {
r.mu.Lock()
defer r.mu.Unlock()
selected := r.lastSelected
r.lastSelected = (r.lastSelected + 1) % n
return selected
}

// RandomSelector is a random selector, thread-safe
type RandomSelector struct{}

// Select returns random backend index
func (r *RandomSelector) Select(n int) int {
return rand.Intn(n) //nolint:gosec // no need for crypto/rand here
}

// FailoverSelector is a selector with failover, thread-safe
type FailoverSelector struct{}

// Select returns next backend index
func (r *FailoverSelector) Select(_ int) int {
return 0 // dead server won't be in the list, we can safely pick the first one
}

// LBSelectorFunc is a functional adapted for LBSelector to select backend from the list
type LBSelectorFunc func(n int) int

// Select returns backend index
func (f LBSelectorFunc) Select(n int) int {
return f(n)
}
121 changes: 121 additions & 0 deletions app/proxy/lb_selector_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package proxy

import (
"sync"
"testing"

"github.com/stretchr/testify/assert"
)

func TestRoundRobinSelector_Select(t *testing.T) {
selector := &RoundRobinSelector{}

testCases := []struct {
name string
len int
expected int
}{
{"First call", 3, 0},
{"Second call", 3, 1},
{"Third call", 3, 2},
{"Back to zero", 3, 0},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := selector.Select(tc.len)
assert.Equal(t, tc.expected, result)
})
}
}

func TestRoundRobinSelector_SelectConcurrent(t *testing.T) {
selector := &RoundRobinSelector{}
l := 3
numGoroutines := 1000

var wg sync.WaitGroup
wg.Add(numGoroutines)

results := &sync.Map{}

for i := 0; i < numGoroutines; i++ {
go func() {
defer wg.Done()
result := selector.Select(l)
results.Store(result, struct{}{})
}()
}

wg.Wait()

// check that all possible results are present in the map.
for i := 0; i < l; i++ {
_, ok := results.Load(i)
assert.True(t, ok, "expected to find %d in the results", i)
}
}

func TestRandomSelector_Select(t *testing.T) {
selector := &RandomSelector{}

testCases := []struct {
name string
len int
}{
{"First call", 5},
{"Second call", 5},
{"Third call", 5},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := selector.Select(tc.len)
assert.True(t, result >= 0 && result < tc.len)
})
}
}

func TestFailoverSelector_Select(t *testing.T) {
selector := &FailoverSelector{}

testCases := []struct {
name string
len int
expected int
}{
{"First call", 5, 0},
{"Second call", 5, 0},
{"Third call", 5, 0},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := selector.Select(tc.len)
assert.Equal(t, tc.expected, result)
})
}
}

func TestLBSelectorFunc_Select(t *testing.T) {
selector := LBSelectorFunc(func(n int) int {
return n - 1 // simple selection logic for testing
})

testCases := []struct {
name string
len int
expected int
}{
{"First call", 5, 4},
{"Second call", 3, 2},
{"Third call", 1, 0},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := selector.Select(tc.len)
assert.Equal(t, tc.expected, result)
})
}
}
14 changes: 9 additions & 5 deletions app/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"context"
"fmt"
"io"
"math/rand"
"net"
"net/http"
"net/http/httputil"
Expand Down Expand Up @@ -47,7 +46,7 @@ type Http struct { // nolint golint
Metrics MiddlewareProvider
PluginConductor MiddlewareProvider
Reporter Reporter
LBSelector func(len int) int
LBSelector LBSelector
OnlyFrom *OnlyFrom
BasicAuthEnabled bool
BasicAuthAllowed []string
Expand Down Expand Up @@ -75,6 +74,11 @@ type Reporter interface {
Report(w http.ResponseWriter, code int)
}

// LBSelector defines load balancer strategy
type LBSelector interface {
Select(len int) int // return index of picked server
}

// Timeouts consolidate timeouts for both server and transport
type Timeouts struct {
// server timeouts
Expand All @@ -101,7 +105,7 @@ func (h *Http) Run(ctx context.Context) error {
}

if h.LBSelector == nil {
h.LBSelector = rand.Intn
h.LBSelector = &RandomSelector{}
}

var httpServer, httpsServer *http.Server
Expand Down Expand Up @@ -277,7 +281,7 @@ func (h *Http) proxyHandler() http.HandlerFunc {
// and if match found sets it to the request context. Context used by proxy handler as well as by plugin conductor
func (h *Http) matchHandler(next http.Handler) http.Handler {

getMatch := func(mm discovery.Matches, picker func(len int) int) (m discovery.MatchedRoute, ok bool) {
getMatch := func(mm discovery.Matches, picker LBSelector) (m discovery.MatchedRoute, ok bool) {
if len(mm.Routes) == 0 {
return m, false
}
Expand All @@ -294,7 +298,7 @@ func (h *Http) matchHandler(next http.Handler) http.Handler {
case 1:
return matches[0], true
default:
return matches[picker(len(matches))], true
return matches[picker.Select(len(matches))], true
}
}

Expand Down
2 changes: 1 addition & 1 deletion app/proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -874,7 +874,7 @@ func TestHttp_matchHandler(t *testing.T) {
client := http.Client{}
for _, tt := range tbl {
t.Run(tt.name, func(t *testing.T) {
h := Http{Matcher: matcherMock, LBSelector: func(len int) int { return 0 }}
h := Http{Matcher: matcherMock, LBSelector: &FailoverSelector{}}
handler := h.matchHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Logf("req: %+v", r)
t.Logf("dst: %v", r.Context().Value(ctxURL))
Expand Down

0 comments on commit fa23778

Please sign in to comment.