diff --git a/integration_testing/connection_controller_routines_test.go b/integration_testing/connection_controller_routines_test.go index 67bd28894..521df04a6 100644 --- a/integration_testing/connection_controller_routines_test.go +++ b/integration_testing/connection_controller_routines_test.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/deso-protocol/core/bls" "github.com/deso-protocol/core/cmd" + "github.com/deso-protocol/core/collections" "github.com/deso-protocol/core/lib" "github.com/stretchr/testify/require" "testing" @@ -76,11 +77,11 @@ func TestConnectionControllerInitiatePersistentConnections(t *testing.T) { waitForCountRemoteNodeIndexer(t, node3, 1, 1, 0, 0) waitForCountRemoteNodeIndexer(t, node4, 2, 2, 0, 0) waitForCountRemoteNodeIndexer(t, node5, 2, 2, 0, 0) - node6.Stop() node2.Stop() node3.Stop() node4.Stop() node5.Stop() + node6.Stop() t.Logf("Test #2 passed | Successfully run validator node6 with --connect-ips set to node2, node3, node4, node5") } @@ -254,8 +255,8 @@ func TestConnectionControllerValidatorConnector(t *testing.T) { func TestConnectionControllerNonValidatorConnector(t *testing.T) { require := require.New(t) - // Spawn 4 non-validators node1, node2, node3, node4. Set node1's targetOutboundPeers to 3. Then make node1 - // create outbound connections to node2, node3, and node4, as well as 2 attempted persistent connections to + // Spawn 6 non-validators node1, node2, node3, node4, node5, node6. Set node1's targetOutboundPeers to 3. Then make + // node1 create outbound connections to node2, node3, and node4, as well as 2 attempted persistent connections to // non-existing ips. node1 := spawnNonValidatorNodeProtocol2(t, 18000, "node1") node1.Config.TargetOutboundPeers = 3 @@ -285,27 +286,41 @@ func TestConnectionControllerNonValidatorConnector(t *testing.T) { waitForCountRemoteNodeIndexer(t, node1, 3, 0, 3, 0) } +func TestConnectionControllerNonValidatorCircularConnectIps(t *testing.T) { + node1 := spawnNonValidatorNodeProtocol2(t, 18000, "node1") + node2 := spawnNonValidatorNodeProtocol2(t, 18001, "node2") + + node1.Config.ConnectIPs = []string{"127.0.0.1:18001"} + node2.Config.ConnectIPs = []string{"127.0.0.1:18000"} + + node1 = startNode(t, node1) + node2 = startNode(t, node2) + defer node1.Stop() + defer node2.Stop() + + waitForCountRemoteNodeIndexer(t, node1, 2, 0, 1, 1) + waitForCountRemoteNodeIndexer(t, node2, 2, 0, 1, 1) +} + func setGetActiveValidatorImplWithValidatorNodes(t *testing.T, validators ...*cmd.Node) { require := require.New(t) - var err error - mapping := make(map[bls.SerializedPublicKey]*lib.ValidatorEntry) + mapping := collections.NewConcurrentMap[bls.SerializedPublicKey, *lib.ValidatorEntry]() for _, validator := range validators { seed := validator.Config.PosValidatorSeed if seed == "" { t.Fatalf("Validator node %s does not have a PosValidatorSeed set", validator.Params.UserAgent) } - blsPriv := &bls.PrivateKey{} - blsPriv, err = blsPriv.FromString(seed) + keystore, err := lib.NewBLSKeystore(seed) require.NoError(err) - mapping[blsPriv.PublicKey().Serialize()] = createSimpleValidatorEntry(validator) + mapping.Set(keystore.GetSigner().GetPublicKey().Serialize(), createSimpleValidatorEntry(validator)) } - setGetActiveValidatorImpl(func() map[bls.SerializedPublicKey]*lib.ValidatorEntry { + setGetActiveValidatorImpl(func() *collections.ConcurrentMap[bls.SerializedPublicKey, *lib.ValidatorEntry] { return mapping }) } -func setGetActiveValidatorImpl(mapping func() map[bls.SerializedPublicKey]*lib.ValidatorEntry) { +func setGetActiveValidatorImpl(mapping func() *collections.ConcurrentMap[bls.SerializedPublicKey, *lib.ValidatorEntry]) { lib.GetActiveValidatorImpl = mapping } @@ -348,11 +363,8 @@ func waitForMinNonValidatorCountRemoteNodeIndexer(t *testing.T, node *cmd.Node, userAgent := node.Params.UserAgent rnManager := node.Server.GetConnectionController().GetRemoteNodeManager() condition := func() bool { - if true != checkRemoteNodeIndexerMinNonValidatorCount(rnManager, allCount, validatorCount, - minNonValidatorOutboundCount, minNonValidatorInboundCount) { - return false - } - return true + return checkRemoteNodeIndexerMinNonValidatorCount(rnManager, allCount, validatorCount, + minNonValidatorOutboundCount, minNonValidatorInboundCount) } waitForCondition(t, fmt.Sprintf("Waiting for Node (%s) to have at least %d non-validator outbound nodes and %d non-validator inbound nodes", userAgent, minNonValidatorOutboundCount, minNonValidatorInboundCount), condition) diff --git a/integration_testing/connection_controller_utils_test.go b/integration_testing/connection_controller_utils_test.go index 5dfbb7414..f4a46df75 100644 --- a/integration_testing/connection_controller_utils_test.go +++ b/integration_testing/connection_controller_utils_test.go @@ -7,6 +7,7 @@ import ( "github.com/deso-protocol/core/lib" "os" "testing" + "time" ) func waitForValidatorConnection(t *testing.T, node1 *cmd.Node, node2 *cmd.Node) { @@ -26,7 +27,7 @@ func waitForValidatorConnection(t *testing.T, node1 *cmd.Node, node2 *cmd.Node) } return true } - waitForCondition(t, fmt.Sprintf("Waiting for Node (%s) to connect to outbound non-validator Node (%s)", userAgentN1, userAgentN2), n1ValidatedN2) + waitForCondition(t, fmt.Sprintf("Waiting for Node (%s) to connect to validator Node (%s)", userAgentN1, userAgentN2), n1ValidatedN2) } func waitForNonValidatorOutboundConnection(t *testing.T, node1 *cmd.Node, node2 *cmd.Node) { @@ -222,5 +223,7 @@ func spawnValidatorNodeProtocol2(t *testing.T, port uint32, id string, blsPriv * node := cmd.NewNode(config) node.Params.UserAgent = id node.Params.ProtocolVersion = lib.ProtocolVersion2 + node.Params.VersionNegotiationTimeout = 1 * time.Second + node.Params.VerackNegotiationTimeout = 1 * time.Second return node } diff --git a/lib/connection_controller.go b/lib/connection_controller.go index 1257cb189..d15c7bcd0 100644 --- a/lib/connection_controller.go +++ b/lib/connection_controller.go @@ -5,6 +5,7 @@ import ( "github.com/btcsuite/btcd/addrmgr" "github.com/btcsuite/btcd/wire" "github.com/deso-protocol/core/bls" + "github.com/deso-protocol/core/collections" "github.com/golang/glog" "github.com/pkg/errors" "net" @@ -13,12 +14,12 @@ import ( "time" ) -type GetActiveValidatorsFunc func() map[bls.SerializedPublicKey]*ValidatorEntry +type GetActiveValidatorsFunc func() *collections.ConcurrentMap[bls.SerializedPublicKey, *ValidatorEntry] var GetActiveValidatorImpl GetActiveValidatorsFunc = BasicGetActiveValidators -func BasicGetActiveValidators() map[bls.SerializedPublicKey]*ValidatorEntry { - return make(map[bls.SerializedPublicKey]*ValidatorEntry) +func BasicGetActiveValidators() *collections.ConcurrentMap[bls.SerializedPublicKey, *ValidatorEntry] { + return collections.NewConcurrentMap[bls.SerializedPublicKey, *ValidatorEntry]() } // ConnectionController is a structure that oversees all connections to remote nodes. It is responsible for kicking off @@ -78,14 +79,15 @@ func NewConnectionController(params *DeSoParams, cmgr *ConnectionManager, handsh } func (cc *ConnectionController) Start() { - cc.startGroup.Add(2) + cc.startGroup.Add(3) cc.initiatePersistentConnections() // Start the validator connector go cc.startValidatorConnector() go cc.startNonValidatorConnector() + go cc.startRemoteNodeCleanup() cc.startGroup.Wait() - cc.exitGroup.Add(2) + cc.exitGroup.Add(3) } func (cc *ConnectionController) Stop() { @@ -151,6 +153,21 @@ func (cc *ConnectionController) startNonValidatorConnector() { } } +func (cc *ConnectionController) startRemoteNodeCleanup() { + cc.startGroup.Done() + + for { + select { + case <-cc.exitChan: + cc.exitGroup.Done() + return + case <-time.After(1 * time.Second): + //cc.rnManager.Cleanup() + } + } + +} + // ########################### // ## Handlers (Peer, DeSoMessage) // ########################### @@ -199,7 +216,7 @@ func (cc *ConnectionController) _handleNewConnectionMessage(origin *Peer, desoMs remoteNode, err = cc.processInboundConnection(msg.Connection) if err != nil { glog.Errorf("ConnectionController.handleNewConnectionMessage: Problem handling inbound connection: %v", err) - msg.Connection.Close() + cc.cleanupFailedInboundConnection(remoteNode, msg.Connection) return } case ConnectionTypeOutbound: @@ -215,6 +232,13 @@ func (cc *ConnectionController) _handleNewConnectionMessage(origin *Peer, desoMs cc.handshake.InitiateHandshake(remoteNode) } +func (cc *ConnectionController) cleanupFailedInboundConnection(remoteNode *RemoteNode, connection Connection) { + if remoteNode != nil { + cc.rnManager.Disconnect(remoteNode) + } + connection.Close() +} + func (cc *ConnectionController) cleanupFailedOutboundConnection(connection Connection) { oc, ok := connection.(*outboundConnection) if !ok { @@ -226,6 +250,7 @@ func (cc *ConnectionController) cleanupFailedOutboundConnection(connection Conne if rn != nil { cc.rnManager.Disconnect(rn) } + oc.Close() cc.cmgr.RemoveAttemptedOutboundAddrs(oc.address) } @@ -235,14 +260,14 @@ func (cc *ConnectionController) cleanupFailedOutboundConnection(connection Conne // refreshValidatorIndex re-indexes validators based on the activeValidatorsMap. It is called periodically by the // validator connector. -func (cc *ConnectionController) refreshValidatorIndex(activeValidatorsMap map[bls.SerializedPublicKey]*ValidatorEntry) { +func (cc *ConnectionController) refreshValidatorIndex(activeValidatorsMap *collections.ConcurrentMap[bls.SerializedPublicKey, *ValidatorEntry]) { // De-index inactive validators. We skip any checks regarding RemoteNodes connection status, nor do we verify whether // de-indexing the validator would result in an excess number of outbound/inbound connections. Any excess connections // will be cleaned up by the peer connector. validatorRemoteNodeMap := cc.rnManager.GetValidatorIndex().Copy() for pk, rn := range validatorRemoteNodeMap { // If the validator is no longer active, de-index it. - if _, ok := activeValidatorsMap[pk]; !ok { + if _, ok := activeValidatorsMap.Get(pk); !ok { cc.rnManager.SetNonValidator(rn) cc.rnManager.UnsetValidator(rn) } @@ -267,7 +292,7 @@ func (cc *ConnectionController) refreshValidatorIndex(activeValidatorsMap map[bl } // If the RemoteNode turns out to be in the validator set, index it. - if _, ok := activeValidatorsMap[pk.Serialize()]; ok { + if _, ok := activeValidatorsMap.Get(pk.Serialize()); ok { cc.rnManager.SetValidator(rn) cc.rnManager.UnsetNonValidator(rn) } @@ -276,13 +301,14 @@ func (cc *ConnectionController) refreshValidatorIndex(activeValidatorsMap map[bl // connectValidators attempts to connect to all active validators that are not already connected. It is called // periodically by the validator connector. -func (cc *ConnectionController) connectValidators(activeValidatorsMap map[bls.SerializedPublicKey]*ValidatorEntry) { +func (cc *ConnectionController) connectValidators(activeValidatorsMap *collections.ConcurrentMap[bls.SerializedPublicKey, *ValidatorEntry]) { // Look through the active validators and connect to any that we're not already connected to. if cc.blsKeystore == nil { return } - for pk, validator := range activeValidatorsMap { + validators := activeValidatorsMap.Copy() + for pk, validator := range validators { _, exists := cc.rnManager.GetValidatorIndex().Get(pk) // If we're already connected to the validator, continue. if exists { @@ -523,30 +549,41 @@ func (cc *ConnectionController) processOutboundConnection(conn Connection) (*Rem cc.AddrMgr.Good(oc.address) } - // if this is a non-persistent outbound peer, and we already have enough outbound peers, then don't bother adding this one. - if !oc.isPersistent && cc.enoughNonValidatorOutboundConnections() { - return nil, fmt.Errorf("ConnectionController.handleOutboundConnection: Connected to maximum number of outbound "+ - "peers (%d)", cc.targetOutboundPeers) - } - - // If this is a non-persistent outbound peer and the group key overlaps with another peer we're already connected to then - // abort mission. We only connect to one peer per IP group in order to prevent Sybil attacks. - if !oc.isPersistent && cc.cmgr.IsFromRedundantOutboundIPAddress(oc.address) { - return nil, fmt.Errorf("ConnectionController.handleOutboundConnection: Rejecting OUTBOUND NON-PERSISTENT "+ - "connection with redundant group key (%s).", addrmgr.GroupKey(oc.address)) - } - na, err := cc.ConvertIPStringToNetAddress(oc.connection.RemoteAddr().String()) if err != nil { return nil, errors.Wrapf(err, "ConnectionController.handleOutboundConnection: Problem calling ipToNetAddr "+ "for addr: (%s)", oc.connection.RemoteAddr().String()) } + // Attach the connection before additional validation steps because it is already established. remoteNode, err := cc.rnManager.AttachOutboundConnection(oc.connection, na, oc.attemptId, oc.isPersistent) if remoteNode == nil || err != nil { return nil, errors.Wrapf(err, "ConnectionController.handleOutboundConnection: Problem calling rnManager.AttachOutboundConnection "+ "for addr: (%s)", oc.connection.RemoteAddr().String()) } + + // If this is a persistent remote node or a validator, we don't need to do any extra connection validation. + if remoteNode.IsPersistent() || remoteNode.GetValidatorPublicKey() != nil { + return remoteNode, nil + } + + // If we get here, it means we're dealing with a non-persistent or non-validator remote node. We perform additional + // connection validation. + + // If we already have enough outbound peers, then don't bother adding this one. + if cc.enoughNonValidatorOutboundConnections() { + return nil, fmt.Errorf("ConnectionController.handleOutboundConnection: Connected to maximum number of outbound "+ + "peers (%d)", cc.targetOutboundPeers) + } + + // If the group key overlaps with another peer we're already connected to then abort mission. We only connect to + // one peer per IP group in order to prevent Sybil attacks. + if cc.cmgr.IsFromRedundantOutboundIPAddress(oc.address) { + return nil, fmt.Errorf("ConnectionController.handleOutboundConnection: Rejecting OUTBOUND NON-PERSISTENT "+ + "connection with redundant group key (%s).", addrmgr.GroupKey(oc.address)) + } + cc.cmgr.AddToGroupKey(na) + return remoteNode, nil } diff --git a/lib/connection_manager.go b/lib/connection_manager.go index f1b1ca52d..1ba4bf8f1 100644 --- a/lib/connection_manager.go +++ b/lib/connection_manager.go @@ -168,6 +168,10 @@ func NewConnectionManager( // Check if the address passed shares a group with any addresses already in our data structures. func (cmgr *ConnectionManager) IsFromRedundantOutboundIPAddress(na *wire.NetAddress) bool { groupKey := addrmgr.GroupKey(na) + // For the sake of running multiple nodes on the same machine, we allow localhost connections. + if groupKey == "local" { + return false + } cmgr.mtxOutboundConnIPGroups.Lock() numGroupsForKey := cmgr.outboundConnIPGroups[groupKey] @@ -185,7 +189,7 @@ func (cmgr *ConnectionManager) IsFromRedundantOutboundIPAddress(na *wire.NetAddr return true } -func (cmgr *ConnectionManager) addToGroupKey(na *wire.NetAddress) { +func (cmgr *ConnectionManager) AddToGroupKey(na *wire.NetAddress) { groupKey := addrmgr.GroupKey(na) cmgr.mtxOutboundConnIPGroups.Lock() @@ -429,7 +433,6 @@ func (cmgr *ConnectionManager) addPeer(pp *Peer) { // number of outbound peers. Also add the peer's address to // our map. if _, ok := peerList[pp.ID]; !ok { - cmgr.addToGroupKey(pp.netAddr) atomic.AddUint32(&cmgr.numOutboundPeers, 1) cmgr.mtxAddrsMaps.Lock() @@ -528,16 +531,6 @@ func (cmgr *ConnectionManager) _logOutboundPeerData() { numInboundPeers := int(atomic.LoadUint32(&cmgr.numInboundPeers)) numPersistentPeers := int(atomic.LoadUint32(&cmgr.numPersistentPeers)) glog.V(1).Infof("Num peers: OUTBOUND(%d) INBOUND(%d) PERSISTENT(%d)", numOutboundPeers, numInboundPeers, numPersistentPeers) - - cmgr.mtxOutboundConnIPGroups.Lock() - for _, vv := range cmgr.outboundConnIPGroups { - if vv != 0 && vv != 1 { - glog.V(1).Infof("_logOutboundPeerData: Peer group count != (0 or 1). "+ - "Is (%d) instead. This "+ - "should never happen.", vv) - } - } - cmgr.mtxOutboundConnIPGroups.Unlock() } func (cmgr *ConnectionManager) AddTimeSample(addrStr string, timeSample time.Time) { diff --git a/lib/network_connection.go b/lib/network_connection.go index eb6d4ab55..ffb0bb1f1 100644 --- a/lib/network_connection.go +++ b/lib/network_connection.go @@ -33,7 +33,9 @@ func (oc *outboundConnection) Close() { if oc.terminated { return } - oc.connection.Close() + if oc.connection != nil { + oc.connection.Close() + } oc.terminated = true } @@ -58,7 +60,9 @@ func (ic *inboundConnection) Close() { return } - ic.connection.Close() + if ic.connection != nil { + ic.connection.Close() + } ic.terminated = true } diff --git a/lib/remote_node.go b/lib/remote_node.go index 5ba651f3f..4ca6f8c12 100644 --- a/lib/remote_node.go +++ b/lib/remote_node.go @@ -219,6 +219,14 @@ func (rn *RemoteNode) IsConnected() bool { return rn.connectionStatus == RemoteNodeStatus_Connected } +func (rn *RemoteNode) IsVersionSent() bool { + return rn.connectionStatus == RemoteNodeStatus_VersionSent +} + +func (rn *RemoteNode) IsVerackSent() bool { + return rn.connectionStatus == RemoteNodeStatus_VerackSent +} + func (rn *RemoteNode) IsHandshakeCompleted() bool { return rn.connectionStatus == RemoteNodeStatus_HandshakeCompleted } @@ -344,9 +352,9 @@ func (rn *RemoteNode) InitiateHandshake(nonce uint64) error { return fmt.Errorf("InitiateHandshake: Remote node is not connected") } + versionTimeExpected := time.Now().Add(rn.params.VersionNegotiationTimeout) + rn.versionTimeExpected = &versionTimeExpected if rn.GetPeer().IsOutbound() { - versionTimeExpected := time.Now().Add(rn.params.VersionNegotiationTimeout) - rn.versionTimeExpected = &versionTimeExpected if err := rn.sendVersionMessage(nonce); err != nil { return fmt.Errorf("InitiateHandshake: Problem sending version message to peer (id= %d): %v", rn.id, err) } @@ -397,6 +405,16 @@ func (rn *RemoteNode) newVersionMessage(nonce uint64) *MsgDeSoVersion { return ver } +func (rn *RemoteNode) IsTimedOut() bool { + if rn.IsConnected() || rn.IsVersionSent() { + return rn.versionTimeExpected.Before(time.Now()) + } + if rn.IsVerackSent() { + return rn.verackTimeExpected.Before(time.Now()) + } + return false +} + // HandleVersionMessage is called upon receiving a version message from the RemoteNode's peer. The peer may be the one // initiating the handshake, in which case, we should respond with our own version message. To do this, we pass the // responseNonce to this function, which we will use in our response version message. @@ -404,7 +422,7 @@ func (rn *RemoteNode) HandleVersionMessage(verMsg *MsgDeSoVersion, responseNonce rn.mtx.Lock() defer rn.mtx.Unlock() - if rn.connectionStatus != RemoteNodeStatus_Connected && rn.connectionStatus != RemoteNodeStatus_VersionSent { + if !rn.IsConnected() && !rn.IsVersionSent() { return fmt.Errorf("HandleVersionMessage: RemoteNode is not connected or version exchange has already "+ "been completed, connectionStatus: %v", rn.connectionStatus) } @@ -416,7 +434,7 @@ func (rn *RemoteNode) HandleVersionMessage(verMsg *MsgDeSoVersion, responseNonce } // Verify that the peer's version message is sent within the version negotiation timeout. - if rn.versionTimeExpected != nil && rn.versionTimeExpected.Before(time.Now()) { + if rn.versionTimeExpected.Before(time.Now()) { return fmt.Errorf("RemoteNode.HandleVersionMessage: Requesting disconnect for id: (%v) "+ "version timeout. Time expected: %v, now: %v", rn.id, rn.versionTimeExpected.UnixMicro(), time.Now().UnixMicro()) } diff --git a/lib/remote_node_manager.go b/lib/remote_node_manager.go index 75d29d897..bff73a985 100644 --- a/lib/remote_node_manager.go +++ b/lib/remote_node_manager.go @@ -126,6 +126,17 @@ func (manager *RemoteNodeManager) SendMessage(rn *RemoteNode, desoMessage DeSoMe return rn.SendMessage(desoMessage) } +func (manager *RemoteNodeManager) Cleanup() { + manager.mtx.Lock() + defer manager.mtx.Unlock() + + for _, rn := range manager.GetAllRemoteNodes().GetAll() { + if rn.IsTimedOut() { + manager.Disconnect(rn) + } + } +} + // ########################### // ## Create RemoteNode // ########################### @@ -140,7 +151,7 @@ func (manager *RemoteNodeManager) CreateValidatorConnection(netAddr *wire.NetAdd } remoteNode := manager.newRemoteNode(publicKey) - if err := remoteNode.DialPersistentOutboundConnection(netAddr); err != nil { + if err := remoteNode.DialOutboundConnection(netAddr); err != nil { return errors.Wrapf(err, "RemoteNodeManager.CreateValidatorConnection: Problem calling DialPersistentOutboundConnection "+ "for addr: (%s:%v)", netAddr.IP.String(), netAddr.Port) } @@ -184,7 +195,7 @@ func (manager *RemoteNodeManager) AttachInboundConnection(conn net.Conn, remoteNode := manager.newRemoteNode(nil) if err := remoteNode.AttachInboundConnection(conn, na); err != nil { - return nil, errors.Wrapf(err, "RemoteNodeManager.AttachInboundConnection: Problem calling AttachInboundConnection "+ + return remoteNode, errors.Wrapf(err, "RemoteNodeManager.AttachInboundConnection: Problem calling AttachInboundConnection "+ "for addr: (%s)", conn.RemoteAddr().String()) } diff --git a/lib/server.go b/lib/server.go index 4d5138f77..a4fa28376 100644 --- a/lib/server.go +++ b/lib/server.go @@ -2547,6 +2547,9 @@ func (srv *Server) Stop() { srv.cmgr.Stop() glog.Infof(CLog(Yellow, "Server.Stop: Closed the ConnectionManger")) + srv.connectionController.Stop() + glog.Infof(CLog(Yellow, "Server.Stop: Closed the ConnectionController")) + // Stop the miner if we have one running. if srv.miner != nil { srv.miner.Stop()