Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CAPPL-474] Cleanup channels #15994

Merged
merged 2 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 59 additions & 22 deletions core/capabilities/webapi/outgoing_connector_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,11 @@ var _ connector.GatewayConnectorHandler = &OutgoingConnectorHandler{}

type OutgoingConnectorHandler struct {
services.StateMachine
gc connector.GatewayConnector
method string
lggr logger.Logger
responseChs map[string]chan *api.Message
responseChsMu sync.Mutex
rateLimiter *common.RateLimiter
gc connector.GatewayConnector
method string
lggr logger.Logger
rateLimiter *common.RateLimiter
responses *responses
}

func NewOutgoingConnectorHandler(gc connector.GatewayConnector, config ServiceConfig, method string, lgger logger.Logger) (*OutgoingConnectorHandler, error) {
Expand All @@ -44,14 +43,12 @@ func NewOutgoingConnectorHandler(gc connector.GatewayConnector, config ServiceCo
return nil, fmt.Errorf("invalid outgoing connector handler method: %s", method)
}

responseChs := make(map[string]chan *api.Message)
return &OutgoingConnectorHandler{
gc: gc,
method: method,
responseChs: responseChs,
responseChsMu: sync.Mutex{},
rateLimiter: rateLimiter,
lggr: lgger,
gc: gc,
method: method,
responses: newResponses(),
rateLimiter: rateLimiter,
lggr: lgger,
}, nil
}

Expand All @@ -74,10 +71,12 @@ func (c *OutgoingConnectorHandler) HandleSingleNodeRequest(ctx context.Context,
return nil, fmt.Errorf("failed to marshal fetch request: %w", err)
}

ch := make(chan *api.Message, 1)
c.responseChsMu.Lock()
c.responseChs[messageID] = ch
c.responseChsMu.Unlock()
ch, err := c.responses.new(messageID)
if err != nil {
return nil, fmt.Errorf("duplicate message received for ID: %s", messageID)
}
defer c.responses.cleanup(messageID)

l := logger.With(c.lggr, "messageID", messageID)
l.Debugw("sending request to gateway")

Expand Down Expand Up @@ -136,16 +135,14 @@ func (c *OutgoingConnectorHandler) HandleGatewayMessage(ctx context.Context, gat
l.Errorw("failed to unmarshal payload", "err", err)
return
}
c.responseChsMu.Lock()
defer c.responseChsMu.Unlock()
ch, ok := c.responseChs[body.MessageId]
ch, ok := c.responses.get(body.MessageId)
if !ok {
l.Errorw("no response channel found")
l.Warnw("no response channel found; this may indicate that the node timed out the request")
return
}
select {
case ch <- msg:
delete(c.responseChs, body.MessageId)
return
case <-ctx.Done():
return
}
Expand Down Expand Up @@ -182,3 +179,43 @@ func validMethod(method string) bool {
return false
}
}

func newResponses() *responses {
return &responses{
chs: map[string]chan *api.Message{},
}
}

type responses struct {
chs map[string]chan *api.Message
mu sync.RWMutex
}

func (r *responses) new(id string) (chan *api.Message, error) {
r.mu.Lock()
defer r.mu.Unlock()

_, ok := r.chs[id]
if ok {
return nil, fmt.Errorf("already have response for id: %s", id)
}

// Buffered so we don't wait if sending
ch := make(chan *api.Message, 1)
r.chs[id] = ch
return ch, nil
}

func (r *responses) cleanup(id string) {
r.mu.Lock()
defer r.mu.Unlock()

delete(r.chs, id)
}

func (r *responses) get(id string) (chan *api.Message, bool) {
r.mu.RLock()
defer r.mu.RUnlock()
ch, ok := r.chs[id]
return ch, ok
}
57 changes: 55 additions & 2 deletions core/capabilities/webapi/outgoing_connector_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"

Expand All @@ -19,7 +20,7 @@ import (
)

func TestHandleSingleNodeRequest(t *testing.T) {
t.Run("OK-timeout_is_not_specify_default_timeout_is_expected", func(t *testing.T) {
t.Run("uses default timeout if no timeout is provided", func(t *testing.T) {
ctx := tests.Context(t)
log := logger.TestLogger(t)
connector := gcmocks.NewGatewayConnector(t)
Expand Down Expand Up @@ -66,7 +67,7 @@ func TestHandleSingleNodeRequest(t *testing.T) {
require.NoError(t, err)
})

t.Run("OK-timeout_is_specified", func(t *testing.T) {
t.Run("uses timeout", func(t *testing.T) {
ctx := tests.Context(t)
log := logger.TestLogger(t)
connector := gcmocks.NewGatewayConnector(t)
Expand Down Expand Up @@ -111,8 +112,60 @@ func TestHandleSingleNodeRequest(t *testing.T) {
URL: testURL,
TimeoutMs: 40000,
})
_, found := connectorHandler.responses.get(msgID)
assert.False(t, found)
require.NoError(t, err)
})

t.Run("cleans up in event of a timeout", func(t *testing.T) {
ctx := tests.Context(t)
log := logger.TestLogger(t)
connector := gcmocks.NewGatewayConnector(t)
var defaultConfig = ServiceConfig{
RateLimiter: common.RateLimiterConfig{
GlobalRPS: 100.0,
GlobalBurst: 100,
PerSenderRPS: 100.0,
PerSenderBurst: 100,
},
}
connectorHandler, err := NewOutgoingConnectorHandler(connector, defaultConfig, ghcapabilities.MethodComputeAction, log)
require.NoError(t, err)

msgID := "msgID"
testURL := "http://localhost:8080"
connector.EXPECT().DonID().Return("donID")
connector.EXPECT().AwaitConnection(matches.AnyContext, "gateway1").Return(nil)
connector.EXPECT().GatewayIDs().Return([]string{"gateway1"})

// build the expected body with the defined timeout
req := ghcapabilities.Request{
URL: testURL,
TimeoutMs: 10,
}
payload, err := json.Marshal(req)
require.NoError(t, err)

expectedBody := &api.MessageBody{
MessageId: msgID,
DonId: connector.DonID(),
Method: ghcapabilities.MethodComputeAction,
Payload: payload,
}

// expect the request body to contain the defined timeout
connector.EXPECT().SignAndSendToGateway(mock.Anything, "gateway1", expectedBody).Run(func(ctx context.Context, gatewayID string, msg *api.MessageBody) {
// don't call HandleGatewayMessage here; i.e. simulate a failure to receive a response
}).Return(nil).Times(1)

_, err = connectorHandler.HandleSingleNodeRequest(ctx, msgID, ghcapabilities.Request{
URL: testURL,
TimeoutMs: 10,
})
_, found := connectorHandler.responses.get(msgID)
assert.False(t, found)
assert.ErrorIs(t, err, context.DeadlineExceeded)
})
}

func gatewayResponse(t *testing.T, msgID string) *api.Message {
Expand Down
10 changes: 8 additions & 2 deletions core/services/gateway/handlers/capabilities/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,14 @@ func (h *handler) handleWebAPIOutgoingMessage(ctx context.Context, msg *api.Mess
},
}
}
// this signature is not verified by the node because
// WS connection between gateway and node are already verified

// Work around the fact that the connection manager expects all messages
// to have a valid signature by reusing the signature that came with the message.
// This is OK to do because:
// - our trust model for Gateways assumes that we can trust the Gateway node. This is a central assumption since
// the Gateway node has access to plaintext secrets sent by DON nodes.
// - the connection between the Gateway and DON Node is already authorized via a DON-side and Gateway-side
// allowlist, and secured via TLS.
respMsg.Signature = msg.Signature

err = h.don.SendToNode(newCtx, nodeAddr, respMsg)
Expand Down
Loading