diff --git a/integration_testing/connection_controller_test.go b/integration_testing/connection_controller_test.go index f550ff42d..102d56587 100644 --- a/integration_testing/connection_controller_test.go +++ b/integration_testing/connection_controller_test.go @@ -288,6 +288,7 @@ func TestConnectionControllerHandshakeTimeouts(t *testing.T) { t.Logf("Test #1 passed | Successfuly disconnected node after version negotiation timeout") // Now let's try timing out verack exchange + node1.Params.VersionNegotiationTimeout = lib.DeSoTestnetParams.VersionNegotiationTimeout dbDir3 := getDirectory(t) defer os.RemoveAll(dbDir3) config3 := generateConfig(t, 18002, dbDir3, 10) @@ -304,6 +305,40 @@ func TestConnectionControllerHandshakeTimeouts(t *testing.T) { waitForEmptyRemoteNodeIndexer(t, node1) waitForEmptyRemoteNodeIndexer(t, node3) t.Logf("Test #2 passed | Successfuly disconnected node after verack exchange timeout") + + // Now let's try timing out handshake between two validators node4 and node5 + dbDir4 := getDirectory(t) + defer os.RemoveAll(dbDir4) + config4 := generateConfig(t, 18003, dbDir4, 10) + config4.SyncType = lib.NodeSyncTypeBlockSync + blsPriv4, err := bls.NewPrivateKey() + require.NoError(err) + config4.PosValidatorSeed = blsPriv4.ToString() + node4 := cmd.NewNode(config4) + node4.Params.UserAgent = "Node4" + node4.Params.ProtocolVersion = lib.ProtocolVersion2 + node4.Params.HandshakeTimeoutMicroSeconds = 0 + node4 = startNode(t, node4) + defer node4.Stop() + + dbDir5 := getDirectory(t) + defer os.RemoveAll(dbDir5) + config5 := generateConfig(t, 18004, dbDir5, 10) + config5.SyncType = lib.NodeSyncTypeBlockSync + blsPriv5, err := bls.NewPrivateKey() + require.NoError(err) + config5.PosValidatorSeed = blsPriv5.ToString() + node5 := cmd.NewNode(config5) + node5.Params.UserAgent = "Node5" + node5.Params.ProtocolVersion = lib.ProtocolVersion2 + node5 = startNode(t, node5) + defer node5.Stop() + + cc = node4.Server.GetConnectionController() + require.NoError(cc.CreateValidatorConnection(node5.Listeners[0].Addr().String(), blsPriv5.PublicKey())) + waitForEmptyRemoteNodeIndexer(t, node4) + waitForEmptyRemoteNodeIndexer(t, node5) + t.Logf("Test #3 passed | Successfuly disconnected validator node after handshake timeout") } func TestConnectionControllerValidatorDuplication(t *testing.T) { @@ -400,6 +435,106 @@ func TestConnectionControllerValidatorDuplication(t *testing.T) { t.Logf("Test #2 passed | Successfuly rejected duplicate validator connection with multiple outbound validators") } +func TestConnectionControllerProtocolDifference(t *testing.T) { + require := require.New(t) + + // Create a ProtocolVersion1 Node1 + dbDir1 := getDirectory(t) + defer os.RemoveAll(dbDir1) + config1 := generateConfig(t, 18000, dbDir1, 10) + config1.SyncType = lib.NodeSyncTypeBlockSync + node1 := cmd.NewNode(config1) + node1.Params.UserAgent = "Node1" + node1.Params.ProtocolVersion = lib.ProtocolVersion1 + node1 = startNode(t, node1) + defer node1.Stop() + + // Create a ProtocolVersion2 NonValidator Node2 + dbDir2 := getDirectory(t) + defer os.RemoveAll(dbDir2) + config2 := generateConfig(t, 18001, dbDir2, 10) + config2.SyncType = lib.NodeSyncTypeBlockSync + node2 := cmd.NewNode(config2) + node2.Params.UserAgent = "Node2" + node2.Params.ProtocolVersion = lib.ProtocolVersion2 + node2 = startNode(t, node2) + + // Create non-validator connection from Node1 to Node2 + cc := node1.Server.GetConnectionController() + require.NoError(cc.CreateNonValidatorOutboundConnection(node2.Listeners[0].Addr().String())) + waitForNonValidatorOutboundConnection(t, node1, node2) + waitForNonValidatorInboundConnection(t, node2, node1) + t.Logf("Test #1 passed | Successfuly connected to a ProtocolVersion1 node with a ProtocolVersion2 non-validator") + + // Create a ProtocolVersion2 Validator Node3 + dbDir3 := getDirectory(t) + defer os.RemoveAll(dbDir3) + config3 := generateConfig(t, 18002, dbDir3, 10) + config3.SyncType = lib.NodeSyncTypeBlockSync + blsPriv3, err := bls.NewPrivateKey() + require.NoError(err) + config3.PosValidatorSeed = blsPriv3.ToString() + node3 := cmd.NewNode(config3) + node3.Params.UserAgent = "Node3" + node3.Params.ProtocolVersion = lib.ProtocolVersion2 + node3 = startNode(t, node3) + + // Create validator connection from Node1 to Node3 + require.NoError(cc.CreateValidatorConnection(node3.Listeners[0].Addr().String(), blsPriv3.PublicKey())) + waitForValidatorConnection(t, node1, node3) + waitForNonValidatorInboundConnection(t, node3, node1) + t.Logf("Test #2 passed | Successfuly connected to a ProtocolVersion1 node with a ProtocolVersion2 validator") + + node2.Stop() + node3.Stop() + waitForEmptyRemoteNodeIndexer(t, node1) + + // Create a ProtocolVersion2 validator Node4 + dbDir4 := getDirectory(t) + defer os.RemoveAll(dbDir4) + config4 := generateConfig(t, 18003, dbDir4, 10) + config4.SyncType = lib.NodeSyncTypeBlockSync + blsPriv4, err := bls.NewPrivateKey() + require.NoError(err) + config4.PosValidatorSeed = blsPriv4.ToString() + node4 := cmd.NewNode(config4) + node4.Params.UserAgent = "Node4" + node4.Params.ProtocolVersion = lib.ProtocolVersion2 + node4 = startNode(t, node4) + defer node4.Stop() + + // Attempt to create non-validator connection from Node4 to Node1 + cc = node4.Server.GetConnectionController() + require.NoError(cc.CreateNonValidatorOutboundConnection(node1.Listeners[0].Addr().String())) + waitForEmptyRemoteNodeIndexer(t, node4) + waitForEmptyRemoteNodeIndexer(t, node1) + t.Logf("Test #3 passed | Successfuly rejected outbound connection from ProtocolVersion2 node to ProtcolVersion1 node") + + // Attempt to create validator connection from Node4 to Node1 + require.NoError(cc.CreateValidatorConnection(node1.Listeners[0].Addr().String(), blsPriv4.PublicKey())) + waitForEmptyRemoteNodeIndexer(t, node4) + waitForEmptyRemoteNodeIndexer(t, node1) + t.Logf("Test #4 passed | Successfuly rejected validator connection from ProtocolVersion2 node to ProtcolVersion1 node") + + // Create a ProtocolVersion2 non-validator Node5 + dbDir5 := getDirectory(t) + defer os.RemoveAll(dbDir5) + config5 := generateConfig(t, 18004, dbDir5, 10) + config5.SyncType = lib.NodeSyncTypeBlockSync + node5 := cmd.NewNode(config5) + node5.Params.UserAgent = "Node5" + node5.Params.ProtocolVersion = lib.ProtocolVersion2 + node5 = startNode(t, node5) + defer node5.Stop() + + // Attempt to create non-validator connection from Node5 to Node1 + cc = node5.Server.GetConnectionController() + require.NoError(cc.CreateNonValidatorOutboundConnection(node1.Listeners[0].Addr().String())) + waitForEmptyRemoteNodeIndexer(t, node5) + waitForEmptyRemoteNodeIndexer(t, node1) + t.Logf("Test #5 passed | Successfuly rejected outbound connection from ProtocolVersion2 node to ProtcolVersion1 node") +} + func waitForValidatorConnection(t *testing.T, node1 *cmd.Node, node2 *cmd.Node) { userAgentN1 := node1.Params.UserAgent userAgentN2 := node2.Params.UserAgent diff --git a/lib/remote_node_manager.go b/lib/remote_node_manager.go index fa5634385..fb269d072 100644 --- a/lib/remote_node_manager.go +++ b/lib/remote_node_manager.go @@ -8,6 +8,7 @@ import ( "github.com/golang/glog" "github.com/pkg/errors" "net" + "sync" "sync/atomic" ) @@ -15,6 +16,8 @@ import ( // and stopping remote node connections. It is also responsible for organizing the remote nodes into indices for easy // access, through the RemoteNodeIndexer. type RemoteNodeManager struct { + mtx sync.Mutex + // remoteNodeIndexer is a structure that stores and indexes all created remote nodes. remoteNodeIndexer *RemoteNodeIndexer @@ -90,6 +93,9 @@ func (manager *RemoteNodeManager) DisconnectById(id RemoteNodeId) { } func (manager *RemoteNodeManager) removeRemoteNodeFromIndexer(rn *RemoteNode) { + manager.mtx.Lock() + defer manager.mtx.Unlock() + if rn == nil { return } @@ -210,6 +216,9 @@ func (manager *RemoteNodeManager) AttachOutboundConnection(conn net.Conn, na *wi // ########################### func (manager *RemoteNodeManager) setRemoteNode(rn *RemoteNode) { + manager.mtx.Lock() + defer manager.mtx.Unlock() + if rn == nil { return } @@ -218,36 +227,39 @@ func (manager *RemoteNodeManager) setRemoteNode(rn *RemoteNode) { } func (manager *RemoteNodeManager) SetNonValidator(rn *RemoteNode) { + manager.mtx.Lock() + defer manager.mtx.Unlock() + if rn == nil { return } if rn.IsOutbound() { manager.GetNonValidatorOutboundIndex().Set(rn.GetId(), rn) - } else if rn.IsInbound() { - manager.GetNonValidatorInboundIndex().Set(rn.GetId(), rn) } else { - manager.Disconnect(rn) - return + manager.GetNonValidatorInboundIndex().Set(rn.GetId(), rn) } - - manager.UnsetValidator(rn) } func (manager *RemoteNodeManager) SetValidator(remoteNode *RemoteNode) { + manager.mtx.Lock() + defer manager.mtx.Unlock() + if remoteNode == nil { return } pk := remoteNode.GetValidatorPublicKey() if pk == nil { - manager.Disconnect(remoteNode) return } manager.GetValidatorIndex().Set(pk.Serialize(), remoteNode) } func (manager *RemoteNodeManager) UnsetValidator(remoteNode *RemoteNode) { + manager.mtx.Lock() + defer manager.mtx.Unlock() + if remoteNode == nil { return } @@ -260,17 +272,17 @@ func (manager *RemoteNodeManager) UnsetValidator(remoteNode *RemoteNode) { } func (manager *RemoteNodeManager) UnsetNonValidator(rn *RemoteNode) { + manager.mtx.Lock() + defer manager.mtx.Unlock() + if rn == nil { return } if rn.IsOutbound() { manager.GetNonValidatorOutboundIndex().Remove(rn.GetId()) - } else if rn.IsInbound() { - manager.GetNonValidatorInboundIndex().Remove(rn.GetId()) } else { - glog.Errorf("RemoteNodeManager.UnsetNonValidator: RemoteNode is not outbound or inbound. Disconnecting.") - manager.Disconnect(rn) + manager.GetNonValidatorInboundIndex().Remove(rn.GetId()) } }