diff --git a/internal/balancer/balancer.go b/internal/balancer/balancer.go index b6c4f03f2..45249366b 100644 --- a/internal/balancer/balancer.go +++ b/internal/balancer/balancer.go @@ -42,7 +42,7 @@ type Balancer struct { discoveryRepeater repeater.Repeater localDCDetector func(ctx context.Context, endpoints []endpoint.Endpoint) (string, error) - connectionsState atomic.Pointer[connectionsState] + connectionsState atomic.Pointer[state] mu xsync.RWMutex onApplyDiscoveredEndpoints []func(ctx context.Context, endpoints []endpoint.Info) @@ -152,7 +152,7 @@ func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, newest []endpoi } info := balancerConfig.Info{SelfLocation: localDC} - state := newConnectionsState(connections, b.config.Filter, info, b.config.AllowFallback) + state := newState(connections, b.config.Filter, info, b.config.AllowFallback) endpointsInfo := make([]endpoint.Info, len(newest)) for i, e := range newest { @@ -319,7 +319,7 @@ func (b *Balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc return nil } -func (b *Balancer) connections() *connectionsState { +func (b *Balancer) connections() *state { return b.connectionsState.Load() } @@ -351,7 +351,7 @@ func (b *Balancer) getConn(ctx context.Context) (c conn.Conn, err error) { } }() - c, failedCount = state.GetConnection(ctx) + c, failedCount = state.Next(ctx) if c == nil { return nil, xerrors.WithStackTrace( fmt.Errorf("%w: cannot get connection from Balancer after %d attempts", ErrNoEndpoints, failedCount), diff --git a/internal/balancer/connections_state.go b/internal/balancer/connections_state.go deleted file mode 100644 index fecbc6db1..000000000 --- a/internal/balancer/connections_state.go +++ /dev/null @@ -1,179 +0,0 @@ -package balancer - -import ( - "context" - - balancerConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer/config" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/conn" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/xrand" -) - -type connectionsState struct { - connByNodeID map[uint32]conn.Conn - - prefer []conn.Conn - fallback []conn.Conn - all []conn.Conn - - rand xrand.Rand -} - -func newConnectionsState( - conns []conn.Conn, - filter balancerConfig.Filter, - info balancerConfig.Info, - allowFallback bool, -) *connectionsState { - res := &connectionsState{ - connByNodeID: connsToNodeIDMap(conns), - rand: xrand.New(xrand.WithLock()), - } - - res.prefer, res.fallback = sortPreferConnections(conns, filter, info, allowFallback) - if allowFallback { - res.all = conns - } else { - res.all = res.prefer - } - - return res -} - -func (s *connectionsState) PreferredCount() int { - return len(s.prefer) -} - -func (s *connectionsState) All() (all []endpoint.Endpoint) { - if s == nil { - return nil - } - - all = make([]endpoint.Endpoint, len(s.all)) - for i, c := range s.all { - all[i] = c.Endpoint() - } - - return all -} - -func (s *connectionsState) GetConnection(ctx context.Context) (_ conn.Conn, failedCount int) { - if err := ctx.Err(); err != nil { - return nil, 0 - } - - if c := s.preferConnection(ctx); c != nil { - return c, 0 - } - - try := func(conns []conn.Conn) conn.Conn { - c, tryFailed := s.selectRandomConnection(conns, false) - failedCount += tryFailed - - return c - } - - if c := try(s.prefer); c != nil { - return c, failedCount - } - - if c := try(s.fallback); c != nil { - return c, failedCount - } - - c, _ := s.selectRandomConnection(s.all, true) - - return c, failedCount -} - -func (s *connectionsState) preferConnection(ctx context.Context) conn.Conn { - if nodeID, hasPreferEndpoint := endpoint.ContextNodeID(ctx); hasPreferEndpoint { - c := s.connByNodeID[nodeID] - if c != nil && isOkConnection(c, true) { - return c - } - } - - return nil -} - -func (s *connectionsState) selectRandomConnection(conns []conn.Conn, allowBanned bool) (c conn.Conn, failedConns int) { - connCount := len(conns) - if connCount == 0 { - // return for empty list need for prevent panic in fast path - return nil, 0 - } - - // fast path - if c := conns[s.rand.Int(connCount)]; isOkConnection(c, allowBanned) { - return c, 0 - } - - // shuffled indexes slices need for guarantee about every connection will check - indexes := make([]int, connCount) - for index := range indexes { - indexes[index] = index - } - s.rand.Shuffle(connCount, func(i, j int) { - indexes[i], indexes[j] = indexes[j], indexes[i] - }) - - for _, index := range indexes { - c := conns[index] - if isOkConnection(c, allowBanned) { - return c, 0 - } - failedConns++ - } - - return nil, failedConns -} - -func connsToNodeIDMap(conns []conn.Conn) (nodes map[uint32]conn.Conn) { - if len(conns) == 0 { - return nil - } - nodes = make(map[uint32]conn.Conn, len(conns)) - for _, c := range conns { - nodes[c.Endpoint().NodeID()] = c - } - - return nodes -} - -func sortPreferConnections( - conns []conn.Conn, - filter balancerConfig.Filter, - info balancerConfig.Info, - allowFallback bool, -) (prefer, fallback []conn.Conn) { - if filter == nil { - return conns, nil - } - - prefer = make([]conn.Conn, 0, len(conns)) - if allowFallback { - fallback = make([]conn.Conn, 0, len(conns)) - } - - for _, c := range conns { - if filter.Allow(info, c.Endpoint()) { - prefer = append(prefer, c) - } else if allowFallback { - fallback = append(fallback, c) - } - } - - return prefer, fallback -} - -func isOkConnection(c conn.Conn, bannedIsOk bool) bool { - switch c.GetState() { - case conn.Online, conn.Created, conn.Offline: - return true - case conn.Banned: - return bannedIsOk - default: - return false - } -} diff --git a/internal/balancer/connections_state_test.go b/internal/balancer/connections_state_test.go deleted file mode 100644 index c8648ee2a..000000000 --- a/internal/balancer/connections_state_test.go +++ /dev/null @@ -1,464 +0,0 @@ -package balancer - -import ( - "context" - "strings" - "testing" - - "github.com/stretchr/testify/require" - - balancerConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer/config" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/conn" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/mock" -) - -func TestConnsToNodeIDMap(t *testing.T) { - table := []struct { - name string - source []conn.Conn - res map[uint32]conn.Conn - }{ - { - name: "Empty", - source: nil, - res: nil, - }, - { - name: "Zero", - source: []conn.Conn{ - &mock.Conn{NodeIDField: 0}, - }, - res: map[uint32]conn.Conn{ - 0: &mock.Conn{NodeIDField: 0}, - }, - }, - { - name: "NonZero", - source: []conn.Conn{ - &mock.Conn{NodeIDField: 1}, - &mock.Conn{NodeIDField: 10}, - }, - res: map[uint32]conn.Conn{ - 1: &mock.Conn{NodeIDField: 1}, - 10: &mock.Conn{NodeIDField: 10}, - }, - }, - { - name: "Combined", - source: []conn.Conn{ - &mock.Conn{NodeIDField: 1}, - &mock.Conn{NodeIDField: 0}, - &mock.Conn{NodeIDField: 10}, - }, - res: map[uint32]conn.Conn{ - 0: &mock.Conn{NodeIDField: 0}, - 1: &mock.Conn{NodeIDField: 1}, - 10: &mock.Conn{NodeIDField: 10}, - }, - }, - } - - for _, test := range table { - t.Run(test.name, func(t *testing.T) { - require.Equal(t, test.res, connsToNodeIDMap(test.source)) - }) - } -} - -type filterFunc func(info balancerConfig.Info, e endpoint.Info) bool - -func (f filterFunc) Allow(info balancerConfig.Info, e endpoint.Info) bool { - return f(info, e) -} - -func (f filterFunc) String() string { - return "Custom" -} - -func TestSortPreferConnections(t *testing.T) { - table := []struct { - name string - source []conn.Conn - allowFallback bool - filter balancerConfig.Filter - prefer []conn.Conn - fallback []conn.Conn - }{ - { - name: "Empty", - source: nil, - allowFallback: false, - filter: nil, - prefer: nil, - fallback: nil, - }, - { - name: "NilFilter", - source: []conn.Conn{ - &mock.Conn{AddrField: "1"}, - &mock.Conn{AddrField: "2"}, - }, - allowFallback: false, - filter: nil, - prefer: []conn.Conn{ - &mock.Conn{AddrField: "1"}, - &mock.Conn{AddrField: "2"}, - }, - fallback: nil, - }, - { - name: "FilterNoFallback", - source: []conn.Conn{ - &mock.Conn{AddrField: "t1"}, - &mock.Conn{AddrField: "f1"}, - &mock.Conn{AddrField: "t2"}, - &mock.Conn{AddrField: "f2"}, - }, - allowFallback: false, - filter: filterFunc(func(_ balancerConfig.Info, e endpoint.Info) bool { - return strings.HasPrefix(e.Address(), "t") - }), - prefer: []conn.Conn{ - &mock.Conn{AddrField: "t1"}, - &mock.Conn{AddrField: "t2"}, - }, - fallback: nil, - }, - { - name: "FilterWithFallback", - source: []conn.Conn{ - &mock.Conn{AddrField: "t1"}, - &mock.Conn{AddrField: "f1"}, - &mock.Conn{AddrField: "t2"}, - &mock.Conn{AddrField: "f2"}, - }, - allowFallback: true, - filter: filterFunc(func(_ balancerConfig.Info, e endpoint.Info) bool { - return strings.HasPrefix(e.Address(), "t") - }), - prefer: []conn.Conn{ - &mock.Conn{AddrField: "t1"}, - &mock.Conn{AddrField: "t2"}, - }, - fallback: []conn.Conn{ - &mock.Conn{AddrField: "f1"}, - &mock.Conn{AddrField: "f2"}, - }, - }, - } - - for _, test := range table { - t.Run(test.name, func(t *testing.T) { - prefer, fallback := sortPreferConnections(test.source, test.filter, balancerConfig.Info{}, test.allowFallback) - require.Equal(t, test.prefer, prefer) - require.Equal(t, test.fallback, fallback) - }) - } -} - -func TestSelectRandomConnection(t *testing.T) { - s := newConnectionsState(nil, nil, balancerConfig.Info{}, false) - - t.Run("Empty", func(t *testing.T) { - c, failedCount := s.selectRandomConnection(nil, false) - require.Nil(t, c) - require.Equal(t, 0, failedCount) - }) - - t.Run("One", func(t *testing.T) { - for _, goodState := range []conn.State{conn.Online, conn.Offline, conn.Created} { - c, failedCount := s.selectRandomConnection([]conn.Conn{&mock.Conn{AddrField: "asd", State: goodState}}, false) - require.Equal(t, &mock.Conn{AddrField: "asd", State: goodState}, c) - require.Equal(t, 0, failedCount) - } - }) - t.Run("OneBanned", func(t *testing.T) { - c, failedCount := s.selectRandomConnection([]conn.Conn{&mock.Conn{AddrField: "asd", State: conn.Banned}}, false) - require.Nil(t, c) - require.Equal(t, 1, failedCount) - - c, failedCount = s.selectRandomConnection([]conn.Conn{&mock.Conn{AddrField: "asd", State: conn.Banned}}, true) - require.Equal(t, &mock.Conn{AddrField: "asd", State: conn.Banned}, c) - require.Equal(t, 0, failedCount) - }) - t.Run("Two", func(t *testing.T) { - conns := []conn.Conn{ - &mock.Conn{AddrField: "1", State: conn.Online}, - &mock.Conn{AddrField: "2", State: conn.Online}, - } - first := 0 - second := 0 - for i := 0; i < 100; i++ { - c, _ := s.selectRandomConnection(conns, false) - if c.Endpoint().Address() == "1" { - first++ - } else { - second++ - } - } - require.Equal(t, 100, first+second) - require.InDelta(t, 50, first, 21) - require.InDelta(t, 50, second, 21) - }) - t.Run("TwoBanned", func(t *testing.T) { - conns := []conn.Conn{ - &mock.Conn{AddrField: "1", State: conn.Banned}, - &mock.Conn{AddrField: "2", State: conn.Banned}, - } - totalFailed := 0 - for i := 0; i < 100; i++ { - c, failed := s.selectRandomConnection(conns, false) - require.Nil(t, c) - totalFailed += failed - } - require.Equal(t, 200, totalFailed) - }) - t.Run("ThreeWithBanned", func(t *testing.T) { - conns := []conn.Conn{ - &mock.Conn{AddrField: "1", State: conn.Online}, - &mock.Conn{AddrField: "2", State: conn.Online}, - &mock.Conn{AddrField: "3", State: conn.Banned}, - } - first := 0 - second := 0 - failed := 0 - for i := 0; i < 100; i++ { - c, checkFailed := s.selectRandomConnection(conns, false) - failed += checkFailed - switch c.Endpoint().Address() { - case "1": - first++ - case "2": - second++ - default: - t.Errorf(c.Endpoint().Address()) - } - } - require.Equal(t, 100, first+second) - require.InDelta(t, 50, first, 21) - require.InDelta(t, 50, second, 21) - require.Greater(t, 10, failed) - }) -} - -func TestNewState(t *testing.T) { - table := []struct { - name string - state *connectionsState - res *connectionsState - }{ - { - name: "Empty", - state: newConnectionsState(nil, nil, balancerConfig.Info{}, false), - res: &connectionsState{ - connByNodeID: nil, - prefer: nil, - fallback: nil, - all: nil, - }, - }, - { - name: "NoFilter", - state: newConnectionsState([]conn.Conn{ - &mock.Conn{AddrField: "1", NodeIDField: 1}, - &mock.Conn{AddrField: "2", NodeIDField: 2}, - }, nil, balancerConfig.Info{}, false), - res: &connectionsState{ - connByNodeID: map[uint32]conn.Conn{ - 1: &mock.Conn{AddrField: "1", NodeIDField: 1}, - 2: &mock.Conn{AddrField: "2", NodeIDField: 2}, - }, - prefer: []conn.Conn{ - &mock.Conn{AddrField: "1", NodeIDField: 1}, - &mock.Conn{AddrField: "2", NodeIDField: 2}, - }, - fallback: nil, - all: []conn.Conn{ - &mock.Conn{AddrField: "1", NodeIDField: 1}, - &mock.Conn{AddrField: "2", NodeIDField: 2}, - }, - }, - }, - { - name: "FilterDenyFallback", - state: newConnectionsState([]conn.Conn{ - &mock.Conn{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, - &mock.Conn{AddrField: "f1", NodeIDField: 2, LocationField: "f"}, - &mock.Conn{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, - &mock.Conn{AddrField: "f2", NodeIDField: 4, LocationField: "f"}, - }, filterFunc(func(info balancerConfig.Info, e endpoint.Info) bool { - return info.SelfLocation == e.Location() - }), balancerConfig.Info{SelfLocation: "t"}, false), - res: &connectionsState{ - connByNodeID: map[uint32]conn.Conn{ - 1: &mock.Conn{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, - 2: &mock.Conn{AddrField: "f1", NodeIDField: 2, LocationField: "f"}, - 3: &mock.Conn{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, - 4: &mock.Conn{AddrField: "f2", NodeIDField: 4, LocationField: "f"}, - }, - prefer: []conn.Conn{ - &mock.Conn{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, - &mock.Conn{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, - }, - fallback: nil, - all: []conn.Conn{ - &mock.Conn{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, - &mock.Conn{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, - }, - }, - }, - { - name: "FilterAllowFallback", - state: newConnectionsState([]conn.Conn{ - &mock.Conn{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, - &mock.Conn{AddrField: "f1", NodeIDField: 2, LocationField: "f"}, - &mock.Conn{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, - &mock.Conn{AddrField: "f2", NodeIDField: 4, LocationField: "f"}, - }, filterFunc(func(info balancerConfig.Info, e endpoint.Info) bool { - return info.SelfLocation == e.Location() - }), balancerConfig.Info{SelfLocation: "t"}, true), - res: &connectionsState{ - connByNodeID: map[uint32]conn.Conn{ - 1: &mock.Conn{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, - 2: &mock.Conn{AddrField: "f1", NodeIDField: 2, LocationField: "f"}, - 3: &mock.Conn{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, - 4: &mock.Conn{AddrField: "f2", NodeIDField: 4, LocationField: "f"}, - }, - prefer: []conn.Conn{ - &mock.Conn{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, - &mock.Conn{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, - }, - fallback: []conn.Conn{ - &mock.Conn{AddrField: "f1", NodeIDField: 2, LocationField: "f"}, - &mock.Conn{AddrField: "f2", NodeIDField: 4, LocationField: "f"}, - }, - all: []conn.Conn{ - &mock.Conn{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, - &mock.Conn{AddrField: "f1", NodeIDField: 2, LocationField: "f"}, - &mock.Conn{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, - &mock.Conn{AddrField: "f2", NodeIDField: 4, LocationField: "f"}, - }, - }, - }, - { - name: "WithNodeID", - state: newConnectionsState([]conn.Conn{ - &mock.Conn{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, - &mock.Conn{AddrField: "f1", NodeIDField: 2, LocationField: "f"}, - &mock.Conn{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, - &mock.Conn{AddrField: "f2", NodeIDField: 4, LocationField: "f"}, - }, filterFunc(func(info balancerConfig.Info, e endpoint.Info) bool { - return info.SelfLocation == e.Location() - }), balancerConfig.Info{SelfLocation: "t"}, true), - res: &connectionsState{ - connByNodeID: map[uint32]conn.Conn{ - 1: &mock.Conn{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, - 2: &mock.Conn{AddrField: "f1", NodeIDField: 2, LocationField: "f"}, - 3: &mock.Conn{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, - 4: &mock.Conn{AddrField: "f2", NodeIDField: 4, LocationField: "f"}, - }, - prefer: []conn.Conn{ - &mock.Conn{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, - &mock.Conn{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, - }, - fallback: []conn.Conn{ - &mock.Conn{AddrField: "f1", NodeIDField: 2, LocationField: "f"}, - &mock.Conn{AddrField: "f2", NodeIDField: 4, LocationField: "f"}, - }, - all: []conn.Conn{ - &mock.Conn{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, - &mock.Conn{AddrField: "f1", NodeIDField: 2, LocationField: "f"}, - &mock.Conn{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, - &mock.Conn{AddrField: "f2", NodeIDField: 4, LocationField: "f"}, - }, - }, - }, - } - - for _, test := range table { - t.Run(test.name, func(t *testing.T) { - require.NotNil(t, test.state.rand) - test.state.rand = nil - require.Equal(t, test.res, test.state) - }) - } -} - -func TestConnection(t *testing.T) { - t.Run("Empty", func(t *testing.T) { - s := newConnectionsState(nil, nil, balancerConfig.Info{}, false) - c, failed := s.GetConnection(context.Background()) - require.Nil(t, c) - require.Equal(t, 0, failed) - }) - t.Run("AllGood", func(t *testing.T) { - s := newConnectionsState([]conn.Conn{ - &mock.Conn{AddrField: "1", State: conn.Online}, - &mock.Conn{AddrField: "2", State: conn.Online}, - }, nil, balancerConfig.Info{}, false) - c, failed := s.GetConnection(context.Background()) - require.NotNil(t, c) - require.Equal(t, 0, failed) - }) - t.Run("WithBanned", func(t *testing.T) { - s := newConnectionsState([]conn.Conn{ - &mock.Conn{AddrField: "1", State: conn.Online}, - &mock.Conn{AddrField: "2", State: conn.Banned}, - }, nil, balancerConfig.Info{}, false) - c, _ := s.GetConnection(context.Background()) - require.Equal(t, &mock.Conn{AddrField: "1", State: conn.Online}, c) - }) - t.Run("AllBanned", func(t *testing.T) { - s := newConnectionsState([]conn.Conn{ - &mock.Conn{AddrField: "t1", State: conn.Banned, LocationField: "t"}, - &mock.Conn{AddrField: "f2", State: conn.Banned, LocationField: "f"}, - }, filterFunc(func(info balancerConfig.Info, e endpoint.Info) bool { - return e.Location() == info.SelfLocation - }), balancerConfig.Info{}, true) - preferred := 0 - fallback := 0 - for i := 0; i < 100; i++ { - c, failed := s.GetConnection(context.Background()) - require.NotNil(t, c) - require.Equal(t, 2, failed) - if c.Endpoint().Address() == "t1" { - preferred++ - } else { - fallback++ - } - } - require.Equal(t, 100, preferred+fallback) - require.InDelta(t, 50, preferred, 21) - require.InDelta(t, 50, fallback, 21) - }) - t.Run("PreferBannedWithFallback", func(t *testing.T) { - s := newConnectionsState([]conn.Conn{ - &mock.Conn{AddrField: "t1", State: conn.Banned, LocationField: "t"}, - &mock.Conn{AddrField: "f2", State: conn.Online, LocationField: "f"}, - }, filterFunc(func(info balancerConfig.Info, e endpoint.Info) bool { - return e.Location() == info.SelfLocation - }), balancerConfig.Info{SelfLocation: "t"}, true) - c, failed := s.GetConnection(context.Background()) - require.Equal(t, &mock.Conn{AddrField: "f2", State: conn.Online, LocationField: "f"}, c) - require.Equal(t, 1, failed) - }) - t.Run("PreferNodeID", func(t *testing.T) { - s := newConnectionsState([]conn.Conn{ - &mock.Conn{AddrField: "1", State: conn.Online, NodeIDField: 1}, - &mock.Conn{AddrField: "2", State: conn.Online, NodeIDField: 2}, - }, nil, balancerConfig.Info{}, false) - c, failed := s.GetConnection(endpoint.WithNodeID(context.Background(), 2)) - require.Equal(t, &mock.Conn{AddrField: "2", State: conn.Online, NodeIDField: 2}, c) - require.Equal(t, 0, failed) - }) - t.Run("PreferNodeIDWithBadState", func(t *testing.T) { - s := newConnectionsState([]conn.Conn{ - &mock.Conn{AddrField: "1", State: conn.Online, NodeIDField: 1}, - &mock.Conn{AddrField: "2", State: conn.Unknown, NodeIDField: 2}, - }, nil, balancerConfig.Info{}, false) - c, failed := s.GetConnection(endpoint.WithNodeID(context.Background(), 2)) - require.Equal(t, &mock.Conn{AddrField: "1", State: conn.Online, NodeIDField: 1}, c) - require.Equal(t, 0, failed) - }) -} diff --git a/internal/balancer/local_dc_test.go b/internal/balancer/local_dc_test.go index 2eab1e9a8..0924c5a18 100644 --- a/internal/balancer/local_dc_test.go +++ b/internal/balancer/local_dc_test.go @@ -151,7 +151,7 @@ func TestLocalDCDiscovery(t *testing.T) { require.NoError(t, err) for i := 0; i < 100; i++ { - conn, _ := r.connections().GetConnection(ctx) + conn, _ := r.connections().Next(ctx) require.Equal(t, "b:234", conn.Endpoint().Address()) require.Equal(t, "b", conn.Endpoint().Location()) } diff --git a/internal/balancer/state.go b/internal/balancer/state.go new file mode 100644 index 000000000..cb0ed5eca --- /dev/null +++ b/internal/balancer/state.go @@ -0,0 +1,141 @@ +package balancer + +import ( + "context" + + balancerConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer/config" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/conn" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/xrand" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/xslices" +) + +type state struct { + index map[uint32]endpoint.Endpoint + + prefer []endpoint.Endpoint + fallback []endpoint.Endpoint + all []endpoint.Endpoint + + rand xrand.Rand +} + +func newState( + endpoints []endpoint.Endpoint, + filter balancerConfig.Filter, + info balancerConfig.Info, + allowFallback bool, +) *state { + res := &state{ + index: xslices.Map(endpoints, func(e endpoint.Endpoint) uint32 { return e.NodeID() }), + rand: xrand.New(xrand.WithLock()), + } + + res.prefer, res.fallback = xslices.Filter(endpoints, func(e endpoint.Endpoint) bool { + return filter.Allow(info, e) + }) + + if allowFallback { + res.all = endpoints + } else { + res.all = res.prefer + res.fallback = nil + } + + return res +} + +func (s *state) PreferredCount() int { + return len(s.prefer) +} + +func (s *state) All() (all []endpoint.Endpoint) { + if s == nil { + return nil + } + + return s.all +} + +func (s *state) Next(ctx context.Context) (_ endpoint.Endpoint, failedCount int) { + if err := ctx.Err(); err != nil { + return nil, 0 + } + + if c := s.preferConnection(ctx); c != nil { + return c, 0 + } + + try := func(endpoints []endpoint.Endpoint) endpoint.Endpoint { + c, tryFailed := s.selectRandomConnection(endpoints, false) + failedCount += tryFailed + + return c + } + + if c := try(s.prefer); c != nil { + return c, failedCount + } + + if c := try(s.fallback); c != nil { + return c, failedCount + } + + c, _ := s.selectRandomConnection(s.all, true) + + return c, failedCount +} + +func (s *state) preferConnection(ctx context.Context) endpoint.Endpoint { + if nodeID, hasPreferEndpoint := endpoint.ContextNodeID(ctx); hasPreferEndpoint { + c := s.index[nodeID] + if c != nil && isOkConnection(c, true) { + return c + } + } + + return nil +} + +func (s *state) selectRandomConnection(endpoints []endpoint.Endpoint, allowBanned bool) (c endpoint.Endpoint, failedConns int) { + connCount := len(endpoints) + if connCount == 0 { + // return for empty list need for prevent panic in fast path + return nil, 0 + } + + // fast path + if c := endpoints[s.rand.Int(connCount)]; isOkConnection(c, allowBanned) { + return c, 0 + } + + // shuffled indexes slices need for guarantee about every connection will check + indexes := make([]int, connCount) + for index := range indexes { + indexes[index] = index + } + s.rand.Shuffle(connCount, func(i, j int) { + indexes[i], indexes[j] = indexes[j], indexes[i] + }) + + for _, index := range indexes { + c := endpoints[index] + if isOkConnection(c, allowBanned) { + return c, 0 + } + failedConns++ + } + + return nil, failedConns +} + +func isOkConnection(c endpoint.Endpoint, bannedIsOk bool) bool { + switch c.GetState() { + case conn.Online, conn.Created, conn.Offline: + return true + case conn.Banned: + return bannedIsOk + default: + return false + } +} diff --git a/internal/balancer/state_test.go b/internal/balancer/state_test.go new file mode 100644 index 000000000..863f43f3e --- /dev/null +++ b/internal/balancer/state_test.go @@ -0,0 +1,416 @@ +package balancer + +import ( + "context" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/xslices" + "strings" + "testing" + + "github.com/stretchr/testify/require" + + balancerConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer/config" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/conn" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/mock" +) + +type filterFunc func(info balancerConfig.Info, e endpoint.Info) bool + +func (f filterFunc) Allow(info balancerConfig.Info, e endpoint.Info) bool { + return f(info, e) +} + +func (f filterFunc) String() string { + return "Custom" +} + +func TestSortPreferConnections(t *testing.T) { + table := []struct { + name string + source []endpoint.Endpoint + allowFallback bool + filter balancerConfig.Filter + prefer []endpoint.Endpoint + fallback []endpoint.Endpoint + }{ + { + name: "Empty", + source: nil, + allowFallback: false, + filter: nil, + prefer: nil, + fallback: nil, + }, + { + name: "NilFilter", + source: []endpoint.Endpoint{ + &mock.Endpoint{AddrField: "1"}, + &mock.Endpoint{AddrField: "2"}, + }, + allowFallback: false, + filter: nil, + prefer: []endpoint.Endpoint{ + &mock.Endpoint{AddrField: "1"}, + &mock.Endpoint{AddrField: "2"}, + }, + fallback: nil, + }, + { + name: "FilterNoFallback", + source: []endpoint.Endpoint{ + &mock.Endpoint{AddrField: "t1"}, + &mock.Endpoint{AddrField: "f1"}, + &mock.Endpoint{AddrField: "t2"}, + &mock.Endpoint{AddrField: "f2"}, + }, + allowFallback: false, + filter: filterFunc(func(_ balancerConfig.Info, e endpoint.Info) bool { + return strings.HasPrefix(e.Address(), "t") + }), + prefer: []endpoint.Endpoint{ + &mock.Endpoint{AddrField: "t1"}, + &mock.Endpoint{AddrField: "t2"}, + }, + fallback: nil, + }, + { + name: "FilterWithFallback", + source: []endpoint.Endpoint{ + &mock.Endpoint{AddrField: "t1"}, + &mock.Endpoint{AddrField: "f1"}, + &mock.Endpoint{AddrField: "t2"}, + &mock.Endpoint{AddrField: "f2"}, + }, + allowFallback: true, + filter: filterFunc(func(_ balancerConfig.Info, e endpoint.Info) bool { + return strings.HasPrefix(e.Address(), "t") + }), + prefer: []endpoint.Endpoint{ + &mock.Endpoint{AddrField: "t1"}, + &mock.Endpoint{AddrField: "t2"}, + }, + fallback: []endpoint.Endpoint{ + &mock.Endpoint{AddrField: "f1"}, + &mock.Endpoint{AddrField: "f2"}, + }, + }, + } + + for _, test := range table { + t.Run(test.name, func(t *testing.T) { + prefer, fallback := xslices.Filter(test.source, func(e endpoint.Endpoint) bool { + return test.filter.Allow(balancerConfig.Info{}, e) + }) + require.Equal(t, test.prefer, prefer) + if test.allowFallback { + require.Equal(t, test.fallback, fallback) + } + }) + } +} + +func TestSelectRandomConnection(t *testing.T) { + s := newState(nil, nil, balancerConfig.Info{}, false) + + t.Run("Empty", func(t *testing.T) { + c, failedCount := s.selectRandomConnection(nil, false) + require.Nil(t, c) + require.Equal(t, 0, failedCount) + }) + + t.Run("One", func(t *testing.T) { + for _, goodState := range []conn.State{conn.Online, conn.Offline, conn.Created} { + c, failedCount := s.selectRandomConnection([]endpoint.Endpoint{&mock.Endpoint{AddrField: "asd", State: goodState}}, false) + require.Equal(t, &mock.Endpoint{AddrField: "asd", State: goodState}, c) + require.Equal(t, 0, failedCount) + } + }) + t.Run("OneBanned", func(t *testing.T) { + c, failedCount := s.selectRandomConnection([]endpoint.Endpoint{&mock.Endpoint{AddrField: "asd", State: conn.Banned}}, false) + require.Nil(t, c) + require.Equal(t, 1, failedCount) + + c, failedCount = s.selectRandomConnection([]endpoint.Endpoint{&mock.Endpoint{AddrField: "asd", State: conn.Banned}}, true) + require.Equal(t, &mock.Endpoint{AddrField: "asd", State: conn.Banned}, c) + require.Equal(t, 0, failedCount) + }) + t.Run("Two", func(t *testing.T) { + conns := []endpoint.Endpoint{ + &mock.Endpoint{AddrField: "1", State: conn.Online}, + &mock.Endpoint{AddrField: "2", State: conn.Online}, + } + first := 0 + second := 0 + for i := 0; i < 100; i++ { + c, _ := s.selectRandomConnection(conns, false) + if e.Address() == "1" { + first++ + } else { + second++ + } + } + require.Equal(t, 100, first+second) + require.InDelta(t, 50, first, 21) + require.InDelta(t, 50, second, 21) + }) + t.Run("TwoBanned", func(t *testing.T) { + conns := []endpoint.Endpoint{ + &mock.Endpoint{AddrField: "1", State: conn.Banned}, + &mock.Endpoint{AddrField: "2", State: conn.Banned}, + } + totalFailed := 0 + for i := 0; i < 100; i++ { + c, failed := s.selectRandomConnection(conns, false) + require.Nil(t, c) + totalFailed += failed + } + require.Equal(t, 200, totalFailed) + }) + t.Run("ThreeWithBanned", func(t *testing.T) { + conns := []endpoint.Endpoint{ + &mock.Endpoint{AddrField: "1", State: conn.Online}, + &mock.Endpoint{AddrField: "2", State: conn.Online}, + &mock.Endpoint{AddrField: "3", State: conn.Banned}, + } + first := 0 + second := 0 + failed := 0 + for i := 0; i < 100; i++ { + c, checkFailed := s.selectRandomConnection(conns, false) + failed += checkFailed + switch e.Address() { + case "1": + first++ + case "2": + second++ + default: + t.Errorf(e.Address()) + } + } + require.Equal(t, 100, first+second) + require.InDelta(t, 50, first, 21) + require.InDelta(t, 50, second, 21) + require.Greater(t, 10, failed) + }) +} + +func TestNewState(t *testing.T) { + table := []struct { + name string + state *state + res *state + }{ + { + name: "Empty", + state: newState(nil, nil, balancerConfig.Info{}, false), + res: &state{ + index: nil, + prefer: nil, + fallback: nil, + all: nil, + }, + }, + { + name: "NoFilter", + state: newState([]endpoint.Endpoint{ + &mock.Endpoint{AddrField: "1", NodeIDField: 1}, + &mock.Endpoint{AddrField: "2", NodeIDField: 2}, + }, nil, balancerConfig.Info{}, false), + res: &state{ + index: map[uint32]endpoint.Endpoint{ + 1: &mock.Endpoint{AddrField: "1", NodeIDField: 1}, + 2: &mock.Endpoint{AddrField: "2", NodeIDField: 2}, + }, + prefer: []endpoint.Endpoint{ + &mock.Endpoint{AddrField: "1", NodeIDField: 1}, + &mock.Endpoint{AddrField: "2", NodeIDField: 2}, + }, + fallback: nil, + all: []endpoint.Endpoint{ + &mock.Endpoint{AddrField: "1", NodeIDField: 1}, + &mock.Endpoint{AddrField: "2", NodeIDField: 2}, + }, + }, + }, + { + name: "FilterDenyFallback", + state: newState([]endpoint.Endpoint{ + &mock.Endpoint{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, + &mock.Endpoint{AddrField: "f1", NodeIDField: 2, LocationField: "f"}, + &mock.Endpoint{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, + &mock.Endpoint{AddrField: "f2", NodeIDField: 4, LocationField: "f"}, + }, filterFunc(func(info balancerConfig.Info, e endpoint.Info) bool { + return info.SelfLocation == e.Location() + }), balancerConfig.Info{SelfLocation: "t"}, false), + res: &state{ + index: map[uint32]endpoint.Endpoint{ + 1: &mock.Endpoint{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, + 2: &mock.Endpoint{AddrField: "f1", NodeIDField: 2, LocationField: "f"}, + 3: &mock.Endpoint{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, + 4: &mock.Endpoint{AddrField: "f2", NodeIDField: 4, LocationField: "f"}, + }, + prefer: []endpoint.Endpoint{ + &mock.Endpoint{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, + &mock.Endpoint{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, + }, + fallback: nil, + all: []endpoint.Endpoint{ + &mock.Endpoint{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, + &mock.Endpoint{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, + }, + }, + }, + { + name: "FilterAllowFallback", + state: newState([]endpoint.Endpoint{ + &mock.Endpoint{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, + &mock.Endpoint{AddrField: "f1", NodeIDField: 2, LocationField: "f"}, + &mock.Endpoint{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, + &mock.Endpoint{AddrField: "f2", NodeIDField: 4, LocationField: "f"}, + }, filterFunc(func(info balancerConfig.Info, e endpoint.Info) bool { + return info.SelfLocation == e.Location() + }), balancerConfig.Info{SelfLocation: "t"}, true), + res: &state{ + index: map[uint32]endpoint.Endpoint{ + 1: &mock.Endpoint{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, + 2: &mock.Endpoint{AddrField: "f1", NodeIDField: 2, LocationField: "f"}, + 3: &mock.Endpoint{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, + 4: &mock.Endpoint{AddrField: "f2", NodeIDField: 4, LocationField: "f"}, + }, + prefer: []endpoint.Endpoint{ + &mock.Endpoint{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, + &mock.Endpoint{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, + }, + fallback: []endpoint.Endpoint{ + &mock.Endpoint{AddrField: "f1", NodeIDField: 2, LocationField: "f"}, + &mock.Endpoint{AddrField: "f2", NodeIDField: 4, LocationField: "f"}, + }, + all: []endpoint.Endpoint{ + &mock.Endpoint{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, + &mock.Endpoint{AddrField: "f1", NodeIDField: 2, LocationField: "f"}, + &mock.Endpoint{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, + &mock.Endpoint{AddrField: "f2", NodeIDField: 4, LocationField: "f"}, + }, + }, + }, + { + name: "WithNodeID", + state: newState([]endpoint.Endpoint{ + &mock.Endpoint{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, + &mock.Endpoint{AddrField: "f1", NodeIDField: 2, LocationField: "f"}, + &mock.Endpoint{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, + &mock.Endpoint{AddrField: "f2", NodeIDField: 4, LocationField: "f"}, + }, filterFunc(func(info balancerConfig.Info, e endpoint.Info) bool { + return info.SelfLocation == e.Location() + }), balancerConfig.Info{SelfLocation: "t"}, true), + res: &state{ + index: map[uint32]endpoint.Endpoint{ + 1: &mock.Endpoint{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, + 2: &mock.Endpoint{AddrField: "f1", NodeIDField: 2, LocationField: "f"}, + 3: &mock.Endpoint{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, + 4: &mock.Endpoint{AddrField: "f2", NodeIDField: 4, LocationField: "f"}, + }, + prefer: []endpoint.Endpoint{ + &mock.Endpoint{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, + &mock.Endpoint{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, + }, + fallback: []endpoint.Endpoint{ + &mock.Endpoint{AddrField: "f1", NodeIDField: 2, LocationField: "f"}, + &mock.Endpoint{AddrField: "f2", NodeIDField: 4, LocationField: "f"}, + }, + all: []endpoint.Endpoint{ + &mock.Endpoint{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, + &mock.Endpoint{AddrField: "f1", NodeIDField: 2, LocationField: "f"}, + &mock.Endpoint{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, + &mock.Endpoint{AddrField: "f2", NodeIDField: 4, LocationField: "f"}, + }, + }, + }, + } + + for _, test := range table { + t.Run(test.name, func(t *testing.T) { + require.NotNil(t, test.state.rand) + test.state.rand = nil + require.Equal(t, test.res, test.state) + }) + } +} + +func TestConnection(t *testing.T) { + t.Run("Empty", func(t *testing.T) { + s := newState(nil, nil, balancerConfig.Info{}, false) + c, failed := s.Next(context.Background()) + require.Nil(t, c) + require.Equal(t, 0, failed) + }) + t.Run("AllGood", func(t *testing.T) { + s := newState([]endpoint.Endpoint{ + &mock.Endpoint{AddrField: "1", State: conn.Online}, + &mock.Endpoint{AddrField: "2", State: conn.Online}, + }, nil, balancerConfig.Info{}, false) + c, failed := s.Next(context.Background()) + require.NotNil(t, c) + require.Equal(t, 0, failed) + }) + t.Run("WithBanned", func(t *testing.T) { + s := newState([]endpoint.Endpoint{ + &mock.Endpoint{AddrField: "1", State: conn.Online}, + &mock.Endpoint{AddrField: "2", State: conn.Banned}, + }, nil, balancerConfig.Info{}, false) + c, _ := s.Next(context.Background()) + require.Equal(t, &mock.Endpoint{AddrField: "1", State: conn.Online}, c) + }) + t.Run("AllBanned", func(t *testing.T) { + s := newState([]endpoint.Endpoint{ + &mock.Endpoint{AddrField: "t1", State: conn.Banned, LocationField: "t"}, + &mock.Endpoint{AddrField: "f2", State: conn.Banned, LocationField: "f"}, + }, filterFunc(func(info balancerConfig.Info, c endpoint.Endpoint) bool { + return e.Location() == info.SelfLocation + }), balancerConfig.Info{}, true) + preferred := 0 + fallback := 0 + for i := 0; i < 100; i++ { + c, failed := s.Next(context.Background()) + require.NotNil(t, c) + require.Equal(t, 2, failed) + if e.Address() == "t1" { + preferred++ + } else { + fallback++ + } + } + require.Equal(t, 100, preferred+fallback) + require.InDelta(t, 50, preferred, 21) + require.InDelta(t, 50, fallback, 21) + }) + t.Run("PreferBannedWithFallback", func(t *testing.T) { + s := newState([]endpoint.Endpoint{ + &mock.Endpoint{AddrField: "t1", State: conn.Banned, LocationField: "t"}, + &mock.Endpoint{AddrField: "f2", State: conn.Online, LocationField: "f"}, + }, filterFunc(func(info balancerConfig.Info, c endpoint.Endpoint) bool { + return e.Location() == info.SelfLocation + }), balancerConfig.Info{SelfLocation: "t"}, true) + c, failed := s.Next(context.Background()) + require.Equal(t, &mock.Endpoint{AddrField: "f2", State: conn.Online, LocationField: "f"}, c) + require.Equal(t, 1, failed) + }) + t.Run("PreferNodeID", func(t *testing.T) { + s := newState([]endpoint.Endpoint{ + &mock.Endpoint{AddrField: "1", State: conn.Online, NodeIDField: 1}, + &mock.Endpoint{AddrField: "2", State: conn.Online, NodeIDField: 2}, + }, nil, balancerConfig.Info{}, false) + c, failed := s.Next(endpoint.WithNodeID(context.Background(), 2)) + require.Equal(t, &mock.Endpoint{AddrField: "2", State: conn.Online, NodeIDField: 2}, c) + require.Equal(t, 0, failed) + }) + t.Run("PreferNodeIDWithBadState", func(t *testing.T) { + s := newState([]endpoint.Endpoint{ + &mock.Endpoint{AddrField: "1", State: conn.Online, NodeIDField: 1}, + &mock.Endpoint{AddrField: "2", State: conn.Unknown, NodeIDField: 2}, + }, nil, balancerConfig.Info{}, false) + c, failed := s.Next(endpoint.WithNodeID(context.Background(), 2)) + require.Equal(t, &mock.Endpoint{AddrField: "1", State: conn.Online, NodeIDField: 1}, c) + require.Equal(t, 0, failed) + }) +}