diff --git a/eth/p2p/discoveryv5/protocol.nim b/eth/p2p/discoveryv5/protocol.nim index 3e41517e..e8c6b808 100644 --- a/eth/p2p/discoveryv5/protocol.nim +++ b/eth/p2p/discoveryv5/protocol.nim @@ -124,6 +124,9 @@ const defaultResponseTimeout* = 4.seconds ## timeout for the response of a request-response ## call + ## Ban durations for banned nodes in the routing table + NodeBanDurationInvalidResponse = 15.minutes + type OptAddress* = object ip*: Opt[IpAddress] @@ -142,6 +145,7 @@ type bindAddress: OptAddress ## UDP binding address pendingRequests: Table[AESGCMNonce, PendingRequest] routingTable*: RoutingTable + banNodes: bool codec*: Codec awaitedMessages: Table[(NodeId, RequestId), Future[Opt[Message]]] refreshLoop: Future[void] @@ -157,6 +161,7 @@ type responseTimeout: Duration rng*: ref HmacDrbgContext + PendingRequest = object node: Node message: seq[byte] @@ -192,10 +197,13 @@ proc addNode*(d: Protocol, node: Node): bool = ## ## Returns true only when `Node` was added as a new entry to a bucket in the ## routing table. - if d.routingTable.addNode(node) == Added: + let r = d.routingTable.addNode(node) + if r == Added: return true - else: - return false + + if r == Banned: + debug "Banned node not added to routing table", nodeId = node.id + return false proc addNode*(d: Protocol, r: Record): bool = ## Add `Node` from a `Record` to discovery routing table. @@ -429,6 +437,30 @@ proc sendWhoareyou(d: Protocol, toId: NodeId, a: Address, else: debug "Node with this id already has ongoing handshake, ignoring packet" +proc replaceNode(d: Protocol, n: Node) = + if n.record notin d.bootstrapRecords: + d.routingTable.replaceNode(n) + else: + # For now we never remove bootstrap nodes. It might make sense to actually + # do so and to retry them only in case we drop to a really low amount of + # peers in the routing table. + debug "Message request to bootstrap node failed", enr = toURI(n.record) + +proc banNode*(d: Protocol, n: Node, banPeriod: chronos.Duration) = + if n.record notin d.bootstrapRecords: + if d.banNodes: + d.routingTable.banNode(n.id, banPeriod) # banNode also replaces the node + else: + d.routingTable.replaceNode(n) + else: + # For now we never remove bootstrap nodes. It might make sense to actually + # do so and to retry them only in case we drop to a really low amount of + # peers in the routing table. + debug "Message request to bootstrap node failed", enr = toURI(n.record) + +proc isBanned*(d: Protocol, nodeId: NodeId): bool = + d.banNodes and d.routingTable.isBanned(nodeId) + proc receive*(d: Protocol, a: Address, packet: openArray[byte]) = discv5_network_bytes.inc(packet.len.int64, labelValues = [$Direction.In]) @@ -437,6 +469,10 @@ proc receive*(d: Protocol, a: Address, packet: openArray[byte]) = let packet = decoded[] case packet.flag of OrdinaryMessage: + if d.isBanned(packet.srcId): + trace "Ignoring received OrdinaryMessage from banned node", nodeId = packet.srcId + return + if packet.messageOpt.isSome(): let message = packet.messageOpt.get() trace "Received message packet", srcId = packet.srcId, address = a, @@ -464,6 +500,10 @@ proc receive*(d: Protocol, a: Address, packet: openArray[byte]) = else: debug "Timed out or unrequested whoareyou packet", address = a of HandshakeMessage: + if d.isBanned(packet.srcIdHs): + trace "Ignoring received HandshakeMessage from banned node", nodeId = packet.srcIdHs + return + trace "Received handshake message packet", srcId = packet.srcIdHs, address = a, kind = packet.message.kind d.handleMessage(packet.srcIdHs, a, packet.message, packet.node) @@ -494,14 +534,7 @@ proc processClient(transp: DatagramTransport, raddr: TransportAddress): proto.receive(Address(ip: raddr.toIpAddress(), port: raddr.port), buf) -proc replaceNode(d: Protocol, n: Node) = - if n.record notin d.bootstrapRecords: - d.routingTable.replaceNode(n) - else: - # For now we never remove bootstrap nodes. It might make sense to actually - # do so and to retry them only in case we drop to a really low amount of - # peers in the routing table. - debug "Message request to bootstrap node failed", enr = toURI(n.record) + # TODO: This could be improved to do the clean-up immediately in case a non # whoareyou response does arrive, but we would need to store the AuthTag @@ -546,9 +579,11 @@ proc waitNodes(d: Protocol, fromNode: Node, reqId: RequestId): break return ok(res) else: + d.banNode(fromNode, NodeBanDurationInvalidResponse) discovery_message_requests_outgoing.inc(labelValues = ["invalid_response"]) return err("Invalid response to find node message") else: + d.replaceNode(fromNode) discovery_message_requests_outgoing.inc(labelValues = ["no_response"]) return err("Nodes message not received in time") @@ -574,6 +609,10 @@ proc ping*(d: Protocol, toNode: Node): ## Send a discovery ping message. ## ## Returns the received pong message or an error. + + if d.isBanned(toNode.id): + return err("toNode is banned") + let reqId = d.sendMessage(toNode, PingMessage(enrSeq: d.localNode.record.seqNum)) let resp = await d.waitMessage(toNode, reqId) @@ -583,7 +622,7 @@ proc ping*(d: Protocol, toNode: Node): d.routingTable.setJustSeen(toNode) return ok(resp.get().pong) else: - d.replaceNode(toNode) + d.banNode(toNode, NodeBanDurationInvalidResponse) discovery_message_requests_outgoing.inc(labelValues = ["invalid_response"]) return err("Invalid response to ping message") else: @@ -597,15 +636,18 @@ proc findNode*(d: Protocol, toNode: Node, distances: seq[uint16]): ## ## Returns the received nodes or an error. ## Received ENRs are already validated and converted to `Node`. + + if d.isBanned(toNode.id): + return err("toNode is banned") + let reqId = d.sendMessage(toNode, FindNodeMessage(distances: distances)) let nodes = await d.waitNodes(toNode, reqId) if nodes.isOk: let res = verifyNodesRecords(nodes.get(), toNode, findNodeResultLimit, distances) d.routingTable.setJustSeen(toNode) - return ok(res) + return ok(res.filterIt(not d.isBanned(it.id))) else: - d.replaceNode(toNode) return err(nodes.error) proc talkReq*(d: Protocol, toNode: Node, protocol, request: seq[byte]): @@ -613,6 +655,10 @@ proc talkReq*(d: Protocol, toNode: Node, protocol, request: seq[byte]): ## Send a discovery talkreq message. ## ## Returns the received talkresp message or an error. + + if d.isBanned(toNode.id): + return err("toNode is banned") + let reqId = d.sendMessage(toNode, TalkReqMessage(protocol: protocol, request: request)) let resp = await d.waitMessage(toNode, reqId) @@ -622,7 +668,7 @@ proc talkReq*(d: Protocol, toNode: Node, protocol, request: seq[byte]): d.routingTable.setJustSeen(toNode) return ok(resp.get().talkResp.response) else: - d.replaceNode(toNode) + d.banNode(toNode, NodeBanDurationInvalidResponse) discovery_message_requests_outgoing.inc(labelValues = ["invalid_response"]) return err("Invalid response to talk request message") else: @@ -797,6 +843,12 @@ proc resolve*(d: Protocol, id: NodeId): Future[Opt[Node]] {.async: (raises: [Can if id == d.localNode.id: return Opt.some(d.localNode) + # No point in trying to resolve a banned node because it won't exist in the + # routing table and it will be filtered out of any respones in the lookup call + if d.isBanned(id): + debug "Not resolving banned node", nodeId = id + return Opt.none(Node) + let node = d.getNode(id) if node.isSome(): let request = await d.findNode(node.get(), @[0'u16]) @@ -882,6 +934,9 @@ proc refreshLoop(d: Protocol) {.async: (raises: []).} = trace "Discovered nodes in random target query", nodes = randomQuery.len debug "Total nodes in discv5 routing table", total = d.routingTable.len() + # Remove the expired bans from routing table to limit memory usage + d.routingTable.cleanupExpiredBans() + await sleepAsync(refreshInterval) except CancelledError: trace "refreshLoop canceled" @@ -985,6 +1040,7 @@ proc newProtocol*( bindPort: Port, bindIp = IPv4_any(), enrAutoUpdate = false, + banNodes = false, config = defaultDiscoveryConfig, rng = newRng()): Protocol = @@ -1034,6 +1090,7 @@ proc newProtocol*( enrAutoUpdate: enrAutoUpdate, routingTable: RoutingTable.init( node, config.bitsPerHop, config.tableIpLimits, rng), + banNodes: banNodes, handshakeTimeout: config.handshakeTimeout, responseTimeout: config.responseTimeout, rng: rng) diff --git a/eth/p2p/discoveryv5/routing_table.nim b/eth/p2p/discoveryv5/routing_table.nim index 8d10621b..773c42c4 100644 --- a/eth/p2p/discoveryv5/routing_table.nim +++ b/eth/p2p/discoveryv5/routing_table.nim @@ -195,7 +195,7 @@ func ipLimitDec(r: var RoutingTable, b: KBucket, n: Node) = r.ipLimits.dec(ip) func getNode*(r: RoutingTable, id: NodeId): Opt[Node] -proc replaceNode*(r: var RoutingTable, n: Node) +proc replaceNode*(r: var RoutingTable, n: Node) {.gcsafe.} proc banNode*(r: var RoutingTable, nodeId: NodeId, period: chronos.Duration) = ## Ban a node from the routing table for the given period. The node is removed diff --git a/tests/p2p/discv5_test_helper.nim b/tests/p2p/discv5_test_helper.nim index 781663a4..ba02b718 100644 --- a/tests/p2p/discv5_test_helper.nim +++ b/tests/p2p/discv5_test_helper.nim @@ -22,7 +22,8 @@ proc initDiscoveryNode*( address: Address, bootstrapRecords: openArray[Record] = [], localEnrFields: openArray[(string, seq[byte])] = [], - previousRecord = Opt.none(enr.Record)): + previousRecord = Opt.none(enr.Record), + banNodes = false): discv5_protocol.Protocol = # set bucketIpLimit to allow bucket split let config = DiscoveryConfig.init(1000, 24, 5) @@ -36,7 +37,8 @@ proc initDiscoveryNode*( localEnrFields = localEnrFields, previousRecord = previousRecord, config = config, - rng = rng) + rng = rng, + banNodes = banNodes) protocol.open() diff --git a/tests/p2p/test_discoveryv5.nim b/tests/p2p/test_discoveryv5.nim index 8c8e5e1c..eec0db0f 100644 --- a/tests/p2p/test_discoveryv5.nim +++ b/tests/p2p/test_discoveryv5.nim @@ -926,3 +926,116 @@ suite "Discovery v5 Tests": await node1.closeWait() await node2.closeWait() + + asyncTest "Banned nodes are removed and cannot be added": + let + node = initDiscoveryNode(rng, PrivateKey.random(rng[]), localAddress(20302), banNodes = true) + targetNode = generateNode(PrivateKey.random(rng[])) + + # add the node + check: + node.addNode(targetNode) == true + node.getNode(targetNode.id).isSome() + + # banning the node should remove it from the routing table + node.banNode(targetNode, 1.minutes) + check node.getNode(targetNode.id).isNone() + + # cannot add a banned node + check: + node.addNode(targetNode) == false + node.getNode(targetNode.id).isNone() + + await node.closeWait() + + asyncTest "FindNode filters out banned nodes": + let + mainNode = initDiscoveryNode(rng, PrivateKey.random(rng[]), localAddress(20301), + banNodes = true) + testNode = initDiscoveryNode(rng, PrivateKey.random(rng[]), localAddress(20302), + @[mainNode.localNode.record], banNodes = true) + + # Generate 100 random nodes and add to our main node's routing table + for i in 0 ..< 100: + discard mainNode.addSeenNode(generateNode(PrivateKey.random(rng[]))) + + let + neighbours = mainNode.neighbours(mainNode.localNode.id) + closest = neighbours[0] + closestDistance = logDistance(closest.id, mainNode.localNode.id) + + block: + # the closest node is returned + let discovered = await testNode.findNode(mainNode.localNode, @[closestDistance]) + check discovered.isOk + check closest in discovered[] + + # ban the closest node + mainNode.banNode(closest, 1.minutes) + + block: + # the banned node is not returned + let discovered = await testNode.findNode(mainNode.localNode, @[closestDistance]) + check discovered.isOk + check closest notin discovered[] + + await mainNode.closeWait() + await testNode.closeWait() + + asyncTest "Cannot send messages to banned nodes": + let + node1 = initDiscoveryNode(rng, PrivateKey.random(rng[]), localAddress(20302), + banNodes = true) + node2 = initDiscoveryNode(rng, PrivateKey.random(rng[]), localAddress(20301), + banNodes = true) + + # ban node2 in node1's routing table + node1.banNode(node2.localNode, 1.minutes) + + block: + let pong = await node1.ping(node2.localNode) + check: + pong.isErr() + pong.error() == "toNode is banned" + + block: + let nodes = await node1.findNode(node2.localNode, @[0.uint16]) + check: + nodes.isErr() + nodes.error() == "toNode is banned" + + block: + let node = await node1.resolve(node2.localNode.id) + check node.isNone() + + await node2.closeWait() + await node1.closeWait() + + asyncTest "Ignore messages from banned nodes": + let + node1 = initDiscoveryNode(rng, PrivateKey.random(rng[]), localAddress(20302), + banNodes = true) + node2 = initDiscoveryNode(rng, PrivateKey.random(rng[]), localAddress(20301), + banNodes = true) + + # ban node1 in node2's routing table + node2.banNode(node1.localNode, 1.minutes) + + block: + let pong = await node1.ping(node2.localNode) + check: + pong.isErr() + pong.error() == "Pong message not received in time" + + block: + let nodes = await node1.findNode(node2.localNode, @[0.uint16]) + check: + nodes.isErr() + nodes.error() == "Nodes message not received in time" + + block: + let node = await node1.resolve(node2.localNode.id) + check node.isNone() + + await node2.closeWait() + await node1.closeWait()