Skip to content

Commit

Permalink
server: upstream connection rebalancing
Browse files Browse the repository at this point in the history
  • Loading branch information
lhpqaq committed Jul 27, 2024
1 parent a587abc commit 4091576
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 0 deletions.
16 changes: 16 additions & 0 deletions server/cluster/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -391,3 +391,19 @@ func (s *State) addMetricsNode(status NodeStatus) {
func (s *State) removeMetricsNode(status NodeStatus) {
s.metrics.Nodes.With(prometheus.Labels{"status": string(status)}).Dec()
}

func (s *State) TotalAndLocalUpstreams() (int, int) {
totalUpstreams := 0
localUpstreams := 0
for _, n := range s.NodesMetadata() {
if n.ID == s.localID {
localUpstreams = n.Upstreams
}
totalUpstreams += n.Upstreams
}
return totalUpstreams, localUpstreams
}

func (s *State) NodesNum() int {
return len(s.nodes)
}
33 changes: 33 additions & 0 deletions server/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,12 @@ type ClusterConfig struct {
AbortIfJoinFails bool `json:"abort_if_join_fails" yaml:"abort_if_join_fails"`

Gossip gossip.Config `json:"gossip" yaml:"gossip"`

RebalancingThreshold float32 `json:"rebalancing_threshold" yaml:"rebalancing_threshold"`

RebalancingRate float32 `json:"rebalancing_rate" yaml:"rebalancing_rate"`

RebalancingCheckInterval time.Duration `json:"rebalancing_check_interval" yaml:"rebalancing_check_interval"`
}

func (c *ClusterConfig) Validate() error {
Expand Down Expand Up @@ -377,6 +383,30 @@ set.`,
Whether the server node should abort if it is configured with more than one
node to join (excluding itself) but fails to join any members.`,
)
fs.Float32Var(
&c.RebalancingThreshold,
"cluster.rebalancing-threshold",
c.RebalancingThreshold,
`
Threshold for node startup rebalancing, if the node have 'threshold' more
connections than the cluster average`,
)

fs.Float32Var(
&c.RebalancingRate,
"cluster.rebalancing-rate",
c.RebalancingRate,
`
Rate for node startup rebalancing, shedding 'rate' of connections every second.`,
)

fs.DurationVar(
&c.RebalancingCheckInterval,
"cluster.rebalancing-check-interval",
c.RebalancingCheckInterval,
`
Time interval that checks if rebalancing is needed.`,
)

c.Gossip.RegisterFlags(fs, "cluster")
}
Expand Down Expand Up @@ -447,6 +477,9 @@ func Default() *Config {
Interval: time.Millisecond * 100,
MaxPacketSize: 1400,
},
RebalancingThreshold: 0.2,
RebalancingRate: 0.005,
RebalancingCheckInterval: time.Second * 5,
},
Log: log.Config{
Level: "info",
Expand Down
78 changes: 78 additions & 0 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"net"
"strings"
"sync"
"time"

"github.com/hashicorp/go-sockaddr"
"github.com/prometheus/client_golang/prometheus"
Expand Down Expand Up @@ -56,6 +57,8 @@ type Server struct {

registry *prometheus.Registry

RebalancingCancel context.CancelFunc

logger log.Logger
}

Expand Down Expand Up @@ -262,9 +265,82 @@ func (s *Server) Start() error {
}
}

s.RebalancingDaemon()

return nil
}

func (s *Server) RebalancingDaemon() {
needRebalancing := func() (int, bool) {
totalConns, localConns := s.clusterState.TotalAndLocalUpstreams()
averageConns := float32(totalConns) /
float32(s.clusterState.NodesNum())
if float32(localConns) > averageConns*
(1+s.conf.Cluster.RebalancingThreshold) {
return localConns, true
}
return 0, false
}

sheddingConnections := func(connsToShed int) {
i := 0
s.upstreamServer.ConnsMux.Lock()
for conn := range s.upstreamServer.Conns {
(*conn).Close()
delete(s.upstreamServer.Conns, conn)
i++
if i >= connsToShed {
break
}
}
s.upstreamServer.ConnsMux.Unlock()
}

startShedding := func(connsToShed int, rebalancing *bool, mu *sync.Mutex) {
mu.Lock()
*rebalancing = true
mu.Unlock()
for {
sheddingConnections(connsToShed)
time.Sleep(time.Second)
if _, needRebalance := needRebalancing(); !needRebalance {
mu.Lock()
*rebalancing = false
mu.Unlock()
return
}
}
}

s.runGoroutine(func() {
ticker := time.NewTicker(s.conf.Cluster.RebalancingCheckInterval)
defer ticker.Stop()
ctx, cancel := context.WithCancel(context.Background())
s.RebalancingCancel = cancel
rebalancing := false
var mu sync.Mutex
for {
select {
case <-ctx.Done():
s.logger.Debug("rebalancing daemon stopped")
return
case <-ticker.C:
mu.Lock()
if rebalancing {
mu.Unlock()
continue
}
mu.Unlock()

if localConn, needRebalance := needRebalancing(); needRebalance {
connsToShed := int(min(1, float32(localConn)*s.conf.Cluster.RebalancingRate))
go startShedding(connsToShed, &rebalancing, &mu)
}
}
}
})
}

// Shutdown gracefully stops the server node.
func (s *Server) Shutdown() {
if !s.shutdown.CompareAndSwap(false, true) {
Expand Down Expand Up @@ -309,6 +385,8 @@ func (s *Server) Shutdown() {

s.shutdownUsageReporting()

s.RebalancingCancel()

s.wg.Wait()

s.logger.Info("shutdown complete")
Expand Down
19 changes: 19 additions & 0 deletions server/upstream/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"net"
"net/http"
"sync"

"github.com/andydunstall/yamux"
"github.com/gin-gonic/gin"
Expand All @@ -32,6 +33,22 @@ type Server struct {
cancel func()

logger log.Logger

Conns map[*net.Conn]bool
ConnsMux sync.Mutex
}

func (s *Server) connStateChange(c net.Conn, state http.ConnState) {
switch state {
case http.StateNew:
s.ConnsMux.Lock()
s.Conns[&c] = true
s.ConnsMux.Unlock()
case http.StateClosed, http.StateHijacked:
s.ConnsMux.Lock()
delete(s.Conns, &c)
s.ConnsMux.Unlock()
}
}

func NewServer(
Expand All @@ -55,7 +72,9 @@ func NewServer(
ctx: ctx,
cancel: cancel,
logger: logger,
Conns: make(map[*net.Conn]bool),
}
server.httpServer.ConnState = server.connStateChange

// Recover from panics.
router.Use(gin.CustomRecoveryWithWriter(nil, server.panicRoute))
Expand Down
60 changes: 60 additions & 0 deletions tests/server/cluster_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,20 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"sync"
"testing"
"time"

"github.com/stretchr/testify/assert"

"github.com/andydunstall/piko/client"
"github.com/andydunstall/piko/pikotest/cluster"
"github.com/andydunstall/piko/pikotest/cluster/config"
"github.com/andydunstall/piko/pikotest/cluster/proxy"
"github.com/andydunstall/piko/pikotest/workload/upstreams"
uconf "github.com/andydunstall/piko/pikotest/workload/upstreams/config"
"github.com/andydunstall/piko/pkg/log"
)

// Tests proxying traffic across multiple Piko server nodes.
Expand Down Expand Up @@ -149,3 +155,57 @@ func TestCluster_Proxy(t *testing.T) {
wg.Wait()
})
}

// Testing cluster upstream connection rebalancing
func TestCluster_Rebalancing(t *testing.T) {
manager := cluster.NewManager()
defer manager.Close()
manager.Update(&config.Config{
Nodes: 3,
})
loadBalancer := proxy.NewLoadBalancer(manager)
defer loadBalancer.Close()
conf := uconf.Default()

logger, _ := log.NewLogger("error", conf.Log.Subsystems)

// Create 1000 upstream connections
for i := 0; i < 1000; i++ {
upstream, _ := upstreams.NewTCPUpstream("my-endpoint"+strconv.Itoa(i), conf, logger)
defer upstream.Close()
}

getConnectionsCount := func(nodeIndex int) (string, int) {
state := manager.Nodes()[nodeIndex].ClusterState()
_, local := state.TotalAndLocalUpstreams()
id := state.LocalID()
return id, local
}
time.Sleep(2 * time.Second)

initConns := make(map[string]int)
for i := 0; i < 3; i++ {
id, conns := getConnectionsCount(i)
initConns[id] = conns
}

// Add two new nodes
manager.Update(&config.Config{
Nodes: 5,
})

// Waiting for a period of time to rebalance
time.Sleep(15 * time.Second)

for i := 0; i < 5; i++ {
id, conns := getConnectionsCount(i)
oldConns, ok := initConns[id]
if !ok {
// If it is a new node, conns should be greater than 0
assert.Greater(t, conns, 0)
} else {
// If it is an old node, conns should be reduced
assert.Less(t, conns, oldConns)
}
}
}

0 comments on commit 4091576

Please sign in to comment.