diff --git a/integration_testing/connection_controller_routines_test.go b/integration_testing/connection_controller_routines_test.go new file mode 100644 index 000000000..582031bd8 --- /dev/null +++ b/integration_testing/connection_controller_routines_test.go @@ -0,0 +1,193 @@ +package integration_testing + +import ( + "fmt" + "github.com/deso-protocol/core/bls" + "github.com/deso-protocol/core/cmd" + "github.com/deso-protocol/core/collections" + "github.com/deso-protocol/core/lib" + "github.com/stretchr/testify/require" + "testing" +) + +func TestConnectionControllerInitiatePersistentConnections(t *testing.T) { + require := require.New(t) + t.Cleanup(func() { + setGetActiveValidatorImpl(lib.BasicGetActiveValidators) + }) + + // NonValidator Node1 will set its --connect-ips to two non-validators node2 and node3, + // and two validators node4 and node5. + node1 := spawnNonValidatorNodeProtocol2(t, 18000, "node1") + node2 := spawnNonValidatorNodeProtocol2(t, 18001, "node2") + node3 := spawnNonValidatorNodeProtocol2(t, 18002, "node3") + blsPriv4, err := bls.NewPrivateKey() + require.NoError(err) + node4 := spawnValidatorNodeProtocol2(t, 18003, "node4", blsPriv4) + blsPriv5, err := bls.NewPrivateKey() + require.NoError(err) + node5 := spawnValidatorNodeProtocol2(t, 18004, "node5", blsPriv5) + + node2 = startNode(t, node2) + node3 = startNode(t, node3) + node4 = startNode(t, node4) + node5 = startNode(t, node5) + + setGetActiveValidatorImplWithValidatorNodes(t, node4, node5) + + node1.Config.ConnectIPs = []string{ + node2.Listeners[0].Addr().String(), + node3.Listeners[0].Addr().String(), + node4.Listeners[0].Addr().String(), + node5.Listeners[0].Addr().String(), + } + node1 = startNode(t, node1) + waitForNonValidatorOutboundConnection(t, node1, node2) + waitForNonValidatorOutboundConnection(t, node1, node3) + waitForValidatorConnection(t, node1, node4) + waitForValidatorConnection(t, node1, node5) + waitForValidatorConnection(t, node4, node5) + waitForCountRemoteNodeIndexer(t, node1, 4, 2, 2, 0) + waitForCountRemoteNodeIndexer(t, node2, 1, 0, 0, 1) + waitForCountRemoteNodeIndexer(t, node3, 1, 0, 0, 1) + waitForCountRemoteNodeIndexer(t, node4, 2, 1, 0, 1) + waitForCountRemoteNodeIndexer(t, node5, 2, 1, 0, 1) + node1.Stop() + t.Logf("Test #1 passed | Successfully run non-validator node1 with --connect-ips set to node2, node3, node4, node5") + + // Now try again with a validator node6, with connect-ips set to node2, node3, node4, node5. + blsPriv6, err := bls.NewPrivateKey() + require.NoError(err) + node6 := spawnValidatorNodeProtocol2(t, 18005, "node6", blsPriv6) + node6.Config.ConnectIPs = []string{ + node2.Listeners[0].Addr().String(), + node3.Listeners[0].Addr().String(), + node4.Listeners[0].Addr().String(), + node5.Listeners[0].Addr().String(), + } + node6 = startNode(t, node6) + setGetActiveValidatorImplWithValidatorNodes(t, node4, node5, node6) + waitForNonValidatorOutboundConnection(t, node6, node2) + waitForNonValidatorOutboundConnection(t, node6, node3) + waitForValidatorConnection(t, node6, node4) + waitForValidatorConnection(t, node6, node5) + waitForValidatorConnection(t, node4, node5) + waitForCountRemoteNodeIndexer(t, node6, 4, 2, 2, 0) + waitForCountRemoteNodeIndexer(t, node2, 1, 1, 0, 0) + waitForCountRemoteNodeIndexer(t, node3, 1, 1, 0, 0) + waitForCountRemoteNodeIndexer(t, node4, 2, 2, 0, 0) + waitForCountRemoteNodeIndexer(t, node5, 2, 2, 0, 0) + node2.Stop() + node3.Stop() + node4.Stop() + node5.Stop() + node6.Stop() + t.Logf("Test #2 passed | Successfully run validator node6 with --connect-ips set to node2, node3, node4, node5") +} + +func TestConnectionControllerNonValidatorCircularConnectIps(t *testing.T) { + node1 := spawnNonValidatorNodeProtocol2(t, 18000, "node1") + node2 := spawnNonValidatorNodeProtocol2(t, 18001, "node2") + + node1.Config.ConnectIPs = []string{"127.0.0.1:18001"} + node2.Config.ConnectIPs = []string{"127.0.0.1:18000"} + + node1 = startNode(t, node1) + node2 = startNode(t, node2) + defer node1.Stop() + defer node2.Stop() + + waitForCountRemoteNodeIndexer(t, node1, 2, 0, 1, 1) + waitForCountRemoteNodeIndexer(t, node2, 2, 0, 1, 1) +} + +func setGetActiveValidatorImplWithValidatorNodes(t *testing.T, validators ...*cmd.Node) { + require := require.New(t) + + mapping := collections.NewConcurrentMap[bls.SerializedPublicKey, *lib.ValidatorEntry]() + for _, validator := range validators { + seed := validator.Config.PosValidatorSeed + if seed == "" { + t.Fatalf("Validator node %s does not have a PosValidatorSeed set", validator.Params.UserAgent) + } + keystore, err := lib.NewBLSKeystore(seed) + require.NoError(err) + mapping.Set(keystore.GetSigner().GetPublicKey().Serialize(), createSimpleValidatorEntry(validator)) + } + setGetActiveValidatorImpl(func() *collections.ConcurrentMap[bls.SerializedPublicKey, *lib.ValidatorEntry] { + return mapping + }) +} + +func setGetActiveValidatorImpl(mapping func() *collections.ConcurrentMap[bls.SerializedPublicKey, *lib.ValidatorEntry]) { + lib.GetActiveValidatorImpl = mapping +} + +func createSimpleValidatorEntry(node *cmd.Node) *lib.ValidatorEntry { + return &lib.ValidatorEntry{ + Domains: [][]byte{[]byte(node.Listeners[0].Addr().String())}, + } +} + +func waitForValidatorFullGraph(t *testing.T, validators ...*cmd.Node) { + for ii := 0; ii < len(validators); ii++ { + waitForValidatorConnectionOneWay(t, validators[ii], validators[ii+1:]...) + } +} + +func waitForValidatorConnectionOneWay(t *testing.T, n *cmd.Node, validators ...*cmd.Node) { + if len(validators) == 0 { + return + } + for _, validator := range validators { + waitForValidatorConnection(t, n, validator) + } +} + +func waitForNonValidatorInboundXOROutboundConnection(t *testing.T, node1 *cmd.Node, node2 *cmd.Node) { + userAgentN1 := node1.Params.UserAgent + userAgentN2 := node2.Params.UserAgent + conditionInbound := conditionNonValidatorInboundConnectionDynamic(t, node1, node2, true) + conditionOutbound := conditionNonValidatorOutboundConnectionDynamic(t, node1, node2, true) + xorCondition := func() bool { + return conditionInbound() != conditionOutbound() + } + waitForCondition(t, fmt.Sprintf("Waiting for Node (%s) to connect to inbound XOR outbound non-validator Node (%s)", + userAgentN1, userAgentN2), xorCondition) +} + +func waitForMinNonValidatorCountRemoteNodeIndexer(t *testing.T, node *cmd.Node, allCount int, validatorCount int, + minNonValidatorOutboundCount int, minNonValidatorInboundCount int) { + + userAgent := node.Params.UserAgent + rnManager := node.Server.GetConnectionController().GetRemoteNodeManager() + condition := func() bool { + return checkRemoteNodeIndexerMinNonValidatorCount(rnManager, allCount, validatorCount, + minNonValidatorOutboundCount, minNonValidatorInboundCount) + } + waitForCondition(t, fmt.Sprintf("Waiting for Node (%s) to have at least %d non-validator outbound nodes and %d non-validator inbound nodes", + userAgent, minNonValidatorOutboundCount, minNonValidatorInboundCount), condition) +} + +func checkRemoteNodeIndexerMinNonValidatorCount(manager *lib.RemoteNodeManager, allCount int, validatorCount int, + minNonValidatorOutboundCount int, minNonValidatorInboundCount int) bool { + + if allCount != manager.GetAllRemoteNodes().Count() { + return false + } + if validatorCount != manager.GetValidatorIndex().Count() { + return false + } + if minNonValidatorOutboundCount > manager.GetNonValidatorOutboundIndex().Count() { + return false + } + if minNonValidatorInboundCount > manager.GetNonValidatorInboundIndex().Count() { + return false + } + if allCount != manager.GetValidatorIndex().Count()+ + manager.GetNonValidatorOutboundIndex().Count()+ + manager.GetNonValidatorInboundIndex().Count() { + return false + } + return true +} diff --git a/integration_testing/connection_controller_test.go b/integration_testing/connection_controller_test.go index 01fb01046..58f4be33b 100644 --- a/integration_testing/connection_controller_test.go +++ b/integration_testing/connection_controller_test.go @@ -396,7 +396,8 @@ func TestConnectionControllerPersistentConnection(t *testing.T) { // Create a persistent connection from Node1 to Node2 cc := node1.Server.GetConnectionController() - require.NoError(cc.CreateNonValidatorPersistentOutboundConnection(node2.Listeners[0].Addr().String())) + _, err = cc.CreateNonValidatorPersistentOutboundConnection(node2.Listeners[0].Addr().String()) + require.NoError(err) waitForValidatorConnection(t, node1, node2) waitForNonValidatorInboundConnection(t, node2, node1) node2.Stop() @@ -408,7 +409,8 @@ func TestConnectionControllerPersistentConnection(t *testing.T) { node3 = startNode(t, node3) // Create a persistent connection from Node1 to Node3 - require.NoError(cc.CreateNonValidatorPersistentOutboundConnection(node3.Listeners[0].Addr().String())) + _, err = cc.CreateNonValidatorPersistentOutboundConnection(node3.Listeners[0].Addr().String()) + require.NoError(err) waitForNonValidatorOutboundConnection(t, node1, node3) waitForNonValidatorInboundConnection(t, node3, node1) node3.Stop() @@ -429,7 +431,8 @@ func TestConnectionControllerPersistentConnection(t *testing.T) { // Create a persistent connection from Node4 to Node5 cc = node4.Server.GetConnectionController() - require.NoError(cc.CreateNonValidatorPersistentOutboundConnection(node5.Listeners[0].Addr().String())) + _, err = cc.CreateNonValidatorPersistentOutboundConnection(node5.Listeners[0].Addr().String()) + require.NoError(err) waitForNonValidatorOutboundConnection(t, node4, node5) waitForValidatorConnection(t, node5, node4) node5.Stop() @@ -444,7 +447,8 @@ func TestConnectionControllerPersistentConnection(t *testing.T) { defer node6.Stop() // Create a persistent connection from Node4 to Node6 - require.NoError(cc.CreateNonValidatorPersistentOutboundConnection(node6.Listeners[0].Addr().String())) + _, err = cc.CreateNonValidatorPersistentOutboundConnection(node6.Listeners[0].Addr().String()) + require.NoError(err) waitForValidatorConnection(t, node4, node6) waitForValidatorConnection(t, node6, node4) t.Logf("Test #4 passed | Successfuly created persistent connection from validator Node4 to validator Node6") diff --git a/integration_testing/connection_controller_utils_test.go b/integration_testing/connection_controller_utils_test.go index 4d5594634..74a33b943 100644 --- a/integration_testing/connection_controller_utils_test.go +++ b/integration_testing/connection_controller_utils_test.go @@ -26,14 +26,24 @@ func waitForValidatorConnection(t *testing.T, node1 *cmd.Node, node2 *cmd.Node) } return true } - waitForCondition(t, fmt.Sprintf("Waiting for Node (%s) to connect to outbound non-validator Node (%s)", userAgentN1, userAgentN2), n1ValidatedN2) + waitForCondition(t, fmt.Sprintf("Waiting for Node (%s) to connect to validator Node (%s)", userAgentN1, userAgentN2), n1ValidatedN2) } func waitForNonValidatorOutboundConnection(t *testing.T, node1 *cmd.Node, node2 *cmd.Node) { userAgentN1 := node1.Params.UserAgent + userAgentN2 := node2.Params.UserAgent + condition := conditionNonValidatorOutboundConnection(t, node1, node2) + waitForCondition(t, fmt.Sprintf("Waiting for Node (%s) to connect to outbound non-validator Node (%s)", userAgentN1, userAgentN2), condition) +} + +func conditionNonValidatorOutboundConnection(t *testing.T, node1 *cmd.Node, node2 *cmd.Node) func() bool { + return conditionNonValidatorOutboundConnectionDynamic(t, node1, node2, false) +} + +func conditionNonValidatorOutboundConnectionDynamic(t *testing.T, node1 *cmd.Node, node2 *cmd.Node, inactiveValidator bool) func() bool { userAgentN2 := node2.Params.UserAgent rnManagerN1 := node1.Server.GetConnectionController().GetRemoteNodeManager() - n1ValidatedN2 := func() bool { + return func() bool { if true != checkRemoteNodeIndexerUserAgent(rnManagerN1, userAgentN2, false, true, false) { return false } @@ -44,19 +54,29 @@ func waitForNonValidatorOutboundConnection(t *testing.T, node1 *cmd.Node, node2 if !rnFromN2.IsHandshakeCompleted() { return false } - if rnFromN2.GetValidatorPublicKey() != nil { - return false + // inactiveValidator should have the public key. + if inactiveValidator { + return rnFromN2.GetValidatorPublicKey() != nil } - return true + return rnFromN2.GetValidatorPublicKey() == nil } - waitForCondition(t, fmt.Sprintf("Waiting for Node (%s) to connect to outbound non-validator Node (%s)", userAgentN1, userAgentN2), n1ValidatedN2) } func waitForNonValidatorInboundConnection(t *testing.T, node1 *cmd.Node, node2 *cmd.Node) { userAgentN1 := node1.Params.UserAgent + userAgentN2 := node2.Params.UserAgent + condition := conditionNonValidatorInboundConnection(t, node1, node2) + waitForCondition(t, fmt.Sprintf("Waiting for Node (%s) to connect to inbound non-validator Node (%s)", userAgentN1, userAgentN2), condition) +} + +func conditionNonValidatorInboundConnection(t *testing.T, node1 *cmd.Node, node2 *cmd.Node) func() bool { + return conditionNonValidatorInboundConnectionDynamic(t, node1, node2, false) +} + +func conditionNonValidatorInboundConnectionDynamic(t *testing.T, node1 *cmd.Node, node2 *cmd.Node, inactiveValidator bool) func() bool { userAgentN2 := node2.Params.UserAgent rnManagerN1 := node1.Server.GetConnectionController().GetRemoteNodeManager() - n1ValidatedN2 := func() bool { + return func() bool { if true != checkRemoteNodeIndexerUserAgent(rnManagerN1, userAgentN2, false, false, true) { return false } @@ -67,12 +87,12 @@ func waitForNonValidatorInboundConnection(t *testing.T, node1 *cmd.Node, node2 * if !rnFromN2.IsHandshakeCompleted() { return false } - if rnFromN2.GetValidatorPublicKey() != nil { - return false + // inactiveValidator should have the public key. + if inactiveValidator { + return rnFromN2.GetValidatorPublicKey() != nil } - return true + return rnFromN2.GetValidatorPublicKey() == nil } - waitForCondition(t, fmt.Sprintf("Waiting for Node (%s) to connect to inbound non-validator Node (%s)", userAgentN1, userAgentN2), n1ValidatedN2) } func waitForEmptyRemoteNodeIndexer(t *testing.T, node1 *cmd.Node) { @@ -90,15 +110,15 @@ func waitForEmptyRemoteNodeIndexer(t *testing.T, node1 *cmd.Node) { func waitForCountRemoteNodeIndexer(t *testing.T, node1 *cmd.Node, allCount int, validatorCount int, nonValidatorOutboundCount int, nonValidatorInboundCount int) { - userAgentN1 := node1.Params.UserAgent - rnManagerN1 := node1.Server.GetConnectionController().GetRemoteNodeManager() - n1ValidatedN2 := func() bool { - if true != checkRemoteNodeIndexerCount(rnManagerN1, allCount, validatorCount, nonValidatorOutboundCount, nonValidatorInboundCount) { + userAgent := node1.Params.UserAgent + rnManager := node1.Server.GetConnectionController().GetRemoteNodeManager() + condition := func() bool { + if true != checkRemoteNodeIndexerCount(rnManager, allCount, validatorCount, nonValidatorOutboundCount, nonValidatorInboundCount) { return false } return true } - waitForCondition(t, fmt.Sprintf("Waiting for Node (%s) to have appropriate RemoteNodes counts", userAgentN1), n1ValidatedN2) + waitForCondition(t, fmt.Sprintf("Waiting for Node (%s) to have appropriate RemoteNodes counts", userAgent), condition) } func checkRemoteNodeIndexerUserAgent(manager *lib.RemoteNodeManager, userAgent string, validator bool, diff --git a/integration_testing/tools.go b/integration_testing/tools.go index 2f97e942d..4db913136 100644 --- a/integration_testing/tools.go +++ b/integration_testing/tools.go @@ -69,7 +69,7 @@ func generateConfig(t *testing.T, port uint32, dataDir string, maxPeers uint32) config.MaxSyncBlockHeight = 0 config.ConnectIPs = []string{} config.PrivateMode = true - config.GlogV = 0 + config.GlogV = 2 config.GlogVmodule = "*bitcoin_manager*=0,*balance*=0,*view*=0,*frontend*=0,*peer*=0,*addr*=0,*network*=0,*utils*=0,*connection*=0,*main*=0,*server*=0,*mempool*=0,*miner*=0,*blockchain*=0" config.MaxInboundPeers = maxPeers config.TargetOutboundPeers = maxPeers diff --git a/lib/connection_controller.go b/lib/connection_controller.go index fef9fa887..18d423f46 100644 --- a/lib/connection_controller.go +++ b/lib/connection_controller.go @@ -5,12 +5,23 @@ import ( "github.com/btcsuite/btcd/addrmgr" "github.com/btcsuite/btcd/wire" "github.com/deso-protocol/core/bls" + "github.com/deso-protocol/core/collections" "github.com/golang/glog" "github.com/pkg/errors" "net" "strconv" + "sync" + "time" ) +type GetActiveValidatorsFunc func() *collections.ConcurrentMap[bls.SerializedPublicKey, *ValidatorEntry] + +var GetActiveValidatorImpl GetActiveValidatorsFunc = BasicGetActiveValidators + +func BasicGetActiveValidators() *collections.ConcurrentMap[bls.SerializedPublicKey, *ValidatorEntry] { + return collections.NewConcurrentMap[bls.SerializedPublicKey, *ValidatorEntry]() +} + // ConnectionController is a structure that oversees all connections to remote nodes. It is responsible for kicking off // the initial connections a node makes to the network. It is also responsible for creating RemoteNodes from all // successful outbound and inbound connections. The ConnectionController also ensures that the node is connected to @@ -32,8 +43,12 @@ type ConnectionController struct { // it's aware of at random and provides it to us. AddrMgr *addrmgr.AddrManager - // When --connectips is set, we don't connect to anything from the addrmgr. + // When --connect-ips is set, we don't connect to anything from the addrmgr. connectIps []string + // persistentIpToRemoteNodeIdsMap maps persistent IP addresses, like the --connect-ips, to the RemoteNodeIds of the + // corresponding RemoteNodes. This is used to ensure that we don't connect to the same persistent IP address twice. + // And that we can reconnect to the same persistent IP address if we disconnect from it. + persistentIpToRemoteNodeIdsMap map[string]RemoteNodeId // The target number of non-validator outbound remote nodes we want to have. We will disconnect remote nodes once // we've exceeded this number of outbound connections. @@ -44,11 +59,16 @@ type ConnectionController struct { // When true, only one connection per IP is allowed. Prevents eclipse attacks // among other things. limitOneInboundRemoteNodePerIP bool + + startGroup sync.WaitGroup + exitChan chan struct{} + exitGroup sync.WaitGroup } func NewConnectionController(params *DeSoParams, cmgr *ConnectionManager, handshakeController *HandshakeController, - rnManager *RemoteNodeManager, blsKeystore *BLSKeystore, addrMgr *addrmgr.AddrManager, targetNonValidatorOutboundRemoteNodes uint32, - targetNonValidatorInboundRemoteNodes uint32, limitOneInboundConnectionPerIP bool) *ConnectionController { + rnManager *RemoteNodeManager, blsKeystore *BLSKeystore, addrMgr *addrmgr.AddrManager, connectIps []string, + targetNonValidatorOutboundRemoteNodes uint32, targetNonValidatorInboundRemoteNodes uint32, + limitOneInboundConnectionPerIP bool) *ConnectionController { return &ConnectionController{ params: params, @@ -57,16 +77,45 @@ func NewConnectionController(params *DeSoParams, cmgr *ConnectionManager, handsh handshake: handshakeController, rnManager: rnManager, AddrMgr: addrMgr, + connectIps: connectIps, + persistentIpToRemoteNodeIdsMap: make(map[string]RemoteNodeId), targetNonValidatorOutboundRemoteNodes: targetNonValidatorOutboundRemoteNodes, targetNonValidatorInboundRemoteNodes: targetNonValidatorInboundRemoteNodes, limitOneInboundRemoteNodePerIP: limitOneInboundConnectionPerIP, + exitChan: make(chan struct{}), } } +func (cc *ConnectionController) Start() { + cc.startGroup.Add(1) + go cc.startPersistentConnector() + + cc.startGroup.Wait() + cc.exitGroup.Add(1) +} + +func (cc *ConnectionController) Stop() { + close(cc.exitChan) + cc.exitGroup.Wait() +} + func (cc *ConnectionController) GetRemoteNodeManager() *RemoteNodeManager { return cc.rnManager } +func (cc *ConnectionController) startPersistentConnector() { + cc.startGroup.Done() + for { + select { + case <-cc.exitChan: + cc.exitGroup.Done() + return + case <-time.After(1 * time.Second): + cc.refreshConnectIps() + } + } +} + // ########################### // ## Handlers (Peer, DeSoMessage) // ########################### @@ -77,6 +126,12 @@ func (cc *ConnectionController) _handleDonePeerMessage(origin *Peer, desoMsg DeS } cc.rnManager.DisconnectById(NewRemoteNodeId(origin.ID)) + // Update the persistentIpToRemoteNodeIdsMap. + for ip, id := range cc.persistentIpToRemoteNodeIdsMap { + if id.ToUint64() == origin.ID { + delete(cc.persistentIpToRemoteNodeIdsMap, ip) + } + } } func (cc *ConnectionController) _handleAddrMessage(origin *Peer, desoMsg DeSoMessage) { @@ -114,7 +169,7 @@ func (cc *ConnectionController) _handleNewConnectionMessage(origin *Peer, desoMs remoteNode, err = cc.processInboundConnection(msg.Connection) if err != nil { glog.Errorf("ConnectionController.handleNewConnectionMessage: Problem handling inbound connection: %v", err) - msg.Connection.Close() + cc.cleanupFailedInboundConnection(remoteNode, msg.Connection) return } case ConnectionTypeOutbound: @@ -130,6 +185,13 @@ func (cc *ConnectionController) _handleNewConnectionMessage(origin *Peer, desoMs cc.handshake.InitiateHandshake(remoteNode) } +func (cc *ConnectionController) cleanupFailedInboundConnection(remoteNode *RemoteNode, connection Connection) { + if remoteNode != nil { + cc.rnManager.Disconnect(remoteNode) + } + connection.Close() +} + func (cc *ConnectionController) cleanupFailedOutboundConnection(connection Connection) { oc, ok := connection.(*outboundConnection) if !ok { @@ -141,13 +203,34 @@ func (cc *ConnectionController) cleanupFailedOutboundConnection(connection Conne if rn != nil { cc.rnManager.Disconnect(rn) } + oc.Close() cc.cmgr.RemoveAttemptedOutboundAddrs(oc.address) } // ########################### -// ## Connections +// ## Persistent Connections // ########################### +func (cc *ConnectionController) refreshConnectIps() { + // Connect to addresses passed via the --connect-ips flag. These addresses are persistent in the sense that if we + // disconnect from one, we will try to reconnect to the same one. + for _, connectIp := range cc.connectIps { + if _, ok := cc.persistentIpToRemoteNodeIdsMap[connectIp]; ok { + continue + } + + glog.Infof("ConnectionController.initiatePersistentConnections: Connecting to connectIp: %v", connectIp) + id, err := cc.CreateNonValidatorPersistentOutboundConnection(connectIp) + if err != nil { + glog.Errorf("ConnectionController.initiatePersistentConnections: Problem connecting "+ + "to connectIp %v: %v", connectIp, err) + continue + } + + cc.persistentIpToRemoteNodeIdsMap[connectIp] = id + } +} + func (cc *ConnectionController) CreateValidatorConnection(ipStr string, publicKey *bls.PublicKey) error { netAddr, err := cc.ConvertIPStringToNetAddress(ipStr) if err != nil { @@ -156,10 +239,10 @@ func (cc *ConnectionController) CreateValidatorConnection(ipStr string, publicKe return cc.rnManager.CreateValidatorConnection(netAddr, publicKey) } -func (cc *ConnectionController) CreateNonValidatorPersistentOutboundConnection(ipStr string) error { +func (cc *ConnectionController) CreateNonValidatorPersistentOutboundConnection(ipStr string) (RemoteNodeId, error) { netAddr, err := cc.ConvertIPStringToNetAddress(ipStr) if err != nil { - return err + return 0, err } return cc.rnManager.CreateNonValidatorPersistentOutboundConnection(netAddr) } @@ -235,8 +318,8 @@ func (cc *ConnectionController) processOutboundConnection(conn Connection) (*Rem } if oc.failed { - return nil, fmt.Errorf("ConnectionController.handleOutboundConnection: Failed to connect to peer (%s)", - oc.address.IP.String()) + return nil, fmt.Errorf("ConnectionController.handleOutboundConnection: Failed to connect to peer (%s:%v)", + oc.address.IP.String(), oc.address.Port) } if !oc.isPersistent { @@ -263,11 +346,35 @@ func (cc *ConnectionController) processOutboundConnection(conn Connection) (*Rem "for addr: (%s)", oc.connection.RemoteAddr().String()) } + // Attach the connection before additional validation steps because it is already established. remoteNode, err := cc.rnManager.AttachOutboundConnection(oc.connection, na, oc.attemptId, oc.isPersistent) if remoteNode == nil || err != nil { return nil, errors.Wrapf(err, "ConnectionController.handleOutboundConnection: Problem calling rnManager.AttachOutboundConnection "+ "for addr: (%s)", oc.connection.RemoteAddr().String()) } + + // If this is a persistent remote node or a validator, we don't need to do any extra connection validation. + if remoteNode.IsPersistent() || remoteNode.GetValidatorPublicKey() != nil { + return remoteNode, nil + } + + // If we get here, it means we're dealing with a non-persistent or non-validator remote node. We perform additional + // connection validation. + + // If we already have enough outbound peers, then don't bother adding this one. + if cc.enoughNonValidatorOutboundConnections() { + return nil, fmt.Errorf("ConnectionController.handleOutboundConnection: Connected to maximum number of outbound "+ + "peers (%d)", cc.targetNonValidatorOutboundRemoteNodes) + } + + // If the group key overlaps with another peer we're already connected to then abort mission. We only connect to + // one peer per IP group in order to prevent Sybil attacks. + if cc.cmgr.IsFromRedundantOutboundIPAddress(oc.address) { + return nil, fmt.Errorf("ConnectionController.handleOutboundConnection: Rejecting OUTBOUND NON-PERSISTENT "+ + "connection with redundant group key (%s).", addrmgr.GroupKey(oc.address)) + } + cc.cmgr.AddToGroupKey(na) + return remoteNode, nil } diff --git a/lib/connection_manager.go b/lib/connection_manager.go index 7c6f510ac..1ba4bf8f1 100644 --- a/lib/connection_manager.go +++ b/lib/connection_manager.go @@ -168,6 +168,10 @@ func NewConnectionManager( // Check if the address passed shares a group with any addresses already in our data structures. func (cmgr *ConnectionManager) IsFromRedundantOutboundIPAddress(na *wire.NetAddress) bool { groupKey := addrmgr.GroupKey(na) + // For the sake of running multiple nodes on the same machine, we allow localhost connections. + if groupKey == "local" { + return false + } cmgr.mtxOutboundConnIPGroups.Lock() numGroupsForKey := cmgr.outboundConnIPGroups[groupKey] @@ -185,7 +189,7 @@ func (cmgr *ConnectionManager) IsFromRedundantOutboundIPAddress(na *wire.NetAddr return true } -func (cmgr *ConnectionManager) addToGroupKey(na *wire.NetAddress) { +func (cmgr *ConnectionManager) AddToGroupKey(na *wire.NetAddress) { groupKey := addrmgr.GroupKey(na) cmgr.mtxOutboundConnIPGroups.Lock() @@ -429,7 +433,6 @@ func (cmgr *ConnectionManager) addPeer(pp *Peer) { // number of outbound peers. Also add the peer's address to // our map. if _, ok := peerList[pp.ID]; !ok { - cmgr.addToGroupKey(pp.netAddr) atomic.AddUint32(&cmgr.numOutboundPeers, 1) cmgr.mtxAddrsMaps.Lock() @@ -528,16 +531,6 @@ func (cmgr *ConnectionManager) _logOutboundPeerData() { numInboundPeers := int(atomic.LoadUint32(&cmgr.numInboundPeers)) numPersistentPeers := int(atomic.LoadUint32(&cmgr.numPersistentPeers)) glog.V(1).Infof("Num peers: OUTBOUND(%d) INBOUND(%d) PERSISTENT(%d)", numOutboundPeers, numInboundPeers, numPersistentPeers) - - cmgr.mtxOutboundConnIPGroups.Lock() - for _, vv := range cmgr.outboundConnIPGroups { - if vv != 0 && vv != 1 { - glog.V(1).Infof("_logOutboundPeerData: Peer group count != (0 or 1). "+ - "Is (%d) instead. This "+ - "should never happen.", vv) - } - } - cmgr.mtxOutboundConnIPGroups.Unlock() } func (cmgr *ConnectionManager) AddTimeSample(addrStr string, timeSample time.Time) { @@ -617,8 +610,13 @@ func (cmgr *ConnectionManager) Start() { select { case oc := <-cmgr.outboundConnectionChan: - glog.V(2).Infof("ConnectionManager.Start: Successfully established an outbound connection with "+ - "(addr= %v)", oc.connection.RemoteAddr()) + if oc.failed { + glog.V(2).Infof("ConnectionManager.Start: Failed to establish an outbound connection with "+ + "(id= %v)", oc.attemptId) + } else { + glog.V(2).Infof("ConnectionManager.Start: Successfully established an outbound connection with "+ + "(addr= %v)", oc.connection.RemoteAddr()) + } delete(cmgr.outboundConnectionAttempts, oc.attemptId) cmgr.serverMessageQueue <- &ServerMessage{ Peer: nil, diff --git a/lib/handshake_controller.go b/lib/handshake_controller.go index bde07745a..f355bad93 100644 --- a/lib/handshake_controller.go +++ b/lib/handshake_controller.go @@ -122,7 +122,7 @@ func (hc *HandshakeController) _handleVersionMessage(origin *Peer, desoMsg DeSoM if hc.usedNonces.Contains(msgNonce) { hc.usedNonces.Delete(msgNonce) glog.Errorf("HandshakeController._handleVersionMessage: Disconnecting RemoteNode with id: (%v) "+ - "nonce collision", origin.ID) + "nonce collision, nonce (%v)", origin.ID, msgNonce) hc.rnManager.Disconnect(rn) return } diff --git a/lib/network_connection.go b/lib/network_connection.go index eb6d4ab55..ffb0bb1f1 100644 --- a/lib/network_connection.go +++ b/lib/network_connection.go @@ -33,7 +33,9 @@ func (oc *outboundConnection) Close() { if oc.terminated { return } - oc.connection.Close() + if oc.connection != nil { + oc.connection.Close() + } oc.terminated = true } @@ -58,7 +60,9 @@ func (ic *inboundConnection) Close() { return } - ic.connection.Close() + if ic.connection != nil { + ic.connection.Close() + } ic.terminated = true } diff --git a/lib/remote_node.go b/lib/remote_node.go index f2d849a36..5ba651f3f 100644 --- a/lib/remote_node.go +++ b/lib/remote_node.go @@ -223,6 +223,10 @@ func (rn *RemoteNode) IsHandshakeCompleted() bool { return rn.connectionStatus == RemoteNodeStatus_HandshakeCompleted } +func (rn *RemoteNode) IsTerminated() bool { + return rn.connectionStatus == RemoteNodeStatus_Terminated +} + func (rn *RemoteNode) IsValidator() bool { if !rn.IsHandshakeCompleted() { return false diff --git a/lib/remote_node_manager.go b/lib/remote_node_manager.go index fb269d072..02bed8e3e 100644 --- a/lib/remote_node_manager.go +++ b/lib/remote_node_manager.go @@ -140,7 +140,7 @@ func (manager *RemoteNodeManager) CreateValidatorConnection(netAddr *wire.NetAdd } remoteNode := manager.newRemoteNode(publicKey) - if err := remoteNode.DialPersistentOutboundConnection(netAddr); err != nil { + if err := remoteNode.DialOutboundConnection(netAddr); err != nil { return errors.Wrapf(err, "RemoteNodeManager.CreateValidatorConnection: Problem calling DialPersistentOutboundConnection "+ "for addr: (%s:%v)", netAddr.IP.String(), netAddr.Port) } @@ -149,19 +149,19 @@ func (manager *RemoteNodeManager) CreateValidatorConnection(netAddr *wire.NetAdd return nil } -func (manager *RemoteNodeManager) CreateNonValidatorPersistentOutboundConnection(netAddr *wire.NetAddress) error { +func (manager *RemoteNodeManager) CreateNonValidatorPersistentOutboundConnection(netAddr *wire.NetAddress) (RemoteNodeId, error) { if netAddr == nil { - return fmt.Errorf("RemoteNodeManager.CreateNonValidatorPersistentOutboundConnection: netAddr is nil") + return 0, fmt.Errorf("RemoteNodeManager.CreateNonValidatorPersistentOutboundConnection: netAddr is nil") } remoteNode := manager.newRemoteNode(nil) if err := remoteNode.DialPersistentOutboundConnection(netAddr); err != nil { - return errors.Wrapf(err, "RemoteNodeManager.CreateNonValidatorPersistentOutboundConnection: Problem calling DialPersistentOutboundConnection "+ + return 0, errors.Wrapf(err, "RemoteNodeManager.CreateNonValidatorPersistentOutboundConnection: Problem calling DialPersistentOutboundConnection "+ "for addr: (%s:%v)", netAddr.IP.String(), netAddr.Port) } manager.setRemoteNode(remoteNode) manager.GetNonValidatorOutboundIndex().Set(remoteNode.GetId(), remoteNode) - return nil + return remoteNode.GetId(), nil } func (manager *RemoteNodeManager) CreateNonValidatorOutboundConnection(netAddr *wire.NetAddress) error { @@ -184,7 +184,7 @@ func (manager *RemoteNodeManager) AttachInboundConnection(conn net.Conn, remoteNode := manager.newRemoteNode(nil) if err := remoteNode.AttachInboundConnection(conn, na); err != nil { - return nil, errors.Wrapf(err, "RemoteNodeManager.AttachInboundConnection: Problem calling AttachInboundConnection "+ + return remoteNode, errors.Wrapf(err, "RemoteNodeManager.AttachInboundConnection: Problem calling AttachInboundConnection "+ "for addr: (%s)", conn.RemoteAddr().String()) } @@ -219,7 +219,7 @@ func (manager *RemoteNodeManager) setRemoteNode(rn *RemoteNode) { manager.mtx.Lock() defer manager.mtx.Unlock() - if rn == nil { + if rn == nil || rn.IsTerminated() { return } @@ -230,7 +230,7 @@ func (manager *RemoteNodeManager) SetNonValidator(rn *RemoteNode) { manager.mtx.Lock() defer manager.mtx.Unlock() - if rn == nil { + if rn == nil || rn.IsTerminated() { return } @@ -245,7 +245,7 @@ func (manager *RemoteNodeManager) SetValidator(remoteNode *RemoteNode) { manager.mtx.Lock() defer manager.mtx.Unlock() - if remoteNode == nil { + if remoteNode == nil || remoteNode.IsTerminated() { return } @@ -260,7 +260,7 @@ func (manager *RemoteNodeManager) UnsetValidator(remoteNode *RemoteNode) { manager.mtx.Lock() defer manager.mtx.Unlock() - if remoteNode == nil { + if remoteNode == nil || remoteNode.IsTerminated() { return } @@ -275,7 +275,7 @@ func (manager *RemoteNodeManager) UnsetNonValidator(rn *RemoteNode) { manager.mtx.Lock() defer manager.mtx.Unlock() - if rn == nil { + if rn == nil || rn.IsTerminated() { return } diff --git a/lib/server.go b/lib/server.go index d4c371955..a4fa28376 100644 --- a/lib/server.go +++ b/lib/server.go @@ -499,8 +499,8 @@ func NewServer( rnManager := NewRemoteNodeManager(srv, _chain, _cmgr, _blsKeystore, _params, _minFeeRateNanosPerKB, nodeServices) srv.handshakeController = NewHandshakeController(rnManager) - srv.connectionController = NewConnectionController(_params, _cmgr, srv.handshakeController, rnManager, - _blsKeystore, _desoAddrMgr, _targetOutboundPeers, _maxInboundPeers, _limitOneInboundConnectionPerIP) + srv.connectionController = NewConnectionController(_params, _cmgr, srv.handshakeController, rnManager, _blsKeystore, + _desoAddrMgr, _connectIps, _targetOutboundPeers, _maxInboundPeers, _limitOneInboundConnectionPerIP) if srv.stateChangeSyncer != nil { srv.stateChangeSyncer.BlockHeight = uint64(_chain.headerTip().Height) @@ -2547,6 +2547,9 @@ func (srv *Server) Stop() { srv.cmgr.Stop() glog.Infof(CLog(Yellow, "Server.Stop: Closed the ConnectionManger")) + srv.connectionController.Stop() + glog.Infof(CLog(Yellow, "Server.Stop: Closed the ConnectionController")) + // Stop the miner if we have one running. if srv.miner != nil { srv.miner.Stop() @@ -2629,6 +2632,8 @@ func (srv *Server) Start() { if srv.miner != nil && len(srv.miner.PublicKeys) > 0 { go srv.miner.Start() } + + srv.connectionController.Start() } // SyncPrefixProgress keeps track of sync progress on an individual prefix. It is used in