Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(quic): add support for reusing an existing socket for local dialing #4304

Merged
merged 10 commits into from
Aug 11, 2023
28 changes: 21 additions & 7 deletions transports/quic/src/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ use libp2p_core::{
use libp2p_identity::PeerId;
use socket2::{Domain, Socket, Type};
use std::collections::hash_map::{DefaultHasher, Entry};
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::hash::{Hash, Hasher};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, UdpSocket};
use std::time::Duration;
Expand Down Expand Up @@ -155,9 +155,16 @@ impl<P: Provider> GenTransport<P> {
if l.is_closed {
return false;
}
let listen_addr = l.socket_addr();
SocketFamily::is_same(&listen_addr.ip(), &socket_addr.ip())
&& listen_addr.ip().is_loopback() == socket_addr.ip().is_loopback()
SocketFamily::is_same(&l.socket_addr().ip(), &socket_addr.ip())
})
.filter(|l| {
if socket_addr.ip().is_loopback() {
l.listening_addresses
.iter()
.any(|ip_addr| ip_addr.is_loopback())
} else {
true
}
})
.collect();
match listeners.len() {
Expand Down Expand Up @@ -428,6 +435,8 @@ struct Listener<P: Provider> {

/// The stream must be awaken after it has been closed to deliver the last event.
close_listener_waker: Option<Waker>,

listening_addresses: HashSet<IpAddr>,
}

impl<P: Provider> Listener<P> {
Expand All @@ -440,12 +449,14 @@ impl<P: Provider> Listener<P> {
) -> Result<Self, Error> {
let if_watcher;
let pending_event;
let mut listening_addresses = HashSet::new();
let local_addr = socket.local_addr()?;
if local_addr.ip().is_unspecified() {
if_watcher = Some(P::new_if_watcher()?);
pending_event = None;
} else {
if_watcher = None;
listening_addresses.insert(local_addr.ip());
let ma = socketaddr_to_multiaddr(&local_addr, version);
pending_event = Some(TransportEvent::NewAddress {
listener_id,
Expand All @@ -467,6 +478,7 @@ impl<P: Provider> Listener<P> {
is_closed: false,
pending_event,
close_listener_waker: None,
listening_addresses,
})
}

Expand Down Expand Up @@ -513,7 +525,8 @@ impl<P: Provider> Listener<P> {
if let Some(listen_addr) =
ip_to_listenaddr(&endpoint_addr, inet.addr(), self.version)
{
log::debug!("New listen address: {}", listen_addr);
log::debug!("New listen address: {listen_addr}");
self.listening_addresses.insert(inet.addr());
return Poll::Ready(TransportEvent::NewAddress {
listener_id: self.listener_id,
listen_addr,
Expand All @@ -524,7 +537,8 @@ impl<P: Provider> Listener<P> {
if let Some(listen_addr) =
ip_to_listenaddr(&endpoint_addr, inet.addr(), self.version)
{
log::debug!("Expired listen address: {}", listen_addr);
log::debug!("Expired listen address: {listen_addr}");
mxinden marked this conversation as resolved.
Show resolved Hide resolved
self.listening_addresses.remove(&inet.addr());
return Poll::Ready(TransportEvent::AddressExpired {
listener_id: self.listener_id,
listen_addr,
Expand Down Expand Up @@ -730,7 +744,7 @@ fn socketaddr_to_multiaddr(socket_addr: &SocketAddr, version: ProtocolVersion) -

#[cfg(test)]
#[cfg(any(feature = "async-std", feature = "tokio"))]
mod test {
mod tests {
use futures::future::poll_fn;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};

Expand Down
43 changes: 43 additions & 0 deletions transports/quic/tests/smoke.rs
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,49 @@ async fn write_after_peer_dropped_stream() {
stream_b.close().await.expect("Close failed.");
}

/// - A listens on 0.0.0.0:0
/// - B listens on 127.0.0.1:0
/// - A dials B
/// - Source port of A at B is the A's listen port
#[cfg(feature = "tokio")]
#[tokio::test]
async fn test_local_listener_reuse() {
mxinden marked this conversation as resolved.
Show resolved Hide resolved
let (_, mut a_transport) = create_default_transport::<quic::tokio::Provider>();
let (_, mut b_transport) = create_default_transport::<quic::tokio::Provider>();

a_transport
.listen_on(
ListenerId::next(),
"/ip4/0.0.0.0/udp/0/quic-v1".parse().unwrap(),
)
.unwrap();

// wait until a listener reports a loopback address
let a_listen_addr = 'outer: loop {
let ev = a_transport.next().await.unwrap();
let listen_addr = ev.into_new_address().unwrap();
for proto in listen_addr.iter() {
if let Protocol::Ip4(ip4) = proto {
if ip4.is_loopback() {
break 'outer listen_addr;
}
}
}
};
// If we do not poll until the end, `NewAddress` events may be `Ready` and `connect` function
// below will panic due to an unexpected event.
poll_fn(|cx| {
let mut pinned = Pin::new(&mut a_transport);
while pinned.as_mut().poll(cx).is_ready() {}
Poll::Ready(())
})
.await;

let b_addr = start_listening(&mut b_transport, "/ip4/127.0.0.1/udp/0/quic-v1").await;
let (_, send_back_addr, _) = connect(&mut b_transport, &mut a_transport, b_addr).await.0;
assert_eq!(send_back_addr, a_listen_addr);
}

async fn smoke<P: Provider>() {
let _ = env_logger::try_init();

Expand Down