Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(capabilities/webapi): use a round robin selector in handler #16170

Merged
merged 1 commit into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions core/capabilities/compute/compute_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ func setup(t *testing.T, config Config) testHarness {
registry := capabilities.NewRegistry(log)
connector := gcmocks.NewGatewayConnector(t)
idGeneratorFn := func() string { return validRequestUUID }
connector.EXPECT().GatewayIDs().Return([]string{"gateway1"})
connectorHandler, err := webapi.NewOutgoingConnectorHandler(connector, config.ServiceConfig, ghcapabilities.MethodComputeAction, log)
require.NoError(t, err)

Expand Down
36 changes: 16 additions & 20 deletions core/capabilities/webapi/outgoing_connector_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"encoding/json"
"fmt"
"sort"
"sync"
"time"

Expand All @@ -26,11 +25,12 @@ var _ connector.GatewayConnectorHandler = &OutgoingConnectorHandler{}

type OutgoingConnectorHandler struct {
services.StateMachine
gc connector.GatewayConnector
method string
lggr logger.Logger
rateLimiter *common.RateLimiter
responses *responses
gc connector.GatewayConnector
gatewaySelector *RoundRobinSelector
method string
lggr logger.Logger
rateLimiter *common.RateLimiter
responses *responses
}

func NewOutgoingConnectorHandler(gc connector.GatewayConnector, config ServiceConfig, method string, lgger logger.Logger) (*OutgoingConnectorHandler, error) {
Expand All @@ -44,11 +44,12 @@ func NewOutgoingConnectorHandler(gc connector.GatewayConnector, config ServiceCo
}

return &OutgoingConnectorHandler{
gc: gc,
method: method,
responses: newResponses(),
rateLimiter: rateLimiter,
lggr: lgger,
gc: gc,
gatewaySelector: NewRoundRobinSelector(gc.GatewayIDs()),
cedric-cordenier marked this conversation as resolved.
Show resolved Hide resolved
method: method,
responses: newResponses(),
rateLimiter: rateLimiter,
lggr: lgger,
}, nil
}

Expand Down Expand Up @@ -87,15 +88,10 @@ func (c *OutgoingConnectorHandler) HandleSingleNodeRequest(ctx context.Context,
Payload: payload,
}

// simply, send request to first available gateway node from sorted list
// this allows for deterministic selection of gateway node receiver for easier debugging
gatewayIDs := c.gc.GatewayIDs()
if len(gatewayIDs) == 0 {
return nil, errors.New("no gateway nodes available")
selectedGateway, err := c.gatewaySelector.NextGateway()
if err != nil {
return nil, fmt.Errorf("failed to select gateway: %w", err)
}
sort.Strings(gatewayIDs)

selectedGateway := gatewayIDs[0]

l.Infow("selected gateway, awaiting connection", "gatewayID", selectedGateway)

Expand All @@ -109,7 +105,7 @@ func (c *OutgoingConnectorHandler) HandleSingleNodeRequest(ctx context.Context,

select {
case resp := <-ch:
l.Debugw("received response from gateway")
l.Debugw("received response from gateway", "gatewayID", selectedGateway)
return resp, nil
case <-ctx.Done():
return nil, ctx.Err()
Expand Down
135 changes: 90 additions & 45 deletions core/capabilities/webapi/outgoing_connector_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,16 @@ import (
func TestHandleSingleNodeRequest(t *testing.T) {
t.Run("uses default timeout if no timeout is provided", func(t *testing.T) {
ctx := tests.Context(t)
log := logger.TestLogger(t)
connector := gcmocks.NewGatewayConnector(t)
var defaultConfig = ServiceConfig{
RateLimiter: common.RateLimiterConfig{
GlobalRPS: 100.0,
GlobalBurst: 100,
PerSenderRPS: 100.0,
PerSenderBurst: 100,
},
}
connectorHandler, err := NewOutgoingConnectorHandler(connector, defaultConfig, ghcapabilities.MethodComputeAction, log)
require.NoError(t, err)

msgID := "msgID"
testURL := "http://localhost:8080"
connector.EXPECT().DonID().Return("donID")
connector.EXPECT().AwaitConnection(matches.AnyContext, "gateway1").Return(nil)
connector.EXPECT().GatewayIDs().Return([]string{"gateway1"})
connector, connectorHandler := newFunction(
t,
func(gc *gcmocks.GatewayConnector) {
gc.EXPECT().DonID().Return("donID")
gc.EXPECT().AwaitConnection(matches.AnyContext, "gateway1").Return(nil)
gc.EXPECT().GatewayIDs().Return([]string{"gateway1"})
},
)

// build the expected body with the default timeout
req := ghcapabilities.Request{
Expand Down Expand Up @@ -67,26 +59,68 @@ func TestHandleSingleNodeRequest(t *testing.T) {
require.NoError(t, err)
})

t.Run("uses timeout", func(t *testing.T) {
t.Run("subsequent request uses gateway 2", func(t *testing.T) {
ctx := tests.Context(t)
log := logger.TestLogger(t)
connector := gcmocks.NewGatewayConnector(t)
var defaultConfig = ServiceConfig{
RateLimiter: common.RateLimiterConfig{
GlobalRPS: 100.0,
GlobalBurst: 100,
PerSenderRPS: 100.0,
PerSenderBurst: 100,
msgID := "msgID"
testURL := "http://localhost:8080"
connector, connectorHandler := newFunction(
t,
func(gc *gcmocks.GatewayConnector) {
gc.EXPECT().DonID().Return("donID")
gc.EXPECT().AwaitConnection(matches.AnyContext, "gateway1").Return(nil).Once()
gc.EXPECT().AwaitConnection(matches.AnyContext, "gateway2").Return(nil).Once()
gc.EXPECT().GatewayIDs().Return([]string{"gateway1", "gateway2"})
},
)

// build the expected body with the default timeout
req := ghcapabilities.Request{
URL: testURL,
TimeoutMs: defaultFetchTimeoutMs,
}
connectorHandler, err := NewOutgoingConnectorHandler(connector, defaultConfig, ghcapabilities.MethodComputeAction, log)
payload, err := json.Marshal(req)
require.NoError(t, err)

expectedBody := &api.MessageBody{
MessageId: msgID,
DonId: connector.DonID(),
Method: ghcapabilities.MethodComputeAction,
Payload: payload,
}

// expect call to be made to gateway 1
connector.EXPECT().SignAndSendToGateway(mock.Anything, "gateway1", expectedBody).Run(func(ctx context.Context, gatewayID string, msg *api.MessageBody) {
connectorHandler.HandleGatewayMessage(ctx, "gateway1", gatewayResponse(t, msgID))
}).Return(nil).Times(1)

_, err = connectorHandler.HandleSingleNodeRequest(ctx, msgID, ghcapabilities.Request{
URL: testURL,
})
require.NoError(t, err)

// expect call to be made to gateway 2
connector.EXPECT().SignAndSendToGateway(mock.Anything, "gateway2", expectedBody).Run(func(ctx context.Context, gatewayID string, msg *api.MessageBody) {
connectorHandler.HandleGatewayMessage(ctx, "gateway2", gatewayResponse(t, msgID))
}).Return(nil).Times(1)

_, err = connectorHandler.HandleSingleNodeRequest(ctx, msgID, ghcapabilities.Request{
URL: testURL,
})
require.NoError(t, err)
})

t.Run("uses timeout", func(t *testing.T) {
ctx := tests.Context(t)
msgID := "msgID"
testURL := "http://localhost:8080"
connector.EXPECT().DonID().Return("donID")
connector.EXPECT().AwaitConnection(matches.AnyContext, "gateway1").Return(nil)
connector.EXPECT().GatewayIDs().Return([]string{"gateway1"})
connector, connectorHandler := newFunction(
t,
func(gc *gcmocks.GatewayConnector) {
gc.EXPECT().DonID().Return("donID")
gc.EXPECT().AwaitConnection(matches.AnyContext, "gateway1").Return(nil)
gc.EXPECT().GatewayIDs().Return([]string{"gateway1"})
},
)

// build the expected body with the defined timeout
req := ghcapabilities.Request{
Expand Down Expand Up @@ -119,24 +153,16 @@ func TestHandleSingleNodeRequest(t *testing.T) {

t.Run("cleans up in event of a timeout", func(t *testing.T) {
ctx := tests.Context(t)
log := logger.TestLogger(t)
connector := gcmocks.NewGatewayConnector(t)
var defaultConfig = ServiceConfig{
RateLimiter: common.RateLimiterConfig{
GlobalRPS: 100.0,
GlobalBurst: 100,
PerSenderRPS: 100.0,
PerSenderBurst: 100,
},
}
connectorHandler, err := NewOutgoingConnectorHandler(connector, defaultConfig, ghcapabilities.MethodComputeAction, log)
require.NoError(t, err)

msgID := "msgID"
testURL := "http://localhost:8080"
connector.EXPECT().DonID().Return("donID")
connector.EXPECT().AwaitConnection(matches.AnyContext, "gateway1").Return(nil)
connector.EXPECT().GatewayIDs().Return([]string{"gateway1"})
connector, connectorHandler := newFunction(
t,
func(gc *gcmocks.GatewayConnector) {
gc.EXPECT().DonID().Return("donID")
gc.EXPECT().AwaitConnection(matches.AnyContext, "gateway1").Return(nil)
gc.EXPECT().GatewayIDs().Return([]string{"gateway1"})
},
)

// build the expected body with the defined timeout
req := ghcapabilities.Request{
Expand Down Expand Up @@ -168,6 +194,25 @@ func TestHandleSingleNodeRequest(t *testing.T) {
})
}

func newFunction(t *testing.T, mockFn func(*gcmocks.GatewayConnector)) (*gcmocks.GatewayConnector, *OutgoingConnectorHandler) {
log := logger.TestLogger(t)
connector := gcmocks.NewGatewayConnector(t)
var defaultConfig = ServiceConfig{
RateLimiter: common.RateLimiterConfig{
GlobalRPS: 100.0,
GlobalBurst: 100,
PerSenderRPS: 100.0,
PerSenderBurst: 100,
},
}

mockFn(connector)

connectorHandler, err := NewOutgoingConnectorHandler(connector, defaultConfig, ghcapabilities.MethodComputeAction, log)
require.NoError(t, err)
return connector, connectorHandler
}

func gatewayResponse(t *testing.T, msgID string) *api.Message {
headers := map[string]string{"Content-Type": "application/json"}
body := []byte("response body")
Expand Down
35 changes: 35 additions & 0 deletions core/capabilities/webapi/round_robin_selector.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package webapi

import (
"sync"

"github.com/pkg/errors"
)

var ErrNoGateways = errors.New("no gateways available")

type RoundRobinSelector struct {
items []string
index int
mu sync.Mutex
}

func NewRoundRobinSelector(items []string) *RoundRobinSelector {
return &RoundRobinSelector{
items: items,
index: 0,
}
}

func (r *RoundRobinSelector) NextGateway() (string, error) {
r.mu.Lock()
defer r.mu.Unlock()

if len(r.items) == 0 {
return "", ErrNoGateways
}

item := r.items[r.index]
r.index = (r.index + 1) % len(r.items)
return item, nil
}
64 changes: 64 additions & 0 deletions core/capabilities/webapi/round_robin_selector_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package webapi

import (
"sync"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestRoundRobinSelector(t *testing.T) {
gateways := []string{"gateway1", "gateway2", "gateway3"}
rr := NewRoundRobinSelector(gateways)

expectedOrder := []string{"gateway1", "gateway2", "gateway3", "gateway1", "gateway2", "gateway3"}

for i, expected := range expectedOrder {
got, err := rr.NextGateway()
require.NoError(t, err, "unexpected error on iteration %d", i)
assert.Equal(t, expected, got, "unexpected gateway at iteration %d", i)
}
}

func TestRoundRobinSelector_Empty(t *testing.T) {
rr := NewRoundRobinSelector([]string{})

_, err := rr.NextGateway()
assert.ErrorIs(t, err, ErrNoGateways, "expected ErrNoGateways when no gateways are available")
}

func TestRoundRobinSelector_Concurrency(t *testing.T) {
gateways := []string{"gateway1", "gateway2", "gateway3"}
rr := NewRoundRobinSelector(gateways)

var wg sync.WaitGroup
numRequests := 100
results := make(chan string, numRequests)

for i := 0; i < numRequests; i++ {
wg.Add(1)
go func() {
defer wg.Done()
gw, err := rr.NextGateway()
if err != nil {
t.Errorf("unexpected error: %v", err)
return
}
results <- gw
}()
}

wg.Wait()
close(results)

counts := make(map[string]int)
for result := range results {
counts[result]++
}

expectedCount := numRequests / len(gateways)
for _, gateway := range gateways {
assert.InDelta(t, expectedCount, counts[gateway], 1, "unexpected request distribution for %s", gateway)
}
}
1 change: 1 addition & 0 deletions core/capabilities/webapi/target/target_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ func setup(t *testing.T, config webapi.ServiceConfig) testHarness {
registry := registrymock.NewCapabilitiesRegistry(t)
connector := gcmocks.NewGatewayConnector(t)
lggr := logger.Test(t)
connector.EXPECT().GatewayIDs().Return([]string{"gateway1"})
connectorHandler, err := webapi.NewOutgoingConnectorHandler(connector, config, ghcapabilities.MethodWebAPITarget, lggr)
require.NoError(t, err)

Expand Down
2 changes: 2 additions & 0 deletions core/services/workflows/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1579,6 +1579,7 @@ func TestEngine_WithCustomComputeStep(t *testing.T) {
}

connector := gcmocks.NewGatewayConnector(t)
connector.EXPECT().GatewayIDs().Return([]string{"gateway1"})
handler, err := webapi.NewOutgoingConnectorHandler(
connector,
cfg.ServiceConfig,
Expand Down Expand Up @@ -1647,6 +1648,7 @@ func TestEngine_CustomComputePropagatesBreaks(t *testing.T) {
},
}
connector := gcmocks.NewGatewayConnector(t)
connector.EXPECT().GatewayIDs().Return([]string{"gateway1"})
handler, err := webapi.NewOutgoingConnectorHandler(
connector,
cfg.ServiceConfig,
Expand Down
2 changes: 1 addition & 1 deletion core/services/workflows/syncer/fetcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ func TestNewFetcherService(t *testing.T) {

t.Run("OK-valid_request", func(t *testing.T) {
connector.EXPECT().AddHandler([]string{capabilities.MethodWorkflowSyncer}, mock.Anything).Return(nil)
connector.EXPECT().GatewayIDs().Return([]string{"gateway1", "gateway2"})

fetcher := NewFetcherService(lggr, wrapper)
require.NoError(t, fetcher.Start(ctx))
Expand All @@ -56,7 +57,6 @@ func TestNewFetcherService(t *testing.T) {
}).Return(nil).Times(1)
connector.EXPECT().DonID().Return(donID)
connector.EXPECT().AwaitConnection(matches.AnyContext, "gateway1").Return(nil)
connector.EXPECT().GatewayIDs().Return([]string{"gateway1", "gateway2"})

payload, err := fetcher.Fetch(ctx, url, 0)
require.NoError(t, err)
Expand Down
Loading
Loading