diff --git a/integration_testing/blocksync_test.go b/integration_testing/blocksync_test.go index 8be96d735..be87aae3a 100644 --- a/integration_testing/blocksync_test.go +++ b/integration_testing/blocksync_test.go @@ -40,9 +40,8 @@ func TestSimpleBlockSync(t *testing.T) { // wait for node1 to sync blocks waitForNodeToFullySync(node1) - // bridge the nodes together. - bridge := NewConnectionBridge(node1, node2) - require.NoError(bridge.Start()) + // TODO: Dial an outbound connection from node2 to node1 + // Fix other integration tests. // wait for node2 to sync blocks. waitForNodeToFullySync(node2) @@ -99,6 +98,7 @@ func TestSimpleSyncRestart(t *testing.T) { compareNodesByDB(t, node1, node2, 0) fmt.Println("Random restart successful! Random height was", randomHeight) fmt.Println("Databases match!") + bridge.Disconnect() node1.Stop() node2.Stop() } @@ -153,7 +153,7 @@ func TestSimpleSyncDisconnectWithSwitchingToNewPeer(t *testing.T) { randomHeight := randomUint32Between(t, 10, config2.MaxSyncBlockHeight) fmt.Println("Random height for a restart (re-use if test failed):", randomHeight) - disconnectAtBlockHeight(t, node2, bridge12, randomHeight) + disconnectAtBlockHeight(node2, bridge12, randomHeight) // bridge the nodes together. bridge23 := NewConnectionBridge(node2, node3) @@ -167,6 +167,8 @@ func TestSimpleSyncDisconnectWithSwitchingToNewPeer(t *testing.T) { compareNodesByDB(t, node3, node2, 0) fmt.Println("Random restart successful! Random height was", randomHeight) fmt.Println("Databases match!") + bridge12.Disconnect() + bridge23.Disconnect() node1.Stop() node2.Stop() node3.Stop() diff --git a/integration_testing/connection_bridge.go b/integration_testing/connection_bridge.go index 4c3b28dde..f6a9897ed 100644 --- a/integration_testing/connection_bridge.go +++ b/integration_testing/connection_bridge.go @@ -13,6 +13,7 @@ import ( "time" ) +// TODO: DEPRECATE // ConnectionBridge is a bidirectional communication channel between two nodes. A bridge creates a pair of inbound and // outbound peers for each of the nodes to handle communication. In total, it creates four peers. // @@ -111,13 +112,14 @@ func (bridge *ConnectionBridge) createInboundConnection(node *cmd.Node) *lib.Pee } // This channel is redundant in our setting. - messagesFromPeer := make(chan *lib.ServerMessage) + messagesFromPeer := make(chan *lib.ServerMessage, 100) + newPeerChan := make(chan *lib.Peer, 100) + donePeerChan := make(chan *lib.Peer, 100) // Because it is an inbound Peer of the node, it is simultaneously a "fake" outbound Peer of the bridge. // Hence, we will mark the _isOutbound parameter as "true" in NewPeer. - peer := lib.NewPeer(conn, true, netAddress, true, - 10000, 0, &lib.DeSoMainnetParams, - messagesFromPeer, nil, nil, lib.NodeSyncTypeAny) - peer.ID = uint64(lib.RandInt64(math.MaxInt64)) + peer := lib.NewPeer(uint64(lib.RandInt64(math.MaxInt64)), conn, true, + netAddress, true, 10000, 0, &lib.DeSoMainnetParams, + messagesFromPeer, nil, nil, lib.NodeSyncTypeAny, newPeerChan, donePeerChan) return peer } @@ -139,27 +141,28 @@ func (bridge *ConnectionBridge) createOutboundConnection(node *cmd.Node, otherNo fmt.Println("createOutboundConnection: Got a connection from remote:", conn.RemoteAddr().String(), "on listener:", ll.Addr().String()) - na, err := lib.IPToNetAddr(conn.RemoteAddr().String(), otherNode.Server.GetConnectionManager().AddrMgr, - otherNode.Params) - messagesFromPeer := make(chan *lib.ServerMessage) - peer := lib.NewPeer(conn, false, na, false, - 10000, 0, bridge.nodeB.Params, - messagesFromPeer, nil, nil, lib.NodeSyncTypeAny) - peer.ID = uint64(lib.RandInt64(math.MaxInt64)) + addrMgr := addrmgr.New("", net.LookupIP) + na, err := lib.IPToNetAddr(conn.RemoteAddr().String(), addrMgr, otherNode.Params) + messagesFromPeer := make(chan *lib.ServerMessage, 100) + newPeerChan := make(chan *lib.Peer, 100) + donePeerChan := make(chan *lib.Peer, 100) + peer := lib.NewPeer(uint64(lib.RandInt64(math.MaxInt64)), conn, + false, na, false, 10000, 0, bridge.nodeB.Params, + messagesFromPeer, nil, nil, lib.NodeSyncTypeAny, newPeerChan, donePeerChan) bridge.newPeerChan <- peer //} }(ll) // Make the provided node to make an outbound connection to our listener. - netAddress, _ := lib.IPToNetAddr(ll.Addr().String(), addrmgr.New("", net.LookupIP), &lib.DeSoMainnetParams) - fmt.Println("createOutboundConnection: IP:", netAddress.IP, "Port:", netAddress.Port) - go node.Server.GetConnectionManager().ConnectPeer(nil, netAddress) + addrMgr := addrmgr.New("", net.LookupIP) + addr, _ := lib.IPToNetAddr(ll.Addr().String(), addrMgr, node.Params) + go node.Server.GetConnectionManager().DialOutboundConnection(addr, uint64(lib.RandInt64(math.MaxInt64))) } // getVersionMessage simulates a version message that the provided node would have sent. func (bridge *ConnectionBridge) getVersionMessage(node *cmd.Node) *lib.MsgDeSoVersion { ver := lib.NewMessage(lib.MsgTypeVersion).(*lib.MsgDeSoVersion) - ver.Version = node.Params.ProtocolVersion + ver.Version = node.Params.ProtocolVersion.ToUint64() ver.TstampSecs = time.Now().Unix() ver.Nonce = uint64(lib.RandInt64(math.MaxInt64)) ver.UserAgent = node.Params.UserAgent @@ -172,12 +175,29 @@ func (bridge *ConnectionBridge) getVersionMessage(node *cmd.Node) *lib.MsgDeSoVe } if node.Server != nil { - ver.LatestBlockHeight = uint32(node.Server.GetBlockchain().BlockTip().Header.Height) + ver.LatestBlockHeight = node.Server.GetBlockchain().BlockTip().Header.Height } ver.MinFeeRateNanosPerKB = node.Config.MinFeerate return ver } +func ReadWithTimeout(readFunc func() error, readTimeout time.Duration) error { + errChan := make(chan error) + go func() { + errChan <- readFunc() + }() + select { + case err := <-errChan: + { + return err + } + case <-time.After(readTimeout): + { + return fmt.Errorf("ReadWithTimeout: Timed out reading message") + } + } +} + // startConnection starts the connection by performing version and verack exchange with // the provided connection, pretending to be the otherNode. func (bridge *ConnectionBridge) startConnection(connection *lib.Peer, otherNode *cmd.Node) error { @@ -192,7 +212,7 @@ func (bridge *ConnectionBridge) startConnection(connection *lib.Peer, otherNode } // Wait for a response to the version message. - if err := connection.ReadWithTimeout( + if err := ReadWithTimeout( func() error { msg, err := connection.ReadDeSoMessage() if err != nil { @@ -215,7 +235,7 @@ func (bridge *ConnectionBridge) startConnection(connection *lib.Peer, otherNode // Now prepare the verack message. verackMsg := lib.NewMessage(lib.MsgTypeVerack) - verackMsg.(*lib.MsgDeSoVerack).Nonce = connection.VersionNonceReceived + verackMsg.(*lib.MsgDeSoVerack).NonceReceived = connection.VersionNonceReceived // And send it to the connection. if err := connection.WriteDeSoMessage(verackMsg); err != nil { @@ -223,7 +243,7 @@ func (bridge *ConnectionBridge) startConnection(connection *lib.Peer, otherNode } // And finally wait for connection's response to the verack message. - if err := connection.ReadWithTimeout( + if err := ReadWithTimeout( func() error { msg, err := connection.ReadDeSoMessage() if err != nil { @@ -234,9 +254,9 @@ func (bridge *ConnectionBridge) startConnection(connection *lib.Peer, otherNode return fmt.Errorf("message is not verack! Type: %v", msg.GetMsgType()) } verackMsg := msg.(*lib.MsgDeSoVerack) - if verackMsg.Nonce != connection.VersionNonceSent { + if verackMsg.NonceReceived != connection.VersionNonceSent { return fmt.Errorf("verack message nonce doesn't match (received: %v, sent: %v)", - verackMsg.Nonce, connection.VersionNonceSent) + verackMsg.NonceReceived, connection.VersionNonceSent) } return nil }, lib.DeSoMainnetParams.VersionNegotiationTimeout); err != nil { diff --git a/integration_testing/connection_controller_test.go b/integration_testing/connection_controller_test.go new file mode 100644 index 000000000..01fb01046 --- /dev/null +++ b/integration_testing/connection_controller_test.go @@ -0,0 +1,451 @@ +package integration_testing + +import ( + "github.com/deso-protocol/core/bls" + "github.com/deso-protocol/core/lib" + "github.com/stretchr/testify/require" + "testing" +) + +func TestConnectionControllerNonValidator(t *testing.T) { + require := require.New(t) + + node1 := spawnNonValidatorNodeProtocol2(t, 18000, "node1") + node1 = startNode(t, node1) + defer node1.Stop() + + // Make sure NonValidator Node1 can create an outbound connection to NonValidator Node2 + node2 := spawnNonValidatorNodeProtocol2(t, 18001, "node2") + node2 = startNode(t, node2) + + cc := node1.Server.GetConnectionController() + require.NoError(cc.CreateNonValidatorOutboundConnection(node2.Listeners[0].Addr().String())) + waitForNonValidatorOutboundConnection(t, node1, node2) + waitForNonValidatorInboundConnection(t, node2, node1) + + node2.Stop() + waitForEmptyRemoteNodeIndexer(t, node1) + t.Logf("Test #1 passed | Successfully created outbound connection from NonValidator Node1 to NonValidator Node2") + + // Make sure NonValidator Node1 can create an outbound connection to validator Node3 + blsPriv3, err := bls.NewPrivateKey() + require.NoError(err) + node3 := spawnValidatorNodeProtocol2(t, 18002, "node3", blsPriv3) + node3 = startNode(t, node3) + + cc = node1.Server.GetConnectionController() + require.NoError(cc.CreateNonValidatorOutboundConnection(node3.Listeners[0].Addr().String())) + waitForValidatorConnection(t, node1, node3) + waitForNonValidatorInboundConnection(t, node3, node1) + + node3.Stop() + waitForEmptyRemoteNodeIndexer(t, node1) + t.Logf("Test #2 passed | Successfully created outbound connection from NonValidator Node1 to Validator Node3") + + // Make sure NonValidator Node1 can create a non-validator connection to validator Node4 + blsPriv4, err := bls.NewPrivateKey() + require.NoError(err) + node4 := spawnValidatorNodeProtocol2(t, 18003, "node4", blsPriv4) + node4 = startNode(t, node4) + defer node4.Stop() + + cc = node1.Server.GetConnectionController() + require.NoError(cc.CreateNonValidatorOutboundConnection(node4.Listeners[0].Addr().String())) + waitForValidatorConnection(t, node1, node4) + waitForNonValidatorInboundConnection(t, node4, node1) + t.Logf("Test #3 passed | Successfully created outbound connection from NonValidator Node1 to Validator Node4") +} + +func TestConnectionControllerValidator(t *testing.T) { + require := require.New(t) + + blsPriv1, err := bls.NewPrivateKey() + require.NoError(err) + node1 := spawnValidatorNodeProtocol2(t, 18000, "node1", blsPriv1) + node1 = startNode(t, node1) + defer node1.Stop() + + // Make sure Validator Node1 can create an outbound connection to Validator Node2 + blsPriv2, err := bls.NewPrivateKey() + blsPub2 := blsPriv2.PublicKey() + require.NoError(err) + node2 := spawnValidatorNodeProtocol2(t, 18001, "node2", blsPriv2) + node2 = startNode(t, node2) + + cc := node1.Server.GetConnectionController() + require.NoError(cc.CreateValidatorConnection(node2.Listeners[0].Addr().String(), blsPub2)) + waitForValidatorConnection(t, node1, node2) + waitForValidatorConnection(t, node2, node1) + + node2.Stop() + waitForEmptyRemoteNodeIndexer(t, node1) + t.Logf("Test #1 passed | Successfully created outbound connection from Validator Node1 to Validator Node2") + + // Make sure Validator Node1 can create an outbound connection to NonValidator Node3 + node3 := spawnNonValidatorNodeProtocol2(t, 18002, "node3") + node3 = startNode(t, node3) + + cc = node1.Server.GetConnectionController() + require.NoError(cc.CreateNonValidatorOutboundConnection(node3.Listeners[0].Addr().String())) + waitForNonValidatorOutboundConnection(t, node1, node3) + waitForValidatorConnection(t, node3, node1) + + node3.Stop() + waitForEmptyRemoteNodeIndexer(t, node1) + t.Logf("Test #2 passed | Successfully created outbound connection from Validator Node1 to NonValidator Node3") + + // Make sure Validator Node1 can create an outbound non-validator connection to Validator Node4 + blsPriv4, err := bls.NewPrivateKey() + require.NoError(err) + node4 := spawnValidatorNodeProtocol2(t, 18003, "node4", blsPriv4) + node4 = startNode(t, node4) + defer node4.Stop() + + cc = node1.Server.GetConnectionController() + require.NoError(cc.CreateNonValidatorOutboundConnection(node4.Listeners[0].Addr().String())) + waitForValidatorConnection(t, node1, node4) + waitForValidatorConnection(t, node4, node1) + t.Logf("Test #3 passed | Successfully created non-validator outbound connection from Validator Node1 to Validator Node4") +} + +func TestConnectionControllerHandshakeDataErrors(t *testing.T) { + require := require.New(t) + + blsPriv1, err := bls.NewPrivateKey() + require.NoError(err) + node1 := spawnValidatorNodeProtocol2(t, 18000, "node1", blsPriv1) + + // This node should have ProtocolVersion2, but it has ProtocolVersion1 as we want it to disconnect. + blsPriv2, err := bls.NewPrivateKey() + require.NoError(err) + node2 := spawnValidatorNodeProtocol2(t, 18001, "node2", blsPriv2) + node2.Params.ProtocolVersion = lib.ProtocolVersion1 + + node1 = startNode(t, node1) + node2 = startNode(t, node2) + defer node1.Stop() + defer node2.Stop() + + cc := node2.Server.GetConnectionController() + require.NoError(cc.CreateNonValidatorOutboundConnection(node1.Listeners[0].Addr().String())) + waitForEmptyRemoteNodeIndexer(t, node1) + waitForEmptyRemoteNodeIndexer(t, node2) + t.Logf("Test #1 passed | Successfuly disconnected node with SFValidator flag and ProtocolVersion1 mismatch") + + // This node shouldn't have ProtocolVersion3, which is beyond latest ProtocolVersion2, meaning nodes should disconnect. + blsPriv3, err := bls.NewPrivateKey() + require.NoError(err) + node3 := spawnValidatorNodeProtocol2(t, 18002, "node3", blsPriv3) + node3.Params.ProtocolVersion = lib.ProtocolVersionType(3) + node3 = startNode(t, node3) + defer node3.Stop() + + cc = node1.Server.GetConnectionController() + require.NoError(cc.CreateNonValidatorOutboundConnection(node3.Listeners[0].Addr().String())) + waitForEmptyRemoteNodeIndexer(t, node1) + waitForEmptyRemoteNodeIndexer(t, node3) + t.Logf("Test #2 passed | Successfuly disconnected node with ProtocolVersion3") + + // This node shouldn't have ProtocolVersion0, which is outdated. + node4 := spawnNonValidatorNodeProtocol2(t, 18003, "node4") + node4.Params.ProtocolVersion = lib.ProtocolVersion0 + node4 = startNode(t, node4) + defer node4.Stop() + + cc = node1.Server.GetConnectionController() + require.NoError(cc.CreateNonValidatorOutboundConnection(node4.Listeners[0].Addr().String())) + waitForEmptyRemoteNodeIndexer(t, node1) + waitForEmptyRemoteNodeIndexer(t, node4) + t.Logf("Test #3 passed | Successfuly disconnected node with ProtocolVersion0") + + // This node will have a different public key than the one it's supposed to have. + blsPriv5, err := bls.NewPrivateKey() + require.NoError(err) + blsPriv5Wrong, err := bls.NewPrivateKey() + require.NoError(err) + node5 := spawnValidatorNodeProtocol2(t, 18004, "node5", blsPriv5) + node5 = startNode(t, node5) + defer node5.Stop() + + cc = node1.Server.GetConnectionController() + require.NoError(cc.CreateValidatorConnection(node5.Listeners[0].Addr().String(), blsPriv5Wrong.PublicKey())) + waitForEmptyRemoteNodeIndexer(t, node1) + waitForEmptyRemoteNodeIndexer(t, node5) + t.Logf("Test #4 passed | Successfuly disconnected node with public key mismatch") + + // This node will be missing SFPosValidator flag while being connected as a validator. + blsPriv6, err := bls.NewPrivateKey() + require.NoError(err) + node6 := spawnNonValidatorNodeProtocol2(t, 18005, "node6") + node6 = startNode(t, node6) + defer node6.Stop() + + cc = node1.Server.GetConnectionController() + require.NoError(cc.CreateValidatorConnection(node6.Listeners[0].Addr().String(), blsPriv6.PublicKey())) + waitForEmptyRemoteNodeIndexer(t, node1) + waitForEmptyRemoteNodeIndexer(t, node6) + t.Logf("Test #5 passed | Successfuly disconnected supposed validator node with missing SFPosValidator flag") + + // This node will have ProtocolVersion1 and be connected as an outbound non-validator node. + node7 := spawnNonValidatorNodeProtocol2(t, 18006, "node7") + node7.Params.ProtocolVersion = lib.ProtocolVersion1 + node7 = startNode(t, node7) + defer node7.Stop() + + cc = node1.Server.GetConnectionController() + require.NoError(cc.CreateNonValidatorOutboundConnection(node7.Listeners[0].Addr().String())) + waitForEmptyRemoteNodeIndexer(t, node1) + waitForEmptyRemoteNodeIndexer(t, node7) + t.Logf("Test #6 passed | Successfuly disconnected outbound non-validator node with ProtocolVersion1") +} + +func TestConnectionControllerHandshakeTimeouts(t *testing.T) { + require := require.New(t) + + // Set version negotiation timeout to 0 to make sure that the node will be disconnected + node1 := spawnNonValidatorNodeProtocol2(t, 18000, "node1") + node1.Params.VersionNegotiationTimeout = 0 + node1 = startNode(t, node1) + defer node1.Stop() + + node2 := spawnNonValidatorNodeProtocol2(t, 18001, "node2") + node2 = startNode(t, node2) + defer node2.Stop() + + cc := node1.Server.GetConnectionController() + require.NoError(cc.CreateNonValidatorOutboundConnection(node2.Listeners[0].Addr().String())) + waitForEmptyRemoteNodeIndexer(t, node1) + waitForEmptyRemoteNodeIndexer(t, node2) + t.Logf("Test #1 passed | Successfuly disconnected node after version negotiation timeout") + + // Now let's try timing out verack exchange + node1.Params.VersionNegotiationTimeout = lib.DeSoTestnetParams.VersionNegotiationTimeout + node3 := spawnNonValidatorNodeProtocol2(t, 18002, "node3") + node3.Params.VerackNegotiationTimeout = 0 + node3 = startNode(t, node3) + defer node3.Stop() + + cc = node3.Server.GetConnectionController() + require.NoError(cc.CreateNonValidatorOutboundConnection(node1.Listeners[0].Addr().String())) + waitForEmptyRemoteNodeIndexer(t, node1) + waitForEmptyRemoteNodeIndexer(t, node3) + t.Logf("Test #2 passed | Successfuly disconnected node after verack exchange timeout") + + // Now let's try timing out handshake between two validators node4 and node5 + blsPriv4, err := bls.NewPrivateKey() + require.NoError(err) + node4 := spawnValidatorNodeProtocol2(t, 18003, "node4", blsPriv4) + node4.Params.HandshakeTimeoutMicroSeconds = 0 + node4 = startNode(t, node4) + defer node4.Stop() + + blsPriv5, err := bls.NewPrivateKey() + require.NoError(err) + node5 := spawnValidatorNodeProtocol2(t, 18004, "node5", blsPriv5) + node5 = startNode(t, node5) + defer node5.Stop() + + cc = node4.Server.GetConnectionController() + require.NoError(cc.CreateValidatorConnection(node5.Listeners[0].Addr().String(), blsPriv5.PublicKey())) + waitForEmptyRemoteNodeIndexer(t, node4) + waitForEmptyRemoteNodeIndexer(t, node5) + t.Logf("Test #3 passed | Successfuly disconnected validator node after handshake timeout") +} + +func TestConnectionControllerValidatorDuplication(t *testing.T) { + require := require.New(t) + + node1 := spawnNonValidatorNodeProtocol2(t, 18000, "node1") + node1 = startNode(t, node1) + defer node1.Stop() + + // Create a validator Node2 + blsPriv2, err := bls.NewPrivateKey() + require.NoError(err) + node2 := spawnValidatorNodeProtocol2(t, 18001, "node2", blsPriv2) + node2 = startNode(t, node2) + + // Create a duplicate validator Node3 + node3 := spawnValidatorNodeProtocol2(t, 18002, "node3", blsPriv2) + node3 = startNode(t, node3) + + // Create validator connection from Node1 to Node2 and from Node1 to Node3 + cc := node1.Server.GetConnectionController() + require.NoError(cc.CreateValidatorConnection(node2.Listeners[0].Addr().String(), blsPriv2.PublicKey())) + // This should fail out right because Node3 has a duplicate public key. + require.Error(cc.CreateValidatorConnection(node3.Listeners[0].Addr().String(), blsPriv2.PublicKey())) + waitForValidatorConnection(t, node1, node2) + waitForNonValidatorInboundConnection(t, node2, node1) + + // Now create an outbound connection from Node3 to Node1, which should pass handshake, but then fail because + // Node1 already has a validator connection to Node2 with the same public key. + cc3 := node3.Server.GetConnectionController() + require.NoError(cc3.CreateNonValidatorOutboundConnection(node1.Listeners[0].Addr().String())) + waitForEmptyRemoteNodeIndexer(t, node3) + waitForCountRemoteNodeIndexer(t, node1, 1, 1, 0, 0) + t.Logf("Test #1 passed | Successfuly rejected duplicate validator connection with inbound/outbound validators") + + node3.Stop() + node2.Stop() + waitForEmptyRemoteNodeIndexer(t, node1) + + // Create two more validators Node4, Node5 with duplicate public keys + blsPriv4, err := bls.NewPrivateKey() + require.NoError(err) + node4 := spawnValidatorNodeProtocol2(t, 18003, "node4", blsPriv4) + node4 = startNode(t, node4) + defer node4.Stop() + + node5 := spawnValidatorNodeProtocol2(t, 18004, "node5", blsPriv4) + node5 = startNode(t, node5) + defer node5.Stop() + + // Create validator connections from Node4 to Node1 and from Node5 to Node1 + cc4 := node4.Server.GetConnectionController() + require.NoError(cc4.CreateNonValidatorOutboundConnection(node1.Listeners[0].Addr().String())) + waitForValidatorConnection(t, node1, node4) + waitForNonValidatorOutboundConnection(t, node4, node1) + cc5 := node5.Server.GetConnectionController() + require.NoError(cc5.CreateNonValidatorOutboundConnection(node1.Listeners[0].Addr().String())) + waitForEmptyRemoteNodeIndexer(t, node5) + waitForCountRemoteNodeIndexer(t, node1, 1, 1, 0, 0) + t.Logf("Test #2 passed | Successfuly rejected duplicate validator connection with multiple outbound validators") +} + +func TestConnectionControllerProtocolDifference(t *testing.T) { + require := require.New(t) + + // Create a ProtocolVersion1 Node1 + node1 := spawnNonValidatorNodeProtocol2(t, 18000, "node1") + node1.Params.ProtocolVersion = lib.ProtocolVersion1 + node1 = startNode(t, node1) + defer node1.Stop() + + // Create a ProtocolVersion2 NonValidator Node2 + node2 := spawnNonValidatorNodeProtocol2(t, 18001, "node2") + node2 = startNode(t, node2) + + // Create non-validator connection from Node1 to Node2 + cc := node1.Server.GetConnectionController() + require.NoError(cc.CreateNonValidatorOutboundConnection(node2.Listeners[0].Addr().String())) + waitForNonValidatorOutboundConnection(t, node1, node2) + waitForNonValidatorInboundConnection(t, node2, node1) + t.Logf("Test #1 passed | Successfuly connected to a ProtocolVersion1 node with a ProtocolVersion2 non-validator") + + // Create a ProtocolVersion2 Validator Node3 + blsPriv3, err := bls.NewPrivateKey() + require.NoError(err) + node3 := spawnValidatorNodeProtocol2(t, 18002, "node3", blsPriv3) + node3 = startNode(t, node3) + + // Create validator connection from Node1 to Node3 + require.NoError(cc.CreateValidatorConnection(node3.Listeners[0].Addr().String(), blsPriv3.PublicKey())) + waitForValidatorConnection(t, node1, node3) + waitForNonValidatorInboundConnection(t, node3, node1) + t.Logf("Test #2 passed | Successfuly connected to a ProtocolVersion1 node with a ProtocolVersion2 validator") + + node2.Stop() + node3.Stop() + waitForEmptyRemoteNodeIndexer(t, node1) + + // Create a ProtocolVersion2 validator Node4 + blsPriv4, err := bls.NewPrivateKey() + require.NoError(err) + node4 := spawnValidatorNodeProtocol2(t, 18003, "node4", blsPriv4) + node4 = startNode(t, node4) + defer node4.Stop() + + // Attempt to create non-validator connection from Node4 to Node1 + cc = node4.Server.GetConnectionController() + require.NoError(cc.CreateNonValidatorOutboundConnection(node1.Listeners[0].Addr().String())) + waitForEmptyRemoteNodeIndexer(t, node4) + waitForEmptyRemoteNodeIndexer(t, node1) + t.Logf("Test #3 passed | Successfuly rejected outbound connection from ProtocolVersion2 node to ProtcolVersion1 node") + + // Attempt to create validator connection from Node4 to Node1 + require.NoError(cc.CreateValidatorConnection(node1.Listeners[0].Addr().String(), blsPriv4.PublicKey())) + waitForEmptyRemoteNodeIndexer(t, node4) + waitForEmptyRemoteNodeIndexer(t, node1) + t.Logf("Test #4 passed | Successfuly rejected validator connection from ProtocolVersion2 node to ProtcolVersion1 node") + + // Create a ProtocolVersion2 non-validator Node5 + node5 := spawnNonValidatorNodeProtocol2(t, 18004, "node5") + node5 = startNode(t, node5) + defer node5.Stop() + + // Attempt to create non-validator connection from Node5 to Node1 + cc = node5.Server.GetConnectionController() + require.NoError(cc.CreateNonValidatorOutboundConnection(node1.Listeners[0].Addr().String())) + waitForEmptyRemoteNodeIndexer(t, node5) + waitForEmptyRemoteNodeIndexer(t, node1) + t.Logf("Test #5 passed | Successfuly rejected outbound connection from ProtocolVersion2 node to ProtcolVersion1 node") +} + +func TestConnectionControllerPersistentConnection(t *testing.T) { + require := require.New(t) + + // Create a NonValidator Node1 + node1 := spawnNonValidatorNodeProtocol2(t, 18000, "node1") + node1 = startNode(t, node1) + + // Create a Validator Node2 + blsPriv2, err := bls.NewPrivateKey() + require.NoError(err) + node2 := spawnValidatorNodeProtocol2(t, 18001, "node2", blsPriv2) + node2 = startNode(t, node2) + + // Create a persistent connection from Node1 to Node2 + cc := node1.Server.GetConnectionController() + require.NoError(cc.CreateNonValidatorPersistentOutboundConnection(node2.Listeners[0].Addr().String())) + waitForValidatorConnection(t, node1, node2) + waitForNonValidatorInboundConnection(t, node2, node1) + node2.Stop() + waitForEmptyRemoteNodeIndexer(t, node1) + t.Logf("Test #1 passed | Successfuly created persistent connection from non-validator Node1 to validator Node2") + + // Create a Non-validator Node3 + node3 := spawnNonValidatorNodeProtocol2(t, 18002, "node3") + node3 = startNode(t, node3) + + // Create a persistent connection from Node1 to Node3 + require.NoError(cc.CreateNonValidatorPersistentOutboundConnection(node3.Listeners[0].Addr().String())) + waitForNonValidatorOutboundConnection(t, node1, node3) + waitForNonValidatorInboundConnection(t, node3, node1) + node3.Stop() + waitForEmptyRemoteNodeIndexer(t, node1) + node1.Stop() + t.Logf("Test #2 passed | Successfuly created persistent connection from non-validator Node1 to non-validator Node3") + + // Create a Validator Node4 + blsPriv4, err := bls.NewPrivateKey() + require.NoError(err) + node4 := spawnValidatorNodeProtocol2(t, 18003, "node4", blsPriv4) + node4 = startNode(t, node4) + defer node4.Stop() + + // Create a non-validator Node5 + node5 := spawnNonValidatorNodeProtocol2(t, 18004, "node5") + node5 = startNode(t, node5) + + // Create a persistent connection from Node4 to Node5 + cc = node4.Server.GetConnectionController() + require.NoError(cc.CreateNonValidatorPersistentOutboundConnection(node5.Listeners[0].Addr().String())) + waitForNonValidatorOutboundConnection(t, node4, node5) + waitForValidatorConnection(t, node5, node4) + node5.Stop() + waitForEmptyRemoteNodeIndexer(t, node4) + t.Logf("Test #3 passed | Successfuly created persistent connection from validator Node4 to non-validator Node5") + + // Create a Validator Node6 + blsPriv6, err := bls.NewPrivateKey() + require.NoError(err) + node6 := spawnValidatorNodeProtocol2(t, 18005, "node6", blsPriv6) + node6 = startNode(t, node6) + defer node6.Stop() + + // Create a persistent connection from Node4 to Node6 + require.NoError(cc.CreateNonValidatorPersistentOutboundConnection(node6.Listeners[0].Addr().String())) + waitForValidatorConnection(t, node4, node6) + waitForValidatorConnection(t, node6, node4) + t.Logf("Test #4 passed | Successfuly created persistent connection from validator Node4 to validator Node6") +} diff --git a/integration_testing/connection_controller_utils_test.go b/integration_testing/connection_controller_utils_test.go new file mode 100644 index 000000000..4d5594634 --- /dev/null +++ b/integration_testing/connection_controller_utils_test.go @@ -0,0 +1,206 @@ +package integration_testing + +import ( + "fmt" + "github.com/deso-protocol/core/bls" + "github.com/deso-protocol/core/cmd" + "github.com/deso-protocol/core/lib" + "os" + "testing" +) + +func waitForValidatorConnection(t *testing.T, node1 *cmd.Node, node2 *cmd.Node) { + userAgentN1 := node1.Params.UserAgent + userAgentN2 := node2.Params.UserAgent + rnManagerN1 := node1.Server.GetConnectionController().GetRemoteNodeManager() + n1ValidatedN2 := func() bool { + if true != checkRemoteNodeIndexerUserAgent(rnManagerN1, userAgentN2, true, false, false) { + return false + } + rnFromN2 := getRemoteNodeWithUserAgent(node1, userAgentN2) + if rnFromN2 == nil { + return false + } + if !rnFromN2.IsHandshakeCompleted() { + return false + } + return true + } + waitForCondition(t, fmt.Sprintf("Waiting for Node (%s) to connect to outbound non-validator Node (%s)", userAgentN1, userAgentN2), n1ValidatedN2) +} + +func waitForNonValidatorOutboundConnection(t *testing.T, node1 *cmd.Node, node2 *cmd.Node) { + userAgentN1 := node1.Params.UserAgent + userAgentN2 := node2.Params.UserAgent + rnManagerN1 := node1.Server.GetConnectionController().GetRemoteNodeManager() + n1ValidatedN2 := func() bool { + if true != checkRemoteNodeIndexerUserAgent(rnManagerN1, userAgentN2, false, true, false) { + return false + } + rnFromN2 := getRemoteNodeWithUserAgent(node1, userAgentN2) + if rnFromN2 == nil { + return false + } + if !rnFromN2.IsHandshakeCompleted() { + return false + } + if rnFromN2.GetValidatorPublicKey() != nil { + return false + } + return true + } + waitForCondition(t, fmt.Sprintf("Waiting for Node (%s) to connect to outbound non-validator Node (%s)", userAgentN1, userAgentN2), n1ValidatedN2) +} + +func waitForNonValidatorInboundConnection(t *testing.T, node1 *cmd.Node, node2 *cmd.Node) { + userAgentN1 := node1.Params.UserAgent + userAgentN2 := node2.Params.UserAgent + rnManagerN1 := node1.Server.GetConnectionController().GetRemoteNodeManager() + n1ValidatedN2 := func() bool { + if true != checkRemoteNodeIndexerUserAgent(rnManagerN1, userAgentN2, false, false, true) { + return false + } + rnFromN2 := getRemoteNodeWithUserAgent(node1, userAgentN2) + if rnFromN2 == nil { + return false + } + if !rnFromN2.IsHandshakeCompleted() { + return false + } + if rnFromN2.GetValidatorPublicKey() != nil { + return false + } + return true + } + waitForCondition(t, fmt.Sprintf("Waiting for Node (%s) to connect to inbound non-validator Node (%s)", userAgentN1, userAgentN2), n1ValidatedN2) +} + +func waitForEmptyRemoteNodeIndexer(t *testing.T, node1 *cmd.Node) { + userAgentN1 := node1.Params.UserAgent + rnManagerN1 := node1.Server.GetConnectionController().GetRemoteNodeManager() + n1ValidatedN2 := func() bool { + if true != checkRemoteNodeIndexerEmpty(rnManagerN1) { + return false + } + return true + } + waitForCondition(t, fmt.Sprintf("Waiting for Node (%s) to disconnect from all RemoteNodes", userAgentN1), n1ValidatedN2) +} + +func waitForCountRemoteNodeIndexer(t *testing.T, node1 *cmd.Node, allCount int, validatorCount int, + nonValidatorOutboundCount int, nonValidatorInboundCount int) { + + userAgentN1 := node1.Params.UserAgent + rnManagerN1 := node1.Server.GetConnectionController().GetRemoteNodeManager() + n1ValidatedN2 := func() bool { + if true != checkRemoteNodeIndexerCount(rnManagerN1, allCount, validatorCount, nonValidatorOutboundCount, nonValidatorInboundCount) { + return false + } + return true + } + waitForCondition(t, fmt.Sprintf("Waiting for Node (%s) to have appropriate RemoteNodes counts", userAgentN1), n1ValidatedN2) +} + +func checkRemoteNodeIndexerUserAgent(manager *lib.RemoteNodeManager, userAgent string, validator bool, + nonValidatorOutbound bool, nonValidatorInbound bool) bool { + + if true != checkUserAgentInRemoteNodeList(userAgent, manager.GetAllRemoteNodes().GetAll()) { + return false + } + if validator != checkUserAgentInRemoteNodeList(userAgent, manager.GetValidatorIndex().GetAll()) { + return false + } + if nonValidatorOutbound != checkUserAgentInRemoteNodeList(userAgent, manager.GetNonValidatorOutboundIndex().GetAll()) { + return false + } + if nonValidatorInbound != checkUserAgentInRemoteNodeList(userAgent, manager.GetNonValidatorInboundIndex().GetAll()) { + return false + } + + return true +} + +func checkRemoteNodeIndexerCount(manager *lib.RemoteNodeManager, allCount int, validatorCount int, + nonValidatorOutboundCount int, nonValidatorInboundCount int) bool { + + if allCount != manager.GetAllRemoteNodes().Count() { + return false + } + if validatorCount != manager.GetValidatorIndex().Count() { + return false + } + if nonValidatorOutboundCount != manager.GetNonValidatorOutboundIndex().Count() { + return false + } + if nonValidatorInboundCount != manager.GetNonValidatorInboundIndex().Count() { + return false + } + + return true +} + +func checkRemoteNodeIndexerEmpty(manager *lib.RemoteNodeManager) bool { + if manager.GetAllRemoteNodes().Count() != 0 { + return false + } + if manager.GetValidatorIndex().Count() != 0 { + return false + } + if manager.GetNonValidatorOutboundIndex().Count() != 0 { + return false + } + if manager.GetNonValidatorInboundIndex().Count() != 0 { + return false + } + return true +} + +func checkUserAgentInRemoteNodeList(userAgent string, rnList []*lib.RemoteNode) bool { + for _, rn := range rnList { + if rn == nil { + continue + } + if rn.GetUserAgent() == userAgent { + return true + } + } + return false +} + +func getRemoteNodeWithUserAgent(node *cmd.Node, userAgent string) *lib.RemoteNode { + rnManager := node.Server.GetConnectionController().GetRemoteNodeManager() + rnList := rnManager.GetAllRemoteNodes().GetAll() + for _, rn := range rnList { + if rn.GetUserAgent() == userAgent { + return rn + } + } + return nil +} + +func spawnNonValidatorNodeProtocol2(t *testing.T, port uint32, id string) *cmd.Node { + dbDir := getDirectory(t) + t.Cleanup(func() { + os.RemoveAll(dbDir) + }) + config := generateConfig(t, port, dbDir, 10) + config.SyncType = lib.NodeSyncTypeBlockSync + node := cmd.NewNode(config) + node.Params.UserAgent = id + node.Params.ProtocolVersion = lib.ProtocolVersion2 + return node +} + +func spawnValidatorNodeProtocol2(t *testing.T, port uint32, id string, blsPriv *bls.PrivateKey) *cmd.Node { + dbDir := getDirectory(t) + t.Cleanup(func() { + os.RemoveAll(dbDir) + }) + config := generateConfig(t, port, dbDir, 10) + config.SyncType = lib.NodeSyncTypeBlockSync + config.PosValidatorSeed = blsPriv.ToString() + node := cmd.NewNode(config) + node.Params.UserAgent = id + node.Params.ProtocolVersion = lib.ProtocolVersion2 + return node +} diff --git a/integration_testing/hypersync_test.go b/integration_testing/hypersync_test.go index aad90ee0e..bc4c8a7c0 100644 --- a/integration_testing/hypersync_test.go +++ b/integration_testing/hypersync_test.go @@ -53,6 +53,7 @@ func TestSimpleHyperSync(t *testing.T) { //compareNodesByDB(t, node1, node2, 0) compareNodesByChecksum(t, node1, node2) fmt.Println("Databases match!") + bridge.Disconnect() node1.Stop() node2.Stop() } @@ -122,6 +123,8 @@ func TestHyperSyncFromHyperSyncedNode(t *testing.T) { compareNodesByChecksum(t, node2, node3) fmt.Println("Databases match!") + bridge12.Disconnect() + bridge23.Disconnect() node1.Stop() node2.Stop() node3.Stop() @@ -178,6 +181,7 @@ func TestSimpleHyperSyncRestart(t *testing.T) { compareNodesByChecksum(t, node1, node2) fmt.Println("Random restart successful! Random sync prefix was", syncPrefix) fmt.Println("Databases match!") + bridge.Disconnect() node1.Stop() node2.Stop() } @@ -255,6 +259,8 @@ func TestSimpleHyperSyncDisconnectWithSwitchingToNewPeer(t *testing.T) { compareNodesByChecksum(t, node1, node2) fmt.Println("Random restart successful! Random sync prefix was", syncPrefix) fmt.Println("Databases match!") + bridge12.Disconnect() + bridge23.Disconnect() node1.Stop() node2.Stop() node3.Stop() @@ -349,6 +355,7 @@ func TestArchivalMode(t *testing.T) { //compareNodesByDB(t, node1, node2, 0) compareNodesByChecksum(t, node1, node2) fmt.Println("Databases match!") + bridge.Disconnect() node1.Stop() node2.Stop() } @@ -406,6 +413,9 @@ func TestBlockSyncFromArchivalModeHyperSync(t *testing.T) { //compareNodesByDB(t, node1, node2, 0) compareNodesByChecksum(t, node1, node2) fmt.Println("Databases match!") + bridge12.Disconnect() + bridge23.Disconnect() node1.Stop() node2.Stop() + node3.Stop() } diff --git a/integration_testing/migrations_test.go b/integration_testing/migrations_test.go index b0a692b52..1419d483e 100644 --- a/integration_testing/migrations_test.go +++ b/integration_testing/migrations_test.go @@ -59,6 +59,7 @@ func TestEncoderMigrations(t *testing.T) { compareNodesByChecksum(t, node1, node2) fmt.Println("Databases match!") + bridge.Disconnect() node1.Stop() node2.Stop() } diff --git a/integration_testing/mining_test.go b/integration_testing/mining_test.go index 49a23333c..88de5e097 100644 --- a/integration_testing/mining_test.go +++ b/integration_testing/mining_test.go @@ -29,9 +29,7 @@ func TestRegtestMiner(t *testing.T) { // wait for node1 to sync blocks mineHeight := uint32(40) - listener := make(chan bool) - listenForBlockHeight(t, node1, mineHeight, listener) - <-listener + <-listenForBlockHeight(node1, mineHeight) node1.Stop() } diff --git a/integration_testing/tools.go b/integration_testing/tools.go index c73b82873..2f97e942d 100644 --- a/integration_testing/tools.go +++ b/integration_testing/tools.go @@ -150,7 +150,8 @@ func compareNodesByChecksum(t *testing.T, nodeA *cmd.Node, nodeB *cmd.Node) { // compareNodesByState will look through all state records in nodeA and nodeB databases and will compare them. // The nodes pass this comparison iff they have identical states. func compareNodesByState(t *testing.T, nodeA *cmd.Node, nodeB *cmd.Node, verbose int) { - compareNodesByStateWithPrefixList(t, nodeA.ChainDB, nodeB.ChainDB, lib.StatePrefixes.StatePrefixesList, verbose) + compareNodesByStateWithPrefixList(t, nodeA.Server.GetBlockchain().DB(), nodeB.Server.GetBlockchain().DB(), + lib.StatePrefixes.StatePrefixesList, verbose) } // compareNodesByDB will look through all records in nodeA and nodeB databases and will compare them. @@ -164,7 +165,8 @@ func compareNodesByDB(t *testing.T, nodeA *cmd.Node, nodeB *cmd.Node, verbose in } prefixList = append(prefixList, []byte{prefix}) } - compareNodesByStateWithPrefixList(t, nodeA.ChainDB, nodeB.ChainDB, prefixList, verbose) + compareNodesByStateWithPrefixList(t, nodeA.Server.GetBlockchain().DB(), nodeB.Server.GetBlockchain().DB(), + prefixList, verbose) } // compareNodesByDB will look through all records in nodeA and nodeB txindex databases and will compare them. @@ -386,25 +388,25 @@ func restartNode(t *testing.T, node *cmd.Node) *cmd.Node { } // listenForBlockHeight busy-waits until the node's block tip reaches provided height. -func listenForBlockHeight(t *testing.T, node *cmd.Node, height uint32, signal chan<- bool) { +func listenForBlockHeight(node *cmd.Node, height uint32) (_listener chan bool) { + listener := make(chan bool) ticker := time.NewTicker(1 * time.Millisecond) go func() { for { <-ticker.C if node.Server.GetBlockchain().BlockTip().Height >= height { - signal <- true + listener <- true break } } }() + return listener } // disconnectAtBlockHeight busy-waits until the node's block tip reaches provided height, and then disconnects // from the provided bridge. -func disconnectAtBlockHeight(t *testing.T, syncingNode *cmd.Node, bridge *ConnectionBridge, height uint32) { - listener := make(chan bool) - listenForBlockHeight(t, syncingNode, height, listener) - <-listener +func disconnectAtBlockHeight(syncingNode *cmd.Node, bridge *ConnectionBridge, height uint32) { + <-listenForBlockHeight(syncingNode, height) bridge.Disconnect() } @@ -414,7 +416,7 @@ func restartAtHeightAndReconnectNode(t *testing.T, node *cmd.Node, source *cmd.N height uint32) (_node *cmd.Node, _bridge *ConnectionBridge) { require := require.New(t) - disconnectAtBlockHeight(t, node, currentBridge, height) + disconnectAtBlockHeight(node, currentBridge, height) newNode := restartNode(t, node) // Wait after the restart. time.Sleep(1 * time.Second) @@ -475,3 +477,23 @@ func randomUint32Between(t *testing.T, min, max uint32) uint32 { randomHeight := uint32(randomNumber) % (max - min) return randomHeight + min } + +func waitForCondition(t *testing.T, id string, condition func() bool) { + signalChan := make(chan struct{}) + go func() { + for { + if condition() { + signalChan <- struct{}{} + return + } + time.Sleep(1 * time.Millisecond) + } + }() + + select { + case <-signalChan: + return + case <-time.After(5 * time.Second): + t.Fatalf("Condition timed out | %s", id) + } +} diff --git a/integration_testing/txindex_test.go b/integration_testing/txindex_test.go index aa13fd265..dfd398557 100644 --- a/integration_testing/txindex_test.go +++ b/integration_testing/txindex_test.go @@ -57,6 +57,7 @@ func TestSimpleTxIndex(t *testing.T) { compareNodesByDB(t, node1, node2, 0) compareNodesByTxIndex(t, node1, node2, 0) fmt.Println("Databases match!") + bridge.Disconnect() node1.Stop() node2.Stop() } diff --git a/lib/connection_controller.go b/lib/connection_controller.go new file mode 100644 index 000000000..fef9fa887 --- /dev/null +++ b/lib/connection_controller.go @@ -0,0 +1,329 @@ +package lib + +import ( + "fmt" + "github.com/btcsuite/btcd/addrmgr" + "github.com/btcsuite/btcd/wire" + "github.com/deso-protocol/core/bls" + "github.com/golang/glog" + "github.com/pkg/errors" + "net" + "strconv" +) + +// ConnectionController is a structure that oversees all connections to remote nodes. It is responsible for kicking off +// the initial connections a node makes to the network. It is also responsible for creating RemoteNodes from all +// successful outbound and inbound connections. The ConnectionController also ensures that the node is connected to +// the active validators, once the node reaches Proof of Stake. +// TODO: Document more in later PRs +type ConnectionController struct { + // The parameters we are initialized with. + params *DeSoParams + + cmgr *ConnectionManager + blsKeystore *BLSKeystore + + handshake *HandshakeController + + rnManager *RemoteNodeManager + + // The address manager keeps track of peer addresses we're aware of. When + // we need to connect to a new outbound peer, it chooses one of the addresses + // it's aware of at random and provides it to us. + AddrMgr *addrmgr.AddrManager + + // When --connectips is set, we don't connect to anything from the addrmgr. + connectIps []string + + // The target number of non-validator outbound remote nodes we want to have. We will disconnect remote nodes once + // we've exceeded this number of outbound connections. + targetNonValidatorOutboundRemoteNodes uint32 + // The target number of non-validator inbound remote nodes we want to have. We will disconnect remote nodes once + // we've exceeded this number of inbound connections. + targetNonValidatorInboundRemoteNodes uint32 + // When true, only one connection per IP is allowed. Prevents eclipse attacks + // among other things. + limitOneInboundRemoteNodePerIP bool +} + +func NewConnectionController(params *DeSoParams, cmgr *ConnectionManager, handshakeController *HandshakeController, + rnManager *RemoteNodeManager, blsKeystore *BLSKeystore, addrMgr *addrmgr.AddrManager, targetNonValidatorOutboundRemoteNodes uint32, + targetNonValidatorInboundRemoteNodes uint32, limitOneInboundConnectionPerIP bool) *ConnectionController { + + return &ConnectionController{ + params: params, + cmgr: cmgr, + blsKeystore: blsKeystore, + handshake: handshakeController, + rnManager: rnManager, + AddrMgr: addrMgr, + targetNonValidatorOutboundRemoteNodes: targetNonValidatorOutboundRemoteNodes, + targetNonValidatorInboundRemoteNodes: targetNonValidatorInboundRemoteNodes, + limitOneInboundRemoteNodePerIP: limitOneInboundConnectionPerIP, + } +} + +func (cc *ConnectionController) GetRemoteNodeManager() *RemoteNodeManager { + return cc.rnManager +} + +// ########################### +// ## Handlers (Peer, DeSoMessage) +// ########################### + +func (cc *ConnectionController) _handleDonePeerMessage(origin *Peer, desoMsg DeSoMessage) { + if desoMsg.GetMsgType() != MsgTypeDisconnectedPeer { + return + } + + cc.rnManager.DisconnectById(NewRemoteNodeId(origin.ID)) +} + +func (cc *ConnectionController) _handleAddrMessage(origin *Peer, desoMsg DeSoMessage) { + if desoMsg.GetMsgType() != MsgTypeAddr { + return + } + + // TODO +} + +func (cc *ConnectionController) _handleGetAddrMessage(origin *Peer, desoMsg DeSoMessage) { + if desoMsg.GetMsgType() != MsgTypeGetAddr { + return + } + + // TODO +} + +// _handleNewConnectionMessage is called when a new outbound or inbound connection is established. It is responsible +// for creating a RemoteNode from the connection and initiating the handshake. The incoming DeSoMessage is a control message. +func (cc *ConnectionController) _handleNewConnectionMessage(origin *Peer, desoMsg DeSoMessage) { + if desoMsg.GetMsgType() != MsgTypeNewConnection { + return + } + + msg, ok := desoMsg.(*MsgDeSoNewConnection) + if !ok { + return + } + + var remoteNode *RemoteNode + var err error + switch msg.Connection.GetConnectionType() { + case ConnectionTypeInbound: + remoteNode, err = cc.processInboundConnection(msg.Connection) + if err != nil { + glog.Errorf("ConnectionController.handleNewConnectionMessage: Problem handling inbound connection: %v", err) + msg.Connection.Close() + return + } + case ConnectionTypeOutbound: + remoteNode, err = cc.processOutboundConnection(msg.Connection) + if err != nil { + glog.Errorf("ConnectionController.handleNewConnectionMessage: Problem handling outbound connection: %v", err) + cc.cleanupFailedOutboundConnection(msg.Connection) + return + } + } + + // If we made it here, we have a valid remote node. We will now initiate the handshake. + cc.handshake.InitiateHandshake(remoteNode) +} + +func (cc *ConnectionController) cleanupFailedOutboundConnection(connection Connection) { + oc, ok := connection.(*outboundConnection) + if !ok { + return + } + + id := NewRemoteNodeId(oc.attemptId) + rn := cc.rnManager.GetRemoteNodeById(id) + if rn != nil { + cc.rnManager.Disconnect(rn) + } + cc.cmgr.RemoveAttemptedOutboundAddrs(oc.address) +} + +// ########################### +// ## Connections +// ########################### + +func (cc *ConnectionController) CreateValidatorConnection(ipStr string, publicKey *bls.PublicKey) error { + netAddr, err := cc.ConvertIPStringToNetAddress(ipStr) + if err != nil { + return err + } + return cc.rnManager.CreateValidatorConnection(netAddr, publicKey) +} + +func (cc *ConnectionController) CreateNonValidatorPersistentOutboundConnection(ipStr string) error { + netAddr, err := cc.ConvertIPStringToNetAddress(ipStr) + if err != nil { + return err + } + return cc.rnManager.CreateNonValidatorPersistentOutboundConnection(netAddr) +} + +func (cc *ConnectionController) CreateNonValidatorOutboundConnection(ipStr string) error { + netAddr, err := cc.ConvertIPStringToNetAddress(ipStr) + if err != nil { + return err + } + return cc.rnManager.CreateNonValidatorOutboundConnection(netAddr) +} + +func (cc *ConnectionController) SetTargetOutboundPeers(numPeers uint32) { + cc.targetNonValidatorOutboundRemoteNodes = numPeers +} + +func (cc *ConnectionController) enoughNonValidatorInboundConnections() bool { + return uint32(cc.rnManager.GetNonValidatorInboundIndex().Count()) >= cc.targetNonValidatorInboundRemoteNodes +} + +func (cc *ConnectionController) enoughNonValidatorOutboundConnections() bool { + return uint32(cc.rnManager.GetNonValidatorOutboundIndex().Count()) >= cc.targetNonValidatorOutboundRemoteNodes +} + +// processInboundConnection is called when a new inbound connection is established. At this point, the connection is not validated, +// nor is it assigned to a RemoteNode. This function is responsible for validating the connection and creating a RemoteNode from it. +// Once the RemoteNode is created, we will initiate handshake. +func (cc *ConnectionController) processInboundConnection(conn Connection) (*RemoteNode, error) { + var ic *inboundConnection + var ok bool + if ic, ok = conn.(*inboundConnection); !ok { + return nil, fmt.Errorf("ConnectionController.handleInboundConnection: Connection is not an inboundConnection") + } + + // Reject the peer if we have too many inbound connections already. + if cc.enoughNonValidatorInboundConnections() { + return nil, fmt.Errorf("ConnectionController.handleInboundConnection: Rejecting INBOUND peer (%s) due to max "+ + "inbound peers (%d) hit", ic.connection.RemoteAddr().String(), cc.targetNonValidatorInboundRemoteNodes) + } + + // If we want to limit inbound connections to one per IP address, check to make sure this address isn't already connected. + if cc.limitOneInboundRemoteNodePerIP && + cc.isDuplicateInboundIPAddress(ic.connection.RemoteAddr()) { + + return nil, fmt.Errorf("ConnectionController.handleInboundConnection: Rejecting INBOUND peer (%s) due to "+ + "already having an inbound connection from the same IP with limit_one_inbound_connection_per_ip set", + ic.connection.RemoteAddr().String()) + } + + na, err := cc.ConvertIPStringToNetAddress(ic.connection.RemoteAddr().String()) + if err != nil { + return nil, errors.Wrapf(err, "ConnectionController.handleInboundConnection: Problem calling "+ + "ConvertIPStringToNetAddress for addr: (%s)", ic.connection.RemoteAddr().String()) + } + + remoteNode, err := cc.rnManager.AttachInboundConnection(ic.connection, na) + if remoteNode == nil || err != nil { + return nil, errors.Wrapf(err, "ConnectionController.handleInboundConnection: Problem calling "+ + "AttachInboundConnection for addr: (%s)", ic.connection.RemoteAddr().String()) + } + + return remoteNode, nil +} + +// processOutboundConnection is called when a new outbound connection is established. At this point, the connection is not validated, +// nor is it assigned to a RemoteNode. This function is responsible for validating the connection and creating a RemoteNode from it. +// Once the RemoteNode is created, we will initiate handshake. +func (cc *ConnectionController) processOutboundConnection(conn Connection) (*RemoteNode, error) { + var oc *outboundConnection + var ok bool + if oc, ok = conn.(*outboundConnection); !ok { + return nil, fmt.Errorf("ConnectionController.handleOutboundConnection: Connection is not an outboundConnection") + } + + if oc.failed { + return nil, fmt.Errorf("ConnectionController.handleOutboundConnection: Failed to connect to peer (%s)", + oc.address.IP.String()) + } + + if !oc.isPersistent { + cc.AddrMgr.Connected(oc.address) + cc.AddrMgr.Good(oc.address) + } + + // if this is a non-persistent outbound peer, and we already have enough outbound peers, then don't bother adding this one. + if !oc.isPersistent && cc.enoughNonValidatorOutboundConnections() { + return nil, fmt.Errorf("ConnectionController.handleOutboundConnection: Connected to maximum number of outbound "+ + "peers (%d)", cc.targetNonValidatorOutboundRemoteNodes) + } + + // If this is a non-persistent outbound peer and the group key overlaps with another peer we're already connected to then + // abort mission. We only connect to one peer per IP group in order to prevent Sybil attacks. + if !oc.isPersistent && cc.cmgr.IsFromRedundantOutboundIPAddress(oc.address) { + return nil, fmt.Errorf("ConnectionController.handleOutboundConnection: Rejecting OUTBOUND NON-PERSISTENT "+ + "connection with redundant group key (%s).", addrmgr.GroupKey(oc.address)) + } + + na, err := cc.ConvertIPStringToNetAddress(oc.connection.RemoteAddr().String()) + if err != nil { + return nil, errors.Wrapf(err, "ConnectionController.handleOutboundConnection: Problem calling ipToNetAddr "+ + "for addr: (%s)", oc.connection.RemoteAddr().String()) + } + + remoteNode, err := cc.rnManager.AttachOutboundConnection(oc.connection, na, oc.attemptId, oc.isPersistent) + if remoteNode == nil || err != nil { + return nil, errors.Wrapf(err, "ConnectionController.handleOutboundConnection: Problem calling rnManager.AttachOutboundConnection "+ + "for addr: (%s)", oc.connection.RemoteAddr().String()) + } + return remoteNode, nil +} + +func (cc *ConnectionController) ConvertIPStringToNetAddress(ipStr string) (*wire.NetAddress, error) { + netAddr, err := IPToNetAddr(ipStr, cc.AddrMgr, cc.params) + if err != nil { + return nil, errors.Wrapf(err, + "ConnectionController.ConvertIPStringToNetAddress: Problem parsing "+ + "ipString to wire.NetAddress") + } + if netAddr == nil { + return nil, fmt.Errorf("ConnectionController.ConvertIPStringToNetAddress: " + + "address was nil after parsing") + } + return netAddr, nil +} + +func IPToNetAddr(ipStr string, addrMgr *addrmgr.AddrManager, params *DeSoParams) (*wire.NetAddress, error) { + port := params.DefaultSocketPort + host, portstr, err := net.SplitHostPort(ipStr) + if err != nil { + // No port specified so leave port=default and set + // host to the ipStr. + host = ipStr + } else { + pp, err := strconv.ParseUint(portstr, 10, 16) + if err != nil { + return nil, errors.Wrapf(err, "IPToNetAddr: Can not parse port from %s for ip", ipStr) + } + port = uint16(pp) + } + netAddr, err := addrMgr.HostToNetAddress(host, port, 0) + if err != nil { + return nil, errors.Wrapf(err, "IPToNetAddr: Can not parse port from %s for ip", ipStr) + } + return netAddr, nil +} + +func (cc *ConnectionController) isDuplicateInboundIPAddress(addr net.Addr) bool { + netAddr, err := IPToNetAddr(addr.String(), cc.AddrMgr, cc.params) + if err != nil { + // Return true in case we have an error. We do this because it + // will result in the peer connection not being accepted, which + // is desired in this case. + glog.Warningf(errors.Wrapf(err, + "ConnectionController.isDuplicateInboundIPAddress: Problem parsing "+ + "net.Addr to wire.NetAddress so marking as redundant and not "+ + "making connection").Error()) + return true + } + if netAddr == nil { + glog.Warningf("ConnectionController.isDuplicateInboundIPAddress: " + + "address was nil after parsing so marking as redundant and not " + + "making connection") + return true + } + + return cc.cmgr.IsDuplicateInboundIPAddress(netAddr) +} diff --git a/lib/connection_manager.go b/lib/connection_manager.go index 38924bdf9..7c6f510ac 100644 --- a/lib/connection_manager.go +++ b/lib/connection_manager.go @@ -4,7 +4,6 @@ import ( "fmt" "math" "net" - "strconv" "sync" "sync/atomic" "time" @@ -14,7 +13,6 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/decred/dcrd/lru" "github.com/golang/glog" - "github.com/pkg/errors" ) // connection_manager.go contains most of the logic for creating and managing @@ -36,24 +34,10 @@ type ConnectionManager struct { // doesn't need a reference to the Server object. But for now we keep things lazy. srv *Server - // When --connectips is set, we don't connect to anything from the addrmgr. - connectIps []string - - // The address manager keeps track of peer addresses we're aware of. When - // we need to connect to a new outbound peer, it chooses one of the addresses - // it's aware of at random and provides it to us. - AddrMgr *addrmgr.AddrManager // The interfaces we listen on for new incoming connections. listeners []net.Listener // The parameters we are initialized with. params *DeSoParams - // The target number of outbound peers we want to have. - targetOutboundPeers uint32 - // The maximum number of inbound peers we allow. - maxInboundPeers uint32 - // When true, only one connection per IP is allowed. Prevents eclipse attacks - // among other things. - limitOneInboundConnectionPerIP bool // When --hypersync is set to true we will attempt fast block synchronization HyperSync bool @@ -136,10 +120,8 @@ type ConnectionManager struct { } func NewConnectionManager( - _params *DeSoParams, _addrMgr *addrmgr.AddrManager, _listeners []net.Listener, + _params *DeSoParams, _listeners []net.Listener, _connectIps []string, _timeSource chainlib.MedianTimeSource, - _targetOutboundPeers uint32, _maxInboundPeers uint32, - _limitOneInboundConnectionPerIP bool, _hyperSync bool, _syncType NodeSyncType, _stallTimeoutSeconds uint64, @@ -150,16 +132,13 @@ func NewConnectionManager( ValidateHyperSyncFlags(_hyperSync, _syncType) return &ConnectionManager{ - srv: _srv, - params: _params, - AddrMgr: _addrMgr, - listeners: _listeners, - connectIps: _connectIps, + srv: _srv, + params: _params, + listeners: _listeners, // We keep track of the last N nonces we've sent in order to detect // self connections. sentNonces: lru.NewCache(1000), timeSource: _timeSource, - //newestBlock: _newestBlock, // Initialize the peer data structures. @@ -176,15 +155,13 @@ func NewConnectionManager( newPeerChan: make(chan *Peer, 100), donePeerChan: make(chan *Peer, 100), outboundConnectionChan: make(chan *outboundConnection, 100), + inboundConnectionChan: make(chan *inboundConnection, 100), - targetOutboundPeers: _targetOutboundPeers, - maxInboundPeers: _maxInboundPeers, - limitOneInboundConnectionPerIP: _limitOneInboundConnectionPerIP, - HyperSync: _hyperSync, - SyncType: _syncType, - serverMessageQueue: _serverMessageQueue, - stallTimeoutSeconds: _stallTimeoutSeconds, - minFeeRateNanosPerKB: _minFeeRateNanosPerKB, + HyperSync: _hyperSync, + SyncType: _syncType, + serverMessageQueue: _serverMessageQueue, + stallTimeoutSeconds: _stallTimeoutSeconds, + minFeeRateNanosPerKB: _minFeeRateNanosPerKB, } } @@ -224,40 +201,6 @@ func (cmgr *ConnectionManager) subFromGroupKey(na *wire.NetAddress) { cmgr.mtxOutboundConnIPGroups.Unlock() } -func (cmgr *ConnectionManager) getRandomAddr() *wire.NetAddress { - for tries := 0; tries < 100; tries++ { - addr := cmgr.AddrMgr.GetAddress() - if addr == nil { - glog.V(2).Infof("ConnectionManager.getRandomAddr: addr from GetAddressWithExclusions was nil") - break - } - - // Lock the address map since multiple threads will be trying to read - // and modify it at the same time. - cmgr.mtxAddrsMaps.RLock() - ok := cmgr.connectedOutboundAddrs[addrmgr.NetAddressKey(addr.NetAddress())] - cmgr.mtxAddrsMaps.RUnlock() - if ok { - glog.V(2).Infof("ConnectionManager.getRandomAddr: Not choosing already connected address %v:%v", addr.NetAddress().IP, addr.NetAddress().Port) - continue - } - - // We can only have one outbound address per /16. This is similar to - // Bitcoin and we do it to prevent Sybil attacks. - if cmgr.IsFromRedundantOutboundIPAddress(addr.NetAddress()) { - glog.V(2).Infof("ConnectionManager.getRandomAddr: Not choosing address due to redundant group key %v:%v", addr.NetAddress().IP, addr.NetAddress().Port) - continue - } - - glog.V(2).Infof("ConnectionManager.getRandomAddr: Returning %v:%v at %d iterations", - addr.NetAddress().IP, addr.NetAddress().Port, tries) - return addr.NetAddress() - } - - glog.V(2).Infof("ConnectionManager.getRandomAddr: Returning nil") - return nil -} - func _delayRetry(retryCount uint64, persistentAddrForLogging *wire.NetAddress, unit time.Duration) (_retryDuration time.Duration) { // No delay if we haven't tried yet or if the number of retries isn't positive. if retryCount <= 0 { @@ -276,42 +219,6 @@ func _delayRetry(retryCount uint64, persistentAddrForLogging *wire.NetAddress, u return retryDelay } -func (cmgr *ConnectionManager) enoughOutboundPeers() bool { - val := atomic.LoadUint32(&cmgr.numOutboundPeers) - if val > cmgr.targetOutboundPeers { - glog.Errorf("enoughOutboundPeers: Connected to too many outbound "+ - "peers: (%d). Should be "+ - "no more than (%d).", val, cmgr.targetOutboundPeers) - return true - } - - if val == cmgr.targetOutboundPeers { - return true - } - return false -} - -func IPToNetAddr(ipStr string, addrMgr *addrmgr.AddrManager, params *DeSoParams) (*wire.NetAddress, error) { - port := params.DefaultSocketPort - host, portstr, err := net.SplitHostPort(ipStr) - if err != nil { - // No port specified so leave port=default and set - // host to the ipStr. - host = ipStr - } else { - pp, err := strconv.ParseUint(portstr, 10, 16) - if err != nil { - return nil, errors.Wrapf(err, "IPToNetAddr: Can not parse port from %s for ip", ipStr) - } - port = uint16(pp) - } - netAddr, err := addrMgr.HostToNetAddress(host, port, 0) - if err != nil { - return nil, errors.Wrapf(err, "IPToNetAddr: Can not parse port from %s for ip", ipStr) - } - return netAddr, nil -} - func (cmgr *ConnectionManager) IsConnectedOutboundIpAddress(netAddr *wire.NetAddress) bool { cmgr.mtxAddrsMaps.RLock() defer cmgr.mtxAddrsMaps.RUnlock() @@ -338,13 +245,15 @@ func (cmgr *ConnectionManager) RemoveAttemptedOutboundAddrs(netAddr *wire.NetAdd // DialPersistentOutboundConnection attempts to connect to a persistent peer. func (cmgr *ConnectionManager) DialPersistentOutboundConnection(persistentAddr *wire.NetAddress, attemptId uint64) (_attemptId uint64) { - glog.V(2).Infof("ConnectionManager.DialPersistentOutboundConnection: Connecting to peer %v", persistentAddr.IP.String()) + glog.V(2).Infof("ConnectionManager.DialPersistentOutboundConnection: Connecting to peer (IP=%v, Port=%v)", + persistentAddr.IP.String(), persistentAddr.Port) return cmgr._dialOutboundConnection(persistentAddr, attemptId, true) } // DialOutboundConnection attempts to connect to a non-persistent peer. func (cmgr *ConnectionManager) DialOutboundConnection(addr *wire.NetAddress, attemptId uint64) { - glog.V(2).Infof("ConnectionManager.ConnectOutboundConnection: Connecting to peer %v", addr.IP.String()) + glog.V(2).Infof("ConnectionManager.ConnectOutboundConnection: Connecting to peer (IP=%v, Port=%v)", + addr.IP.String(), addr.Port) cmgr._dialOutboundConnection(addr, attemptId, false) } @@ -400,7 +309,7 @@ func (cmgr *ConnectionManager) ConnectPeer(id uint64, conn net.Conn, na *wire.Ne return peer } -func (cmgr *ConnectionManager) IsFromRedundantInboundIPAddress(netAddr *wire.NetAddress) bool { +func (cmgr *ConnectionManager) IsDuplicateInboundIPAddress(netAddr *wire.NetAddress) bool { cmgr.mtxPeerMaps.RLock() defer cmgr.mtxPeerMaps.RUnlock() @@ -412,7 +321,7 @@ func (cmgr *ConnectionManager) IsFromRedundantInboundIPAddress(netAddr *wire.Net // nodes on a local machine. // TODO: Should this be a flag? if net.IP([]byte{127, 0, 0, 1}).Equal(netAddr.IP) { - glog.V(1).Infof("ConnectionManager._isFromRedundantInboundIPAddress: Allowing " + + glog.V(1).Infof("ConnectionManager.IsDuplicateInboundIPAddress: Allowing " + "localhost IP address to connect") return false } diff --git a/lib/constants.go b/lib/constants.go index 460b8dadb..0a525332e 100644 --- a/lib/constants.go +++ b/lib/constants.go @@ -498,6 +498,10 @@ func (pvt ProtocolVersionType) Before(version ProtocolVersionType) bool { return pvt.ToUint64() < version.ToUint64() } +func (pvt ProtocolVersionType) After(version ProtocolVersionType) bool { + return pvt.ToUint64() > version.ToUint64() +} + // DeSoParams defines the full list of possible parameters for the // DeSo network. type DeSoParams struct { @@ -564,6 +568,8 @@ type DeSoParams struct { DialTimeout time.Duration // The amount of time we wait to receive a version message from a peer. VersionNegotiationTimeout time.Duration + // The amount of time we wait to receive a verack message from a peer. + VerackNegotiationTimeout time.Duration // The maximum number of addresses to broadcast to peers. MaxAddressesToBroadcast uint32 @@ -1025,6 +1031,7 @@ var DeSoMainnetParams = DeSoParams{ DialTimeout: 30 * time.Second, VersionNegotiationTimeout: 30 * time.Second, + VerackNegotiationTimeout: 30 * time.Second, MaxAddressesToBroadcast: 10, @@ -1296,6 +1303,7 @@ var DeSoTestnetParams = DeSoParams{ DialTimeout: 30 * time.Second, VersionNegotiationTimeout: 30 * time.Second, + VerackNegotiationTimeout: 30 * time.Second, MaxAddressesToBroadcast: 10, diff --git a/lib/pos_handshake_controller.go b/lib/handshake_controller.go similarity index 96% rename from lib/pos_handshake_controller.go rename to lib/handshake_controller.go index 6f4804f2e..bde07745a 100644 --- a/lib/pos_handshake_controller.go +++ b/lib/handshake_controller.go @@ -5,12 +5,15 @@ import ( "github.com/decred/dcrd/lru" "github.com/golang/glog" "math" + "sync" ) // HandshakeController is a structure that handles the handshake process with remote nodes. It is the entry point for // initiating a handshake with a remote node. It is also responsible for handling version/verack messages from remote // nodes. And for handling the handshake complete control message. type HandshakeController struct { + mtxHandshakeComplete sync.Mutex + rnManager *RemoteNodeManager usedNonces lru.Cache } @@ -37,6 +40,10 @@ func (hc *HandshakeController) InitiateHandshake(rn *RemoteNode) { // _handleHandshakeCompleteMessage handles HandshakeComplete control messages, sent by RemoteNodes. func (hc *HandshakeController) _handleHandshakeCompleteMessage(origin *Peer, desoMsg DeSoMessage) { + // Prevent race conditions while handling handshake complete messages. + hc.mtxHandshakeComplete.Lock() + defer hc.mtxHandshakeComplete.Unlock() + if desoMsg.GetMsgType() != MsgTypePeerHandshakeComplete { return } diff --git a/lib/network.go b/lib/network.go index 80d412c4f..75474ea7c 100644 --- a/lib/network.go +++ b/lib/network.go @@ -1543,8 +1543,7 @@ func (msg *MsgDeSoPong) FromBytes(data []byte) error { type ServiceFlag uint64 const ( - // SFFullNodeDeprecated is deprecated, and set on all nodes by default - // now. We basically split it into SFHyperSync and SFArchivalMode. + // SFFullNodeDeprecated is deprecated, and set on all nodes by default now. SFFullNodeDeprecated ServiceFlag = 1 << 0 // SFHyperSync is a flag used to indicate that the peer supports hyper sync. SFHyperSync ServiceFlag = 1 << 1 @@ -1555,6 +1554,10 @@ const ( SFPosValidator ServiceFlag = 1 << 3 ) +func (sf ServiceFlag) HasService(serviceFlag ServiceFlag) bool { + return sf&serviceFlag == serviceFlag +} + type MsgDeSoVersion struct { // What is the current version we're on? Version uint64 @@ -1952,10 +1955,6 @@ func (msg *MsgDeSoVerack) EncodeVerackV0() ([]byte, error) { } func (msg *MsgDeSoVerack) EncodeVerackV1() ([]byte, error) { - if msg.PublicKey == nil || msg.Signature == nil { - return nil, fmt.Errorf("MsgDeSoVerack.EncodeVerackV1: PublicKey and Signature must be set for V1 message") - } - retBytes := []byte{} // Version diff --git a/lib/network_test.go b/lib/network_test.go index 8a971f75a..c0f721a99 100644 --- a/lib/network_test.go +++ b/lib/network_test.go @@ -93,12 +93,12 @@ func TestVerackV1(t *testing.T) { require := require.New(t) networkType := NetworkType_MAINNET - var buf bytes.Buffer + var buf1, buf2 bytes.Buffer nonceReceived := uint64(12345678910) nonceSent := nonceReceived + 1 tstamp := uint64(2345678910) - // First, test that nil public key and signature are not allowed. + // First, test that nil public key and signature are allowed. msg := &MsgDeSoVerack{ Version: VerackVersion1, NonceReceived: nonceReceived, @@ -107,8 +107,8 @@ func TestVerackV1(t *testing.T) { PublicKey: nil, Signature: nil, } - _, err := WriteMessage(&buf, msg, networkType) - require.Error(err) + _, err := WriteMessage(&buf1, msg, networkType) + require.NoError(err) payload := append(UintToBuf(nonceReceived), UintToBuf(nonceSent)...) payload = append(payload, UintToBuf(tstamp)...) hash := sha3.Sum256(payload) @@ -118,10 +118,10 @@ func TestVerackV1(t *testing.T) { msg.PublicKey = priv.PublicKey() msg.Signature, err = priv.Sign(hash[:]) require.NoError(err) - _, err = WriteMessage(&buf, msg, networkType) + _, err = WriteMessage(&buf2, msg, networkType) require.NoError(err) - verBytes := buf.Bytes() + verBytes := buf2.Bytes() testMsg, _, err := ReadMessage(bytes.NewReader(verBytes), networkType) require.NoError(err) require.Equal(msg, testMsg) diff --git a/lib/peer.go b/lib/peer.go index 98d2c135e..0af9aa0b7 100644 --- a/lib/peer.go +++ b/lib/peer.go @@ -1192,11 +1192,12 @@ func (pp *Peer) Start() { // If the address manager needs more addresses, then send a GetAddr message // to the peer. This is best-effort. if pp.cmgr != nil { - if pp.cmgr.AddrMgr.NeedMoreAddresses() { + // TODO: Move this to ConnectionController. + /*if pp.cmgr.AddrMgr.NeedMoreAddresses() { go func() { pp.QueueMessage(&MsgDeSoGetAddr{}) }() - } + }*/ } // Send our verack message now that the IO processing machinery has started. diff --git a/lib/remote_node.go b/lib/remote_node.go index a357118a2..f2d849a36 100644 --- a/lib/remote_node.go +++ b/lib/remote_node.go @@ -191,6 +191,10 @@ func (rn *RemoteNode) GetValidatorPublicKey() *bls.PublicKey { return rn.validatorPublicKey } +func (rn *RemoteNode) GetServiceFlag() ServiceFlag { + return rn.handshakeMetadata.serviceFlag +} + func (rn *RemoteNode) GetUserAgent() string { return rn.handshakeMetadata.userAgent } @@ -223,7 +227,11 @@ func (rn *RemoteNode) IsValidator() bool { if !rn.IsHandshakeCompleted() { return false } - return rn.GetValidatorPublicKey() != nil + return rn.hasValidatorServiceFlag() +} + +func (rn *RemoteNode) hasValidatorServiceFlag() bool { + return rn.GetServiceFlag().HasService(SFPosValidator) } // DialOutboundConnection dials an outbound connection to the provided netAddr. @@ -359,7 +367,7 @@ func (rn *RemoteNode) sendVersionMessage(nonce uint64) error { return nil } -// newVersionMessage returns a new version message that can be sent to a RemoteNode peer. The message will contain the +// newVersionMessage returns a new version message that can be sent to a RemoteNode. The message will contain the // nonce that is passed in as an argument. func (rn *RemoteNode) newVersionMessage(nonce uint64) *MsgDeSoVersion { ver := NewMessage(MsgTypeVersion).(*MsgDeSoVersion) @@ -412,6 +420,12 @@ func (rn *RemoteNode) HandleVersionMessage(verMsg *MsgDeSoVersion, responseNonce vMeta := rn.handshakeMetadata // Record the version the peer is using. vMeta.advertisedProtocolVersion = NewProtocolVersionType(verMsg.Version) + // Make sure the latest supported protocol version is ProtocolVersion2. + if vMeta.advertisedProtocolVersion.After(ProtocolVersion2) { + return fmt.Errorf("RemoteNode.HandleVersionMessage: Requesting disconnect for id: (%v) "+ + "protocol version too high. Peer version: %v, max version: %v", rn.id, verMsg.Version, ProtocolVersion2) + } + // Decide on the protocol version to use for this connection. negotiatedVersion := rn.params.ProtocolVersion if verMsg.Version < rn.params.ProtocolVersion.ToUint64() { @@ -430,6 +444,17 @@ func (rn *RemoteNode) HandleVersionMessage(verMsg *MsgDeSoVersion, responseNonce // Record the services the peer is advertising. vMeta.serviceFlag = verMsg.Services + // If the RemoteNode was connected with an expectation of being a validator, make sure that its advertised ServiceFlag + // indicates that it is a validator. + if !rn.hasValidatorServiceFlag() && rn.validatorPublicKey != nil { + return fmt.Errorf("RemoteNode.HandleVersionMessage: Requesting disconnect for id: (%v). "+ + "Expected validator, but received invalid ServiceFlag: %v", rn.id, verMsg.Services) + } + // If the RemoteNode is on ProtocolVersion1, then it must not have the validator service flag set. + if rn.hasValidatorServiceFlag() && vMeta.advertisedProtocolVersion.Before(ProtocolVersion2) { + return fmt.Errorf("RemoteNode.HandleVersionMessage: Requesting disconnect for id: (%v). "+ + "RemoteNode has SFValidator service flag, but doesn't have ProtocolVersion2 or later", rn.id) + } // Record the tstamp sent by the peer and calculate the time offset. timeConnected := time.Unix(verMsg.TstampSecs, 0) @@ -450,7 +475,7 @@ func (rn *RemoteNode) HandleVersionMessage(verMsg *MsgDeSoVersion, responseNonce vMeta.minTxFeeRateNanosPerKB = verMsg.MinFeeRateNanosPerKB // Respond to the version message if this is an inbound peer. - if !rn.IsOutbound() { + if rn.IsInbound() { if err := rn.sendVersionMessage(responseNonce); err != nil { return errors.Wrapf(err, "RemoteNode.HandleVersionMessage: Problem sending version message to peer (id= %d)", rn.id) } @@ -460,7 +485,7 @@ func (rn *RemoteNode) HandleVersionMessage(verMsg *MsgDeSoVersion, responseNonce // peer's verack message even if it is an inbound peer. Instead, we just send the verack message right away. // Set the latest time by which we should receive a verack message from the peer. - verackTimeExpected := time.Now().Add(rn.params.VersionNegotiationTimeout) + verackTimeExpected := time.Now().Add(rn.params.VerackNegotiationTimeout) rn.verackTimeExpected = &verackTimeExpected if err := rn.sendVerack(); err != nil { return errors.Wrapf(err, "RemoteNode.HandleVersionMessage: Problem sending verack message to peer (id= %d)", rn.id) @@ -496,7 +521,6 @@ func (rn *RemoteNode) newVerackMessage() (*MsgDeSoVerack, error) { verack.Version = VerackVersion0 verack.NonceReceived = vMeta.versionNonceReceived case ProtocolVersion2: - // FIXME: resolve the non-validator - validator handshake issues on protocol version 2. // For protocol version 2, we need to send the nonce we received from the peer in their version message. // We also need to send our own nonce, which we generate for our version message. In addition, we need to // send a current timestamp (in microseconds). We then sign the tuple of (nonceReceived, nonceSent, tstampMicro) @@ -507,6 +531,10 @@ func (rn *RemoteNode) newVerackMessage() (*MsgDeSoVerack, error) { verack.NonceSent = vMeta.versionNonceSent tstampMicro := uint64(time.Now().UnixMicro()) verack.TstampMicro = tstampMicro + // If the RemoteNode is not a validator, then we don't need to sign the verack message. + if !rn.nodeServices.HasService(SFPosValidator) { + break + } verack.PublicKey = rn.keystore.GetSigner().GetPublicKey() verack.Signature, err = rn.keystore.GetSigner().SignPoSValidatorHandshake(verack.NonceSent, verack.NonceReceived, tstampMicro) if err != nil { @@ -599,6 +627,11 @@ func (rn *RemoteNode) validateVerackPoS(vrkMsg *MsgDeSoVerack) error { "verack timestamp too far in the past. Time now: %v, verack timestamp: %v", rn.id, timeNowMicro, vrkMsg.TstampMicro) } + // If the RemoteNode is not a validator, then we don't need to verify the verack message's signature. + if !rn.hasValidatorServiceFlag() { + return nil + } + // Make sure the verack message's public key and signature are not nil. if vrkMsg.PublicKey == nil || vrkMsg.Signature == nil { return fmt.Errorf("RemoteNode.validateVerackPoS: Requesting disconnect for id: (%v) "+ @@ -617,7 +650,7 @@ func (rn *RemoteNode) validateVerackPoS(vrkMsg *MsgDeSoVerack) error { "verack signature verification failed", rn.id) } - if rn.validatorPublicKey != nil || rn.validatorPublicKey.Serialize() != vrkMsg.PublicKey.Serialize() { + if rn.validatorPublicKey != nil && rn.validatorPublicKey.Serialize() != vrkMsg.PublicKey.Serialize() { return fmt.Errorf("RemoteNode.validateVerackPoS: Requesting disconnect for id: (%v) "+ "verack public key mismatch; message: %v; expected: %v", rn.id, vrkMsg.PublicKey, rn.validatorPublicKey) } diff --git a/lib/remote_node_manager.go b/lib/remote_node_manager.go index a41fe4606..fb269d072 100644 --- a/lib/remote_node_manager.go +++ b/lib/remote_node_manager.go @@ -5,8 +5,10 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/deso-protocol/core/bls" "github.com/deso-protocol/core/collections" + "github.com/golang/glog" "github.com/pkg/errors" "net" + "sync" "sync/atomic" ) @@ -14,6 +16,8 @@ import ( // and stopping remote node connections. It is also responsible for organizing the remote nodes into indices for easy // access, through the RemoteNodeIndexer. type RemoteNodeManager struct { + mtx sync.Mutex + // remoteNodeIndexer is a structure that stores and indexes all created remote nodes. remoteNodeIndexer *RemoteNodeIndexer @@ -62,13 +66,19 @@ func (manager *RemoteNodeManager) ProcessCompletedHandshake(remoteNode *RemoteNo if remoteNode.IsValidator() { manager.SetValidator(remoteNode) + manager.UnsetNonValidator(remoteNode) } else { + manager.UnsetValidator(remoteNode) manager.SetNonValidator(remoteNode) } manager.srv.HandleAcceptedPeer(remoteNode.GetPeer()) } func (manager *RemoteNodeManager) Disconnect(rn *RemoteNode) { + if rn == nil { + return + } + glog.V(2).Infof("RemoteNodeManager.Disconnect: Disconnecting from remote node %v", rn.GetId()) rn.Disconnect() manager.removeRemoteNodeFromIndexer(rn) } @@ -83,17 +93,29 @@ func (manager *RemoteNodeManager) DisconnectById(id RemoteNodeId) { } func (manager *RemoteNodeManager) removeRemoteNodeFromIndexer(rn *RemoteNode) { + manager.mtx.Lock() + defer manager.mtx.Unlock() + if rn == nil { return } indexer := manager.remoteNodeIndexer indexer.GetAllRemoteNodes().Remove(rn.GetId()) - if rn.validatorPublicKey != nil { - indexer.GetValidatorIndex().Remove(rn.validatorPublicKey.Serialize()) - } indexer.GetNonValidatorOutboundIndex().Remove(rn.GetId()) indexer.GetNonValidatorInboundIndex().Remove(rn.GetId()) + + // Try to evict the remote node from the validator index. If the remote node is not a validator, then there is nothing to do. + if rn.GetValidatorPublicKey() == nil { + return + } + // Only remove from the validator index if the fetched remote node is the same as the one we are trying to remove. + // Otherwise, we could have a fun edge-case where a duplicated validator connection ends up removing an + // existing validator connection from the index. + fetchedRn, ok := indexer.GetValidatorIndex().Get(rn.GetValidatorPublicKey().Serialize()) + if ok && fetchedRn.GetId() == rn.GetId() { + indexer.GetValidatorIndex().Remove(rn.GetValidatorPublicKey().Serialize()) + } } func (manager *RemoteNodeManager) SendMessage(rn *RemoteNode, desoMessage DeSoMessage) error { @@ -113,6 +135,10 @@ func (manager *RemoteNodeManager) CreateValidatorConnection(netAddr *wire.NetAdd return fmt.Errorf("RemoteNodeManager.CreateValidatorConnection: netAddr or public key is nil") } + if _, ok := manager.GetValidatorIndex().Get(publicKey.Serialize()); ok { + return fmt.Errorf("RemoteNodeManager.CreateValidatorConnection: RemoteNode already exists for public key: %v", publicKey) + } + remoteNode := manager.newRemoteNode(publicKey) if err := remoteNode.DialPersistentOutboundConnection(netAddr); err != nil { return errors.Wrapf(err, "RemoteNodeManager.CreateValidatorConnection: Problem calling DialPersistentOutboundConnection "+ @@ -190,6 +216,9 @@ func (manager *RemoteNodeManager) AttachOutboundConnection(conn net.Conn, na *wi // ########################### func (manager *RemoteNodeManager) setRemoteNode(rn *RemoteNode) { + manager.mtx.Lock() + defer manager.mtx.Unlock() + if rn == nil { return } @@ -198,36 +227,39 @@ func (manager *RemoteNodeManager) setRemoteNode(rn *RemoteNode) { } func (manager *RemoteNodeManager) SetNonValidator(rn *RemoteNode) { + manager.mtx.Lock() + defer manager.mtx.Unlock() + if rn == nil { return } if rn.IsOutbound() { manager.GetNonValidatorOutboundIndex().Set(rn.GetId(), rn) - } else if rn.IsInbound() { - manager.GetNonValidatorInboundIndex().Set(rn.GetId(), rn) } else { - manager.Disconnect(rn) - return + manager.GetNonValidatorInboundIndex().Set(rn.GetId(), rn) } - - manager.UnsetValidator(rn) } func (manager *RemoteNodeManager) SetValidator(remoteNode *RemoteNode) { + manager.mtx.Lock() + defer manager.mtx.Unlock() + if remoteNode == nil { return } pk := remoteNode.GetValidatorPublicKey() if pk == nil { - manager.Disconnect(remoteNode) return } manager.GetValidatorIndex().Set(pk.Serialize(), remoteNode) } func (manager *RemoteNodeManager) UnsetValidator(remoteNode *RemoteNode) { + manager.mtx.Lock() + defer manager.mtx.Unlock() + if remoteNode == nil { return } @@ -240,16 +272,17 @@ func (manager *RemoteNodeManager) UnsetValidator(remoteNode *RemoteNode) { } func (manager *RemoteNodeManager) UnsetNonValidator(rn *RemoteNode) { + manager.mtx.Lock() + defer manager.mtx.Unlock() + if rn == nil { return } if rn.IsOutbound() { manager.GetNonValidatorOutboundIndex().Remove(rn.GetId()) - } else if rn.IsInbound() { - manager.GetNonValidatorInboundIndex().Remove(rn.GetId()) } else { - manager.Disconnect(rn) + manager.GetNonValidatorInboundIndex().Remove(rn.GetId()) } } diff --git a/lib/server.go b/lib/server.go index d1c82e5b3..d4c371955 100644 --- a/lib/server.go +++ b/lib/server.go @@ -62,7 +62,9 @@ type Server struct { eventManager *EventManager TxIndex *TXIndex + handshakeController *HandshakeController // fastHotStuffEventLoop consensus.FastHotStuffEventLoop + connectionController *ConnectionController // posMempool *PosMemPool TODO: Add the mempool later // All messages received from peers get sent from the ConnectionManager to the @@ -175,6 +177,10 @@ func (srv *Server) ResetRequestQueues() { srv.requestedTransactionsMap = make(map[BlockHash]*GetDataRequestInfo) } +func (srv *Server) GetConnectionController() *ConnectionController { + return srv.connectionController +} + // dataLock must be acquired for writing before calling this function. func (srv *Server) _removeRequest(hash *BlockHash) { // Just be lazy and remove the hash from everything indiscriminately to @@ -445,8 +451,7 @@ func NewServer( // Create a new connection manager but note that it won't be initialized until Start(). _incomingMessages := make(chan *ServerMessage, (_targetOutboundPeers+_maxInboundPeers)*3) _cmgr := NewConnectionManager( - _params, _desoAddrMgr, _listeners, _connectIps, timesource, - _targetOutboundPeers, _maxInboundPeers, _limitOneInboundConnectionPerIP, + _params, _listeners, _connectIps, timesource, _hyperSync, _syncType, _stallTimeoutSeconds, _minFeeRateNanosPerKB, _incomingMessages, srv) @@ -481,6 +486,22 @@ func NewServer( hex.EncodeToString(_chain.blockTip().Hash[:]), hex.EncodeToString(BigintToHash(_chain.blockTip().CumWork)[:])) + nodeServices := SFFullNodeDeprecated + if _hyperSync { + nodeServices |= SFHyperSync + } + if archivalMode { + nodeServices |= SFArchivalNode + } + if _blsKeystore != nil { + nodeServices |= SFPosValidator + } + rnManager := NewRemoteNodeManager(srv, _chain, _cmgr, _blsKeystore, _params, _minFeeRateNanosPerKB, nodeServices) + + srv.handshakeController = NewHandshakeController(rnManager) + srv.connectionController = NewConnectionController(_params, _cmgr, srv.handshakeController, rnManager, + _blsKeystore, _desoAddrMgr, _targetOutboundPeers, _maxInboundPeers, _limitOneInboundConnectionPerIP) + if srv.stateChangeSyncer != nil { srv.stateChangeSyncer.BlockHeight = uint64(_chain.headerTip().Height) } @@ -2176,7 +2197,9 @@ func (srv *Server) _handleAddrMessage(pp *Peer, msg *MsgDeSoAddr) { netAddrsReceived = append( netAddrsReceived, addrAsNetAddr) } - srv.cmgr.AddrMgr.AddAddresses(netAddrsReceived, pp.netAddr) + // TODO: temporary + addressMgr := addrmgr.New("", net.LookupIP) + addressMgr.AddAddresses(netAddrsReceived, pp.netAddr) // If the message had <= 10 addrs in it, then queue all the addresses for relaying // on the next cycle. @@ -2207,7 +2230,9 @@ func (srv *Server) _handleGetAddrMessage(pp *Peer, msg *MsgDeSoGetAddr) { glog.V(1).Infof("Server._handleGetAddrMessage: Received GetAddr from peer %v", pp) // When we get a GetAddr message, choose MaxAddrsPerMsg from the AddrMgr // and send them back to the peer. - netAddrsFound := srv.cmgr.AddrMgr.AddressCache() + // TODO: temporary + addressMgr := addrmgr.New("", net.LookupIP) + netAddrsFound := addressMgr.AddressCache() if len(netAddrsFound) > MaxAddrsPerAddrMsg { netAddrsFound = netAddrsFound[:MaxAddrsPerAddrMsg] } @@ -2230,9 +2255,12 @@ func (srv *Server) _handleControlMessages(serverMessage *ServerMessage) (_should switch serverMessage.Msg.(type) { // Control messages used internally to signal to the server. case *MsgDeSoPeerHandshakeComplete: - break + srv.handshakeController._handleHandshakeCompleteMessage(serverMessage.Peer, serverMessage.Msg) case *MsgDeSoDisconnectedPeer: srv._handleDonePeer(serverMessage.Peer) + srv.connectionController._handleDonePeerMessage(serverMessage.Peer, serverMessage.Msg) + case *MsgDeSoNewConnection: + srv.connectionController._handleNewConnectionMessage(serverMessage.Peer, serverMessage.Msg) case *MsgDeSoQuit: return true } @@ -2244,6 +2272,10 @@ func (srv *Server) _handlePeerMessages(serverMessage *ServerMessage) { // Handle all non-control message types from our Peers. switch msg := serverMessage.Msg.(type) { // Messages sent among peers. + case *MsgDeSoAddr: + srv.connectionController._handleAddrMessage(serverMessage.Peer, serverMessage.Msg) + case *MsgDeSoGetAddr: + srv.connectionController._handleGetAddrMessage(serverMessage.Peer, serverMessage.Msg) case *MsgDeSoGetHeaders: srv._handleGetHeaders(serverMessage.Peer, msg) case *MsgDeSoHeaderBundle: @@ -2266,6 +2298,10 @@ func (srv *Server) _handlePeerMessages(serverMessage *ServerMessage) { srv._handleMempool(serverMessage.Peer, msg) case *MsgDeSoInv: srv._handleInv(serverMessage.Peer, msg) + case *MsgDeSoVersion: + srv.handshakeController._handleVersionMessage(serverMessage.Peer, serverMessage.Msg) + case *MsgDeSoVerack: + srv.handshakeController._handleVerackMessage(serverMessage.Peer, serverMessage.Msg) } } @@ -2443,10 +2479,12 @@ func (srv *Server) _startAddressRelayer() { } // For the first ten minutes after the server starts, relay our address to all // peers. After the first ten minutes, do it once every 24 hours. + // TODO: temporary + addressMgr := addrmgr.New("", net.LookupIP) glog.V(1).Infof("Server.Start._startAddressRelayer: Relaying our own addr to peers") if numMinutesPassed < 10 || numMinutesPassed%(RebroadcastNodeAddrIntervalMinutes) == 0 { for _, pp := range srv.cmgr.GetAllPeers() { - bestAddress := srv.cmgr.AddrMgr.GetBestLocalAddress(pp.netAddr) + bestAddress := addressMgr.GetBestLocalAddress(pp.netAddr) if bestAddress != nil { glog.V(2).Infof("Server.Start._startAddressRelayer: Relaying address %v to "+ "peer %v", bestAddress.IP.String(), pp)