Skip to content

Commit

Permalink
check router is closing in requests (#3157)
Browse files Browse the repository at this point in the history
Co-authored-by: Stephen Buttolph <[email protected]>
  • Loading branch information
ceyonur and StephenButtolph authored Jul 1, 2024
1 parent ffed367 commit ae35eeb
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 0 deletions.
67 changes: 67 additions & 0 deletions snow/networking/router/chain_router.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
var (
errUnknownChain = errors.New("received message for unknown chain")
errUnallowedNode = errors.New("received message from non-allowed node")
errClosing = errors.New("router is closing")

_ Router = (*ChainRouter)(nil)
_ benchlist.Benchable = (*ChainRouter)(nil)
Expand Down Expand Up @@ -63,6 +64,7 @@ type ChainRouter struct {
clock mockable.Clock
log logging.Logger
lock sync.Mutex
closing bool
chainHandlers map[ids.ID]handler.Handler

// It is only safe to call [RegisterResponse] with the router lock held. Any
Expand Down Expand Up @@ -154,6 +156,18 @@ func (cr *ChainRouter) RegisterRequest(
engineType p2p.EngineType,
) {
cr.lock.Lock()
if cr.closing {
cr.log.Debug("dropping request",
zap.Stringer("nodeID", nodeID),
zap.Stringer("requestingChainID", requestingChainID),
zap.Stringer("respondingChainID", respondingChainID),
zap.Uint32("requestID", requestID),
zap.Stringer("messageOp", op),
zap.Error(errClosing),
)
cr.lock.Unlock()
return
}
// When we receive a response message type (Chits, Put, Accepted, etc.)
// we validate that we actually sent the corresponding request.
// Give this request a unique ID so we can do that validation.
Expand Down Expand Up @@ -244,6 +258,17 @@ func (cr *ChainRouter) HandleInbound(ctx context.Context, msg message.InboundMes
cr.lock.Lock()
defer cr.lock.Unlock()

if cr.closing {
cr.log.Debug("dropping message",
zap.Stringer("messageOp", op),
zap.Stringer("nodeID", nodeID),
zap.Stringer("chainID", destinationChainID),
zap.Error(errClosing),
)
msg.OnFinishedHandling()
return
}

// Get the chain, if it exists
chain, exists := cr.chainHandlers[destinationChainID]
if !exists {
Expand Down Expand Up @@ -356,6 +381,7 @@ func (cr *ChainRouter) Shutdown(ctx context.Context) {
cr.lock.Lock()
prevChains := cr.chainHandlers
cr.chainHandlers = map[ids.ID]handler.Handler{}
cr.closing = true
cr.lock.Unlock()

for _, chain := range prevChains {
Expand Down Expand Up @@ -388,6 +414,13 @@ func (cr *ChainRouter) AddChain(ctx context.Context, chain handler.Handler) {
defer cr.lock.Unlock()

chainID := chain.Context().ChainID
if cr.closing {
cr.log.Debug("dropping add chain request",
zap.Stringer("chainID", chainID),
zap.Error(errClosing),
)
return
}
cr.log.Debug("registering chain with chain router",
zap.Stringer("chainID", chainID),
)
Expand Down Expand Up @@ -446,6 +479,14 @@ func (cr *ChainRouter) Connected(nodeID ids.NodeID, nodeVersion *version.Applica
cr.lock.Lock()
defer cr.lock.Unlock()

if cr.closing {
cr.log.Debug("dropping connected message",
zap.Stringer("nodeID", nodeID),
zap.Error(errClosing),
)
return
}

connectedPeer, exists := cr.peers[nodeID]
if !exists {
connectedPeer = &peer{
Expand Down Expand Up @@ -493,6 +534,14 @@ func (cr *ChainRouter) Disconnected(nodeID ids.NodeID) {
cr.lock.Lock()
defer cr.lock.Unlock()

if cr.closing {
cr.log.Debug("dropping disconnected message",
zap.Stringer("nodeID", nodeID),
zap.Error(errClosing),
)
return
}

peer := cr.peers[nodeID]
delete(cr.peers, nodeID)
if _, benched := cr.benched[nodeID]; benched {
Expand Down Expand Up @@ -522,6 +571,15 @@ func (cr *ChainRouter) Benched(chainID ids.ID, nodeID ids.NodeID) {
cr.lock.Lock()
defer cr.lock.Unlock()

if cr.closing {
cr.log.Debug("dropping benched message",
zap.Stringer("nodeID", nodeID),
zap.Stringer("chainID", chainID),
zap.Error(errClosing),
)
return
}

benchedChains, exists := cr.benched[nodeID]
benchedChains.Add(chainID)
cr.benched[nodeID] = benchedChains
Expand Down Expand Up @@ -554,6 +612,15 @@ func (cr *ChainRouter) Unbenched(chainID ids.ID, nodeID ids.NodeID) {
cr.lock.Lock()
defer cr.lock.Unlock()

if cr.closing {
cr.log.Debug("dropping unbenched message",
zap.Stringer("nodeID", nodeID),
zap.Stringer("chainID", chainID),
zap.Error(errClosing),
)
return
}

benchedChains := cr.benched[nodeID]
benchedChains.Remove(chainID)
if benchedChains.Len() != 0 {
Expand Down
111 changes: 111 additions & 0 deletions snow/networking/router/chain_router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,117 @@ func TestShutdown(t *testing.T) {
require.Less(shutdownDuration, 250*time.Millisecond)
}

func TestConnectedAfterShutdownErrorLogRegression(t *testing.T) {
require := require.New(t)

snowCtx := snowtest.Context(t, snowtest.PChainID)
chainCtx := snowtest.ConsensusContext(snowCtx)

chainRouter := ChainRouter{}
require.NoError(chainRouter.Initialize(
ids.EmptyNodeID,
logging.NoWarn{}, // If an error log is emitted, the test will fail
nil,
time.Second,
set.Set[ids.ID]{},
true,
set.Set[ids.ID]{},
nil,
HealthConfig{},
prometheus.NewRegistry(),
))

resourceTracker, err := tracker.NewResourceTracker(
prometheus.NewRegistry(),
resource.NoUsage,
meter.ContinuousFactory{},
time.Second,
)
require.NoError(err)

p2pTracker, err := p2p.NewPeerTracker(
logging.NoLog{},
"",
prometheus.NewRegistry(),
nil,
version.CurrentApp,
)
require.NoError(err)

h, err := handler.New(
chainCtx,
nil,
nil,
time.Second,
testThreadPoolSize,
resourceTracker,
validators.UnhandledSubnetConnector,
subnets.New(chainCtx.NodeID, subnets.Config{}),
commontracker.NewPeers(),
p2pTracker,
prometheus.NewRegistry(),
)
require.NoError(err)

engine := common.EngineTest{
T: t,
StartF: func(context.Context, uint32) error {
return nil
},
ContextF: func() *snow.ConsensusContext {
return chainCtx
},
HaltF: func(context.Context) {},
ShutdownF: func(context.Context) error {
return nil
},
ConnectedF: func(context.Context, ids.NodeID, *version.Application) error {
return nil
},
}
engine.Default(true)
engine.CantGossip = false

bootstrapper := &common.BootstrapperTest{
EngineTest: engine,
CantClear: true,
}

h.SetEngineManager(&handler.EngineManager{
Avalanche: &handler.Engine{
StateSyncer: nil,
Bootstrapper: bootstrapper,
Consensus: &engine,
},
Snowman: &handler.Engine{
StateSyncer: nil,
Bootstrapper: bootstrapper,
Consensus: &engine,
},
})
chainCtx.State.Set(snow.EngineState{
Type: engineType,
State: snow.NormalOp, // assumed bootstrapping is done
})

chainRouter.AddChain(context.Background(), h)

h.Start(context.Background(), false)

chainRouter.Shutdown(context.Background())

shutdownDuration, err := h.AwaitStopped(context.Background())
require.NoError(err)
require.GreaterOrEqual(shutdownDuration, time.Duration(0))

// Calling connected after shutdown should result in an error log.
chainRouter.Connected(
ids.GenerateTestNodeID(),
version.CurrentApp,
ids.GenerateTestID(),
)
}

func TestShutdownTimesOut(t *testing.T) {
require := require.New(t)

Expand Down

0 comments on commit ae35eeb

Please sign in to comment.