Skip to content

Commit

Permalink
[CAPPL-474] Cleanup channels
Browse files Browse the repository at this point in the history
  • Loading branch information
cedric-cordenier committed Jan 21, 2025
1 parent 74f11f9 commit f155ebe
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 26 deletions.
81 changes: 59 additions & 22 deletions core/capabilities/webapi/outgoing_connector_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,11 @@ var _ connector.GatewayConnectorHandler = &OutgoingConnectorHandler{}

type OutgoingConnectorHandler struct {
services.StateMachine
gc connector.GatewayConnector
method string
lggr logger.Logger
responseChs map[string]chan *api.Message
responseChsMu sync.Mutex
rateLimiter *common.RateLimiter
gc connector.GatewayConnector
method string
lggr logger.Logger
rateLimiter *common.RateLimiter
responses *responses
}

func NewOutgoingConnectorHandler(gc connector.GatewayConnector, config ServiceConfig, method string, lgger logger.Logger) (*OutgoingConnectorHandler, error) {
Expand All @@ -44,14 +43,12 @@ func NewOutgoingConnectorHandler(gc connector.GatewayConnector, config ServiceCo
return nil, fmt.Errorf("invalid outgoing connector handler method: %s", method)
}

responseChs := make(map[string]chan *api.Message)
return &OutgoingConnectorHandler{
gc: gc,
method: method,
responseChs: responseChs,
responseChsMu: sync.Mutex{},
rateLimiter: rateLimiter,
lggr: lgger,
gc: gc,
method: method,
responses: newResponses(),
rateLimiter: rateLimiter,
lggr: lgger,
}, nil
}

Expand All @@ -74,10 +71,12 @@ func (c *OutgoingConnectorHandler) HandleSingleNodeRequest(ctx context.Context,
return nil, fmt.Errorf("failed to marshal fetch request: %w", err)
}

ch := make(chan *api.Message, 1)
c.responseChsMu.Lock()
c.responseChs[messageID] = ch
c.responseChsMu.Unlock()
ch, err := c.responses.new(messageID)
if err != nil {
return nil, fmt.Errorf("duplicate message received for ID: %s", messageID)
}
defer c.responses.cleanup(messageID)

l := logger.With(c.lggr, "messageID", messageID)
l.Debugw("sending request to gateway")

Expand Down Expand Up @@ -136,16 +135,14 @@ func (c *OutgoingConnectorHandler) HandleGatewayMessage(ctx context.Context, gat
l.Errorw("failed to unmarshal payload", "err", err)
return
}
c.responseChsMu.Lock()
defer c.responseChsMu.Unlock()
ch, ok := c.responseChs[body.MessageId]
ch, ok := c.responses.get(body.MessageId)
if !ok {
l.Errorw("no response channel found")
l.Warnw("no response channel found; this may indicate that the node timed out the request")
return
}
select {
case ch <- msg:
delete(c.responseChs, body.MessageId)
return
case <-ctx.Done():
return
}
Expand Down Expand Up @@ -182,3 +179,43 @@ func validMethod(method string) bool {
return false
}
}

func newResponses() *responses {
return &responses{
chs: map[string]chan *api.Message{},
}
}

type responses struct {
chs map[string]chan *api.Message
mu sync.RWMutex
}

func (r *responses) new(id string) (chan *api.Message, error) {
r.mu.Lock()
defer r.mu.Unlock()

_, ok := r.chs[id]
if ok {
return nil, fmt.Errorf("already have response for id: %s", id)
}

// Buffered so we don't wait if sending
ch := make(chan *api.Message, 1)
r.chs[id] = ch
return ch, nil
}

func (r *responses) cleanup(id string) {
r.mu.Lock()
defer r.mu.Unlock()

delete(r.chs, id)
}

func (r *responses) get(id string) (chan *api.Message, bool) {
r.mu.RLock()
defer r.mu.RUnlock()
ch, ok := r.chs[id]
return ch, ok
}
57 changes: 55 additions & 2 deletions core/capabilities/webapi/outgoing_connector_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"

Expand All @@ -19,7 +20,7 @@ import (
)

func TestHandleSingleNodeRequest(t *testing.T) {
t.Run("OK-timeout_is_not_specify_default_timeout_is_expected", func(t *testing.T) {
t.Run("uses default timeout if no timeout is provided", func(t *testing.T) {
ctx := tests.Context(t)
log := logger.TestLogger(t)
connector := gcmocks.NewGatewayConnector(t)
Expand Down Expand Up @@ -66,7 +67,7 @@ func TestHandleSingleNodeRequest(t *testing.T) {
require.NoError(t, err)
})

t.Run("OK-timeout_is_specified", func(t *testing.T) {
t.Run("uses timeout", func(t *testing.T) {
ctx := tests.Context(t)
log := logger.TestLogger(t)
connector := gcmocks.NewGatewayConnector(t)
Expand Down Expand Up @@ -111,8 +112,60 @@ func TestHandleSingleNodeRequest(t *testing.T) {
URL: testURL,
TimeoutMs: 40000,
})
_, found := connectorHandler.responses.get(msgID)
assert.False(t, found)
require.NoError(t, err)
})

t.Run("cleans up in event of a timeout", func(t *testing.T) {
ctx := tests.Context(t)
log := logger.TestLogger(t)
connector := gcmocks.NewGatewayConnector(t)
var defaultConfig = ServiceConfig{
RateLimiter: common.RateLimiterConfig{
GlobalRPS: 100.0,
GlobalBurst: 100,
PerSenderRPS: 100.0,
PerSenderBurst: 100,
},
}
connectorHandler, err := NewOutgoingConnectorHandler(connector, defaultConfig, ghcapabilities.MethodComputeAction, log)
require.NoError(t, err)

msgID := "msgID"
testURL := "http://localhost:8080"
connector.EXPECT().DonID().Return("donID")
connector.EXPECT().AwaitConnection(matches.AnyContext, "gateway1").Return(nil)
connector.EXPECT().GatewayIDs().Return([]string{"gateway1"})

// build the expected body with the defined timeout
req := ghcapabilities.Request{
URL: testURL,
TimeoutMs: 10,
}
payload, err := json.Marshal(req)
require.NoError(t, err)

expectedBody := &api.MessageBody{
MessageId: msgID,
DonId: connector.DonID(),
Method: ghcapabilities.MethodComputeAction,
Payload: payload,
}

// expect the request body to contain the defined timeout
connector.EXPECT().SignAndSendToGateway(mock.Anything, "gateway1", expectedBody).Run(func(ctx context.Context, gatewayID string, msg *api.MessageBody) {
// don't call HandleGatewayMessage here; i.e. simulate a failure to receive a response
}).Return(nil).Times(1)

_, err = connectorHandler.HandleSingleNodeRequest(ctx, msgID, ghcapabilities.Request{
URL: testURL,
TimeoutMs: 10,
})
_, found := connectorHandler.responses.get(msgID)
assert.False(t, found)
assert.ErrorIs(t, err, context.DeadlineExceeded)
})
}

func gatewayResponse(t *testing.T, msgID string) *api.Message {
Expand Down
10 changes: 8 additions & 2 deletions core/services/gateway/handlers/capabilities/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,14 @@ func (h *handler) handleWebAPIOutgoingMessage(ctx context.Context, msg *api.Mess
},
}
}
// this signature is not verified by the node because
// WS connection between gateway and node are already verified

// Work around the fact that the connection manager expects all messages
// to have a valid signature by reusing the signature that came with the message.
// This is OK to do because:
// - our trust model for Gateways assumes that we can trust the Gateway node. This is a central assumption since
// the Gateway node has access to plaintext secrets sent by DON nodes.
// - the connection between the Gateway and DON Node is already authorized via a DON-side and Gateway-side
// allowlist, and secured via TLS.
respMsg.Signature = msg.Signature

err = h.don.SendToNode(newCtx, nodeAddr, respMsg)
Expand Down

0 comments on commit f155ebe

Please sign in to comment.