diff --git a/iroh/bench/src/iroh.rs b/iroh/bench/src/iroh.rs index b01811a8bb..792481a1de 100644 --- a/iroh/bench/src/iroh.rs +++ b/iroh/bench/src/iroh.rs @@ -7,6 +7,7 @@ use anyhow::{Context, Result}; use bytes::Bytes; use iroh::{ endpoint::{Connection, ConnectionError, RecvStream, SendStream, TransportConfig}, + watcher::Watcher as _, Endpoint, NodeAddr, RelayMap, RelayMode, RelayUrl, }; use tracing::{trace, warn}; diff --git a/iroh/examples/connect-unreliable.rs b/iroh/examples/connect-unreliable.rs index ea375ba0dd..b361b7e268 100644 --- a/iroh/examples/connect-unreliable.rs +++ b/iroh/examples/connect-unreliable.rs @@ -8,7 +8,7 @@ use std::net::SocketAddr; use clap::Parser; -use iroh::{Endpoint, NodeAddr, RelayMode, RelayUrl, SecretKey}; +use iroh::{watcher::Watcher as _, Endpoint, NodeAddr, RelayMode, RelayUrl, SecretKey}; use tracing::info; // An example ALPN that we are using to communicate over the `Endpoint` @@ -50,7 +50,7 @@ async fn main() -> anyhow::Result<()> { .bind() .await?; - let node_addr = endpoint.node_addr().await?; + let node_addr = endpoint.node_addr().initialized().await?; let me = node_addr.node_id; println!("node id: {me}"); println!("node listening addresses:"); diff --git a/iroh/examples/connect.rs b/iroh/examples/connect.rs index 5187c02a1d..767abd478e 100644 --- a/iroh/examples/connect.rs +++ b/iroh/examples/connect.rs @@ -9,7 +9,7 @@ use std::net::SocketAddr; use anyhow::Context; use clap::Parser; -use iroh::{Endpoint, NodeAddr, RelayMode, RelayUrl, SecretKey}; +use iroh::{watcher::Watcher as _, Endpoint, NodeAddr, RelayMode, RelayUrl, SecretKey}; use tracing::info; // An example ALPN that we are using to communicate over the `Endpoint` diff --git a/iroh/examples/echo.rs b/iroh/examples/echo.rs index e18aee8949..00e1d3ccf1 100644 --- a/iroh/examples/echo.rs +++ b/iroh/examples/echo.rs @@ -11,6 +11,7 @@ use futures_lite::future::Boxed as BoxedFuture; use iroh::{ endpoint::Connecting, protocol::{ProtocolHandler, Router}, + watcher::Watcher as _, Endpoint, NodeAddr, }; @@ -23,7 +24,7 @@ const ALPN: &[u8] = b"iroh-example/echo/0"; #[tokio::main] async fn main() -> Result<()> { let router = accept_side().await?; - let node_addr = router.endpoint().node_addr().await?; + let node_addr = router.endpoint().node_addr().initialized().await?; connect_side(node_addr).await?; diff --git a/iroh/examples/listen-unreliable.rs b/iroh/examples/listen-unreliable.rs index 06d24a76ab..d787ed6107 100644 --- a/iroh/examples/listen-unreliable.rs +++ b/iroh/examples/listen-unreliable.rs @@ -3,7 +3,7 @@ //! This example uses the default relay servers to attempt to holepunch, and will use that relay server to relay packets if the two devices cannot establish a direct UDP connection. //! run this example from the project root: //! $ cargo run --example listen-unreliable -use iroh::{Endpoint, RelayMode, SecretKey}; +use iroh::{watcher::Watcher as _, Endpoint, RelayMode, SecretKey}; use tracing::{info, warn}; // An example ALPN that we are using to communicate over the `Endpoint` @@ -35,7 +35,7 @@ async fn main() -> anyhow::Result<()> { println!("node id: {me}"); println!("node listening addresses:"); - let node_addr = endpoint.node_addr().await?; + let node_addr = endpoint.node_addr().initialized().await?; let local_addrs = node_addr .direct_addresses .into_iter() diff --git a/iroh/examples/listen.rs b/iroh/examples/listen.rs index 13413992dd..9b969f3bc3 100644 --- a/iroh/examples/listen.rs +++ b/iroh/examples/listen.rs @@ -5,7 +5,7 @@ //! $ cargo run --example listen use std::time::Duration; -use iroh::{endpoint::ConnectionError, Endpoint, RelayMode, SecretKey}; +use iroh::{endpoint::ConnectionError, watcher::Watcher as _, Endpoint, RelayMode, SecretKey}; use tracing::{debug, info, warn}; // An example ALPN that we are using to communicate over the `Endpoint` @@ -37,7 +37,7 @@ async fn main() -> anyhow::Result<()> { println!("node id: {me}"); println!("node listening addresses:"); - let node_addr = endpoint.node_addr().await?; + let node_addr = endpoint.node_addr().initialized().await?; let local_addrs = node_addr .direct_addresses .into_iter() diff --git a/iroh/examples/transfer.rs b/iroh/examples/transfer.rs index b5dee21393..159ca86d2d 100644 --- a/iroh/examples/transfer.rs +++ b/iroh/examples/transfer.rs @@ -8,7 +8,8 @@ use bytes::Bytes; use clap::{Parser, Subcommand}; use indicatif::HumanBytes; use iroh::{ - endpoint::ConnectionError, Endpoint, NodeAddr, RelayMap, RelayMode, RelayUrl, SecretKey, + endpoint::ConnectionError, watcher::Watcher as _, Endpoint, NodeAddr, RelayMap, RelayMode, + RelayUrl, SecretKey, }; use iroh_base::ticket::NodeTicket; use tracing::info; diff --git a/iroh/src/discovery.rs b/iroh/src/discovery.rs index a8ee7965f6..d9b8d3afc6 100644 --- a/iroh/src/discovery.rs +++ b/iroh/src/discovery.rs @@ -450,7 +450,7 @@ mod tests { use tokio_util::task::AbortOnDropHandle; use super::*; - use crate::RelayMode; + use crate::{watcher::Watcher as _, RelayMode}; type InfoStore = HashMap<NodeId, (Option<RelayUrl>, BTreeSet<SocketAddr>, u64)>; @@ -580,7 +580,7 @@ mod tests { }; let ep1_addr = NodeAddr::new(ep1.node_id()); // wait for our address to be updated and thus published at least once - ep1.node_addr().await?; + ep1.node_addr().initialized().await?; let _conn = ep2.connect(ep1_addr, TEST_ALPN).await?; Ok(()) } @@ -606,7 +606,7 @@ mod tests { }; let ep1_addr = NodeAddr::new(ep1.node_id()); // wait for out address to be updated and thus published at least once - ep1.node_addr().await?; + ep1.node_addr().initialized().await?; let _conn = ep2.connect(ep1_addr, TEST_ALPN).await?; Ok(()) } @@ -636,7 +636,7 @@ mod tests { }; let ep1_addr = NodeAddr::new(ep1.node_id()); // wait for out address to be updated and thus published at least once - ep1.node_addr().await?; + ep1.node_addr().initialized().await?; let _conn = ep2.connect(ep1_addr, TEST_ALPN).await?; Ok(()) } @@ -659,7 +659,7 @@ mod tests { }; let ep1_addr = NodeAddr::new(ep1.node_id()); // wait for out address to be updated and thus published at least once - ep1.node_addr().await?; + ep1.node_addr().initialized().await?; let res = ep2.connect(ep1_addr, TEST_ALPN).await; assert!(res.is_err()); Ok(()) @@ -682,7 +682,7 @@ mod tests { new_endpoint(secret, disco).await }; // wait for out address to be updated and thus published at least once - ep1.node_addr().await?; + ep1.node_addr().initialized().await?; let ep1_wrong_addr = NodeAddr { node_id: ep1.node_id(), relay_url: None, diff --git a/iroh/src/discovery/local_swarm_discovery.rs b/iroh/src/discovery/local_swarm_discovery.rs index b4975807c4..aaf8dc0d4e 100644 --- a/iroh/src/discovery/local_swarm_discovery.rs +++ b/iroh/src/discovery/local_swarm_discovery.rs @@ -54,7 +54,7 @@ use tracing::{debug, error, info_span, trace, warn, Instrument}; use crate::{ discovery::{Discovery, DiscoveryItem}, - watchable::Watchable, + watcher::{Watchable, Watcher as _}, Endpoint, }; diff --git a/iroh/src/discovery/pkarr.rs b/iroh/src/discovery/pkarr.rs index a319094b68..a493c177f3 100644 --- a/iroh/src/discovery/pkarr.rs +++ b/iroh/src/discovery/pkarr.rs @@ -61,7 +61,7 @@ use crate::{ discovery::{Discovery, DiscoveryItem}, dns::node_info::NodeInfo, endpoint::force_staging_infra, - watchable::{Disconnected, Watchable, Watcher}, + watcher::{self, Disconnected, Watchable, Watcher as _}, Endpoint, }; @@ -221,7 +221,7 @@ struct PublisherService { secret_key: SecretKey, #[debug("PkarrClient")] pkarr_client: PkarrRelayClient, - watcher: Watcher<Option<NodeInfo>>, + watcher: watcher::Direct<Option<NodeInfo>>, ttl: u32, republish_interval: Duration, } diff --git a/iroh/src/endpoint.rs b/iroh/src/endpoint.rs index 8ce5a97bd5..70dd838016 100644 --- a/iroh/src/endpoint.rs +++ b/iroh/src/endpoint.rs @@ -38,7 +38,7 @@ use crate::{ dns::{default_resolver, DnsResolver}, magicsock::{self, Handle, QuicMappedAddr}, tls, - watchable::Watcher, + watcher::{self, Watcher}, }; mod rtt_actor; @@ -74,6 +74,20 @@ const DISCOVERY_WAIT_PERIOD: Duration = Duration::from_millis(500); type DiscoveryBuilder = Box<dyn FnOnce(&SecretKey) -> Option<Box<dyn Discovery>> + Send + Sync>; +/// A type alias for the return value of [`Endpoint::node_addr`]. +/// +/// This type implements [`Watcher`] with `Value` being an optional [`NodeAddr`]. +/// +/// We return a named type instead of `impl Watcher<Value = NodeAddr>`, as this allows +/// you to e.g. store the watcher in a struct. +pub type NodeAddrWatcher = watcher::Map< + ( + watcher::Direct<Option<BTreeSet<DirectAddr>>>, + watcher::Direct<Option<RelayUrl>>, + ), + Option<NodeAddr>, +>; + /// Builder for [`Endpoint`]. /// /// By default the endpoint will generate a new random [`SecretKey`], which will result in a @@ -756,19 +770,47 @@ impl Endpoint { self.static_config.secret_key.public() } - /// Returns the current [`NodeAddr`] for this endpoint. + /// Returns a [`Watcher`] for the current [`NodeAddr`] for this endpoint. /// - /// The returned [`NodeAddr`] will have the current [`RelayUrl`] and direct addresses - /// as they would be returned by [`Endpoint::home_relay`] and - /// [`Endpoint::direct_addresses`]. - pub async fn node_addr(&self) -> Result<NodeAddr> { - let addrs = self.direct_addresses().initialized().await?; - let relay = self.home_relay().get()?; - Ok(NodeAddr::from_parts( - self.node_id(), - relay, - addrs.into_iter().map(|x| x.addr), - )) + /// The observed [`NodeAddr`] will have the current [`RelayUrl`] and direct addresses + /// as they would be returned by [`Endpoint::home_relay`] and [`Endpoint::direct_addresses`]. + /// + /// Use [`Watcher::initialized`] to wait for a [`NodeAddr`] that is ready to be connected to: + /// + /// ```no_run + /// # async fn wrapper() -> testresult::TestResult { + /// use iroh::{watcher::Watcher, Endpoint}; + /// + /// let endpoint = Endpoint::builder() + /// .alpns(vec![b"my-alpn".to_vec()]) + /// .bind() + /// .await?; + /// let node_addr = endpoint.node_addr().initialized().await?; + /// # let _ = node_addr; + /// # Ok(()) + /// # } + /// ``` + pub fn node_addr(&self) -> NodeAddrWatcher { + let watch_addrs = self.direct_addresses(); + let watch_relay = self.home_relay(); + let node_id = self.node_id(); + + watch_addrs + .or(watch_relay) + .map(move |(addrs, relay)| match (addrs, relay) { + (Some(addrs), relay) => Some(NodeAddr::from_parts( + node_id, + relay, + addrs.into_iter().map(|x| x.addr), + )), + (None, Some(relay)) => Some(NodeAddr::from_parts( + node_id, + Some(relay), + std::iter::empty(), + )), + (None, None) => None, + }) + .expect("watchable is alive - cannot be disconnected yet") } /// Returns a [`Watcher`] for the [`RelayUrl`] of the Relay server used as home relay. @@ -788,7 +830,7 @@ impl Endpoint { /// To wait for a home relay connection to be established, use [`Watcher::initialized`]: /// ```no_run /// use futures_lite::StreamExt; - /// use iroh::Endpoint; + /// use iroh::{watcher::Watcher, Endpoint}; /// /// # let rt = tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap(); /// # rt.block_on(async move { @@ -796,7 +838,7 @@ impl Endpoint { /// let _relay_url = mep.home_relay().initialized().await.unwrap(); /// # }); /// ``` - pub fn home_relay(&self) -> Watcher<Option<RelayUrl>> { + pub fn home_relay(&self) -> watcher::Direct<Option<RelayUrl>> { self.msock.home_relay() } @@ -824,7 +866,7 @@ impl Endpoint { /// To get the first set of direct addresses use [`Watcher::initialized`]: /// ```no_run /// use futures_lite::StreamExt; - /// use iroh::Endpoint; + /// use iroh::{watcher::Watcher, Endpoint}; /// /// # let rt = tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap(); /// # rt.block_on(async move { @@ -834,7 +876,7 @@ impl Endpoint { /// ``` /// /// [STUN]: https://en.wikipedia.org/wiki/STUN - pub fn direct_addresses(&self) -> Watcher<Option<BTreeSet<DirectAddr>>> { + pub fn direct_addresses(&self) -> watcher::Direct<Option<BTreeSet<DirectAddr>>> { self.msock.direct_addresses() } @@ -906,7 +948,7 @@ impl Endpoint { /// # Errors /// /// Will error if we do not have any address information for the given `node_id`. - pub fn conn_type(&self, node_id: NodeId) -> Result<Watcher<ConnectionType>> { + pub fn conn_type(&self, node_id: NodeId) -> Result<watcher::Direct<ConnectionType>> { self.msock.conn_type(node_id) } @@ -1419,6 +1461,8 @@ mod tests { use futures_lite::StreamExt; use iroh_test::CallOnDrop; use rand::SeedableRng; + use testresult::TestResult; + use tokio_util::task::AbortOnDropHandle; use tracing::{error_span, info, info_span, Instrument}; use super::*; @@ -1434,7 +1478,7 @@ mod tests { .bind() .await .unwrap(); - let my_addr = ep.node_addr().await.unwrap(); + let my_addr = ep.node_addr().initialized().await.unwrap(); let res = ep.connect(my_addr.clone(), TEST_ALPN).await; assert!(res.is_err()); let err = res.err().unwrap(); @@ -1716,8 +1760,8 @@ mod tests { .bind() .await .unwrap(); - let ep1_nodeaddr = ep1.node_addr().await.unwrap(); - let ep2_nodeaddr = ep2.node_addr().await.unwrap(); + let ep1_nodeaddr = ep1.node_addr().initialized().await.unwrap(); + let ep2_nodeaddr = ep2.node_addr().initialized().await.unwrap(); ep1.add_node_addr(ep2_nodeaddr.clone()).unwrap(); ep2.add_node_addr(ep1_nodeaddr.clone()).unwrap(); let ep1_nodeid = ep1.node_id(); @@ -1792,10 +1836,10 @@ mod tests { } #[tokio::test] - async fn endpoint_conn_type_stream() { + async fn endpoint_conn_type_becomes_direct() -> TestResult { const TIMEOUT: Duration = std::time::Duration::from_secs(15); let _logging_guard = iroh_test::logging::setup(); - let (relay_map, _relay_url, _relay_guard) = run_relay_server().await.unwrap(); + let (relay_map, _relay_url, _relay_guard) = run_relay_server().await?; let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42); let ep1_secret_key = SecretKey::generate(&mut rng); let ep2_secret_key = SecretKey::generate(&mut rng); @@ -1805,18 +1849,16 @@ mod tests { .alpns(vec![TEST_ALPN.to_vec()]) .relay_mode(RelayMode::Custom(relay_map.clone())) .bind() - .await - .unwrap(); + .await?; let ep2 = Endpoint::builder() .secret_key(ep2_secret_key) .insecure_skip_relay_cert_verify(true) .alpns(vec![TEST_ALPN.to_vec()]) .relay_mode(RelayMode::Custom(relay_map)) .bind() - .await - .unwrap(); + .await?; - async fn handle_direct_conn(ep: &Endpoint, node_id: PublicKey) -> Result<()> { + async fn wait_for_conn_type_direct(ep: &Endpoint, node_id: PublicKey) -> TestResult { let mut stream = ep.conn_type(node_id)?.stream(); let src = ep.node_id().fmt_short(); let dst = node_id.fmt_short(); @@ -1826,53 +1868,56 @@ mod tests { return Ok(()); } } - anyhow::bail!("conn_type stream ended before `ConnectionType::Direct`"); + panic!("conn_type stream ended before `ConnectionType::Direct`"); } - async fn accept(ep: &Endpoint) -> NodeId { - let incoming = ep.accept().await.unwrap(); - let conn = incoming.await.unwrap(); - let node_id = get_remote_node_id(&conn).unwrap(); + async fn accept(ep: &Endpoint) -> TestResult<Connection> { + let incoming = ep.accept().await.expect("ep closed"); + let conn = incoming.await?; + let node_id = get_remote_node_id(&conn)?; tracing::info!(node_id=%node_id.fmt_short(), "accepted connection"); - node_id + Ok(conn) } let ep1_nodeid = ep1.node_id(); let ep2_nodeid = ep2.node_id(); - let ep1_nodeaddr = ep1.node_addr().await.unwrap(); + let ep1_nodeaddr = ep1.node_addr().initialized().await?; tracing::info!( "node id 1 {ep1_nodeid}, relay URL {:?}", ep1_nodeaddr.relay_url() ); tracing::info!("node id 2 {ep2_nodeid}"); - let ep1_side = async move { - accept(&ep1).await; - handle_direct_conn(&ep1, ep2_nodeid).await - }; - - let ep2_side = async move { - ep2.connect(ep1_nodeaddr, TEST_ALPN).await.unwrap(); - handle_direct_conn(&ep2, ep1_nodeid).await - }; - - let res_ep1 = tokio::spawn(tokio::time::timeout(TIMEOUT, ep1_side)); - - let ep1_abort_handle = res_ep1.abort_handle(); - let _ep1_guard = CallOnDrop::new(move || { - ep1_abort_handle.abort(); + let ep1_side = tokio::time::timeout(TIMEOUT, async move { + let conn = accept(&ep1).await?; + let mut send = conn.open_uni().await?; + wait_for_conn_type_direct(&ep1, ep2_nodeid).await?; + send.write_all(b"Conn is direct").await?; + send.finish()?; + conn.closed().await; + TestResult::Ok(()) }); - let res_ep2 = tokio::spawn(tokio::time::timeout(TIMEOUT, ep2_side)); - let ep2_abort_handle = res_ep2.abort_handle(); - let _ep2_guard = CallOnDrop::new(move || { - ep2_abort_handle.abort(); + let ep2_side = tokio::time::timeout(TIMEOUT, async move { + let conn = ep2.connect(ep1_nodeaddr, TEST_ALPN).await?; + let mut recv = conn.accept_uni().await?; + wait_for_conn_type_direct(&ep2, ep1_nodeid).await?; + let read = recv.read_to_end(100).await?; + assert_eq!(read, b"Conn is direct".to_vec()); + conn.close(0u32.into(), b"done"); + conn.closed().await; + TestResult::Ok(()) }); - let (r1, r2) = tokio::try_join!(res_ep1, res_ep2).unwrap(); - r1.expect("ep1 timeout").unwrap(); - r2.expect("ep2 timeout").unwrap(); + let res_ep1 = AbortOnDropHandle::new(tokio::spawn(ep1_side)); + let res_ep2 = AbortOnDropHandle::new(tokio::spawn(ep2_side)); + + let (r1, r2) = tokio::try_join!(res_ep1, res_ep2)?; + r1??; + r2??; + + Ok(()) } #[tokio::test] diff --git a/iroh/src/endpoint/rtt_actor.rs b/iroh/src/endpoint/rtt_actor.rs index c1a9748ef5..951760f120 100644 --- a/iroh/src/endpoint/rtt_actor.rs +++ b/iroh/src/endpoint/rtt_actor.rs @@ -13,7 +13,7 @@ use tokio::{ use tokio_util::task::AbortOnDropHandle; use tracing::{debug, error, info_span, trace, Instrument}; -use crate::{magicsock::ConnectionType, metrics::MagicsockMetrics, watchable::WatcherStream}; +use crate::{magicsock::ConnectionType, metrics::MagicsockMetrics, watcher}; #[derive(Debug)] pub(super) struct RttHandle { @@ -51,7 +51,7 @@ pub(super) enum RttMessage { /// The connection. connection: quinn::WeakConnectionHandle, /// Path changes for this connection from the magic socket. - conn_type_changes: WatcherStream<ConnectionType>, + conn_type_changes: watcher::Stream<watcher::Direct<ConnectionType>>, /// For reporting-only, the Node ID of this connection. node_id: NodeId, }, @@ -64,7 +64,7 @@ pub(super) enum RttMessage { #[derive(Debug)] struct RttActor { /// Stream of connection type changes. - connection_events: stream_group::Keyed<WatcherStream<ConnectionType>>, + connection_events: stream_group::Keyed<watcher::Stream<watcher::Direct<ConnectionType>>>, /// References to the connections. /// /// These are weak references so not to keep the connections alive. The key allows @@ -121,7 +121,7 @@ impl RttActor { fn handle_new_connection( &mut self, connection: quinn::WeakConnectionHandle, - conn_type_changes: WatcherStream<ConnectionType>, + conn_type_changes: watcher::Stream<watcher::Direct<ConnectionType>>, node_id: NodeId, ) { let key = self.connection_events.insert(conn_type_changes); diff --git a/iroh/src/lib.rs b/iroh/src/lib.rs index 988b7697a4..659925d97c 100644 --- a/iroh/src/lib.rs +++ b/iroh/src/lib.rs @@ -246,7 +246,7 @@ pub mod endpoint; pub mod metrics; pub mod protocol; mod tls; -pub mod watchable; +pub mod watcher; pub use endpoint::{Endpoint, RelayMode}; pub use iroh_base::{ diff --git a/iroh/src/magicsock.rs b/iroh/src/magicsock.rs index 49336e0d20..41b9c9e89e 100644 --- a/iroh/src/magicsock.rs +++ b/iroh/src/magicsock.rs @@ -66,7 +66,7 @@ use crate::{ discovery::{Discovery, DiscoveryItem}, dns::DnsResolver, key::{public_ed_box, secret_ed_box, DecryptionError, SharedSecret}, - watchable::{Watchable, Watcher}, + watcher::{self, Watchable}, }; mod metrics; @@ -323,7 +323,10 @@ impl MagicSock { /// store [`Some`] set of addresses. /// /// To get the current direct addresses, use [`Watcher::initialized`]. - pub(crate) fn direct_addresses(&self) -> Watcher<Option<BTreeSet<DirectAddr>>> { + /// + /// [`Watcher`]: crate::watcher::Watcher + /// [`Watcher::initialized`]: crate::watcher::Watcher::initialized + pub(crate) fn direct_addresses(&self) -> watcher::Direct<Option<BTreeSet<DirectAddr>>> { self.direct_addrs.addrs.watch() } @@ -331,7 +334,10 @@ impl MagicSock { /// /// Note that this can be used to wait for the initial home relay to be known using /// [`Watcher::initialized`]. - pub(crate) fn home_relay(&self) -> Watcher<Option<RelayUrl>> { + /// + /// [`Watcher`]: crate::watcher::Watcher + /// [`Watcher::initialized`]: crate::watcher::Watcher::initialized + pub(crate) fn home_relay(&self) -> watcher::Direct<Option<RelayUrl>> { self.my_relay.watch() } @@ -345,7 +351,9 @@ impl MagicSock { /// /// Will return an error if there is no address information known about the /// given `node_id`. - pub(crate) fn conn_type(&self, node_id: NodeId) -> Result<Watcher<ConnectionType>> { + /// + /// [`Watcher`]: crate::watcher::Watcher + pub(crate) fn conn_type(&self, node_id: NodeId) -> Result<watcher::Direct<ConnectionType>> { self.node_map.conn_type(node_id) } @@ -2853,7 +2861,9 @@ mod tests { use super::*; use crate::{ defaults::staging::{self, EU_RELAY_HOSTNAME}, - tls, Endpoint, RelayMode, + tls, + watcher::Watcher as _, + Endpoint, RelayMode, }; const ALPN: &[u8] = b"n0/test/1"; @@ -3198,7 +3208,7 @@ mod tests { println!("first conn!"); let conn = m1 .endpoint - .connect(m2.endpoint.node_addr().await?, ALPN) + .connect(m2.endpoint.node_addr().initialized().await?, ALPN) .await?; println!("Closing first conn"); conn.close(0u32.into(), b"bye lolz"); diff --git a/iroh/src/magicsock/node_map.rs b/iroh/src/magicsock/node_map.rs index d25fbf84fe..8ceaf1d706 100644 --- a/iroh/src/magicsock/node_map.rs +++ b/iroh/src/magicsock/node_map.rs @@ -21,7 +21,7 @@ use super::{ }; use crate::{ disco::{CallMeMaybe, Pong, SendAddr}, - watchable::Watcher, + watcher, }; mod best_addr; @@ -291,7 +291,12 @@ impl NodeMap { /// /// Will return an error if there is not an entry in the [`NodeMap`] for /// the `node_id` - pub(super) fn conn_type(&self, node_id: NodeId) -> anyhow::Result<Watcher<ConnectionType>> { + /// + /// [`Watcher`]: crate::watcher::Watcher + pub(super) fn conn_type( + &self, + node_id: NodeId, + ) -> anyhow::Result<watcher::Direct<ConnectionType>> { self.inner.lock().expect("poisoned").conn_type(node_id) } @@ -459,7 +464,7 @@ impl NodeMapInner { /// /// Will return an error if there is not an entry in the [`NodeMap`] for /// the `public_key` - fn conn_type(&self, node_id: NodeId) -> anyhow::Result<Watcher<ConnectionType>> { + fn conn_type(&self, node_id: NodeId) -> anyhow::Result<watcher::Direct<ConnectionType>> { match self.get(NodeStateKey::NodeId(node_id)) { Some(ep) => Ok(ep.conn_type()), None => anyhow::bail!("No endpoint for {node_id:?} found"), diff --git a/iroh/src/magicsock/node_map/node_state.rs b/iroh/src/magicsock/node_map/node_state.rs index d62d4c294e..1f457812ed 100644 --- a/iroh/src/magicsock/node_map/node_state.rs +++ b/iroh/src/magicsock/node_map/node_state.rs @@ -24,7 +24,7 @@ use crate::{ disco::{self, SendAddr}, magicsock::{ActorMessage, MagicsockMetrics, QuicMappedAddr, Timer, HEARTBEAT_INTERVAL}, util::relay_only_mode, - watchable::{Watchable, Watcher}, + watcher::{self, Watchable}, }; /// Number of addresses that are not active that we keep around per node. @@ -191,7 +191,7 @@ impl NodeState { self.id } - pub(super) fn conn_type(&self) -> Watcher<ConnectionType> { + pub(super) fn conn_type(&self) -> watcher::Direct<ConnectionType> { self.conn_type.watch() } diff --git a/iroh/src/watchable.rs b/iroh/src/watcher.rs similarity index 66% rename from iroh/src/watchable.rs rename to iroh/src/watcher.rs index 9ed6cd9056..f0297152bf 100644 --- a/iroh/src/watchable.rs +++ b/iroh/src/watcher.rs @@ -7,6 +7,11 @@ //! In that way, a [`Watchable`] is like a [`tokio::sync::broadcast::Sender`] (and a //! [`Watcher`] is like a [`tokio::sync::broadcast::Receiver`]), except that there's no risk //! of the channel filling up, but instead you might miss items. +//! +//! This module is meant to be imported like this (if you use all of these things): +//! ```ignore +//! use iroh::watcher::{self, Watchable, Watcher as _}; +//! ``` #[cfg(not(iroh_loom))] use std::sync; @@ -18,14 +23,13 @@ use std::{ task::{self, Poll, Waker}, }; -use futures_lite::stream::Stream; #[cfg(iroh_loom)] use loom::sync; use sync::{Mutex, RwLock}; /// A wrapper around a value that notifies [`Watcher`]s when the value is modified. /// -/// Only the most recent value is available to any observer, but but observer is guaranteed +/// Only the most recent value is available to any observer, but the observer is guaranteed /// to be notified of the most recent value. #[derive(Debug, Default)] pub struct Watchable<T> { @@ -87,9 +91,9 @@ impl<T: Clone + Eq> Watchable<T> { ret } - /// Creates a [`Watcher`] allowing the value to be observed, but not modified. - pub fn watch(&self) -> Watcher<T> { - Watcher { + /// Creates a [`Direct`] [`Watcher`], allowing the value to be observed, but not modified. + pub fn watch(&self) -> Direct<T> { + Direct { epoch: self.shared.state.read().expect("poisoned").epoch, shared: Arc::downgrade(&self.shared), } @@ -101,29 +105,45 @@ impl<T: Clone + Eq> Watchable<T> { } } -/// An observer for a value. +/// A handle to a value that's represented by one or more underlying [`Watchable`]s. /// -/// The [`Watcher`] can get the current value, and will be notified when the value changes. -/// Only the most recent value is accessible, and if the thread with the [`Watchable`] -/// changes the value faster than the thread with the [`Watcher`] can keep up with, then +/// A [`Watcher`] can get the current value, and will be notified when the value changes. +/// Only the most recent value is accessible, and if the threads with the underlying [`Watchable`]s +/// change the value faster than the threads with the [`Watcher`] can keep up with, then /// it'll miss in-between values. /// When the thread changing the [`Watchable`] pauses updating, the [`Watcher`] will always /// end up reporting the most recent state eventually. -#[derive(Debug, Clone)] -pub struct Watcher<T> { - epoch: u64, - shared: Weak<Shared<T>>, -} - -impl<T: Clone + Eq> Watcher<T> { - /// Returns the currently held value. +/// +/// Watchers can be modified via [`Watcher::map`] to observe a value derived from the original +/// value via a function. +/// +/// Watchers can be combined via [`Watcher::or`] to allow observing multiple values at once and +/// getting an update in case any of the values updates. +/// +/// One of the underlying [`Watchable`]s might already be dropped. In that case, +/// the watcher will be "disconnected" and return [`Err(Disconnected)`](Disconnected) +/// on some function calls or, when turned into a stream, that stream will end. +pub trait Watcher: Clone { + /// The type of value that can change. /// - /// Returns [`Err(Disconnected)`](Disconnected) if the original - /// [`Watchable`] was dropped. - pub fn get(&self) -> Result<T, Disconnected> { - let shared = self.shared.upgrade().ok_or(Disconnected)?; - Ok(shared.get()) - } + /// We require `Clone`, because we need to be able to make + /// the values have a lifetime that's detached from the original [`Watchable`]'s + /// lifetime. + /// + /// We require `Eq`, to be able to check whether the value actually changed or + /// not, so we can notify or not notify accordingly. + type Value: Clone + Eq; + + /// Returns the current state of the underlying value, or errors out with + /// [`Disconnected`], if one of the underlying [`Watchable`]s has been dropped. + fn get(&self) -> Result<Self::Value, Disconnected>; + + /// Polls for the next value, or returns [`Disconnected`] if one of the underlying + /// [`Watchable`]s has been dropped. + fn poll_updated( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll<Result<Self::Value, Disconnected>>; /// Returns a future completing with `Ok(value)` once a new value is set, or with /// [`Err(Disconnected)`](Disconnected) if the connected [`Watchable`] was dropped. @@ -131,8 +151,32 @@ impl<T: Clone + Eq> Watcher<T> { /// # Cancel Safety /// /// The returned future is cancel-safe. - pub fn updated(&mut self) -> WatchNextFut<T> { - WatchNextFut { watcher: self } + fn updated(&mut self) -> NextFut<Self> { + NextFut { watcher: self } + } + + /// Returns a future completing once the value is set to [`Some`] value. + /// + /// If the current value is [`Some`] value, this future will resolve immediately. + /// + /// This is a utility for the common case of storing an [`Option`] inside a + /// [`Watchable`]. + /// + /// # Cancel Safety + /// + /// The returned future is cancel-safe. + fn initialized<T>(&mut self) -> InitializedFut<T, Self> + where + Self: Watcher<Value = Option<T>>, + { + InitializedFut { + initial: match self.get() { + Ok(Some(value)) => Some(Ok(value)), + Ok(None) => None, + Err(Disconnected) => Some(Err(Disconnected)), + }, + watcher: self, + } } /// Returns a stream which will yield the most recent values as items. @@ -148,10 +192,14 @@ impl<T: Clone + Eq> Watcher<T> { /// # Cancel Safety /// /// The returned stream is cancel-safe. - pub fn stream(mut self) -> WatcherStream<T> { - debug_assert!(self.epoch > 0); - self.epoch -= 1; - WatcherStream { watcher: self } + fn stream(self) -> Stream<Self> + where + Self: Unpin, + { + Stream { + initial: self.get().ok(), + watcher: self, + } } /// Returns a stream which will yield the most recent values as items, starting from @@ -168,21 +216,126 @@ impl<T: Clone + Eq> Watcher<T> { /// # Cancel Safety /// /// The returned stream is cancel-safe. - pub fn stream_updates_only(self) -> WatcherStream<T> { - WatcherStream { watcher: self } + fn stream_updates_only(self) -> Stream<Self> + where + Self: Unpin, + { + Stream { + initial: None, + watcher: self, + } } -} -impl<T: Clone + Eq> Watcher<Option<T>> { - /// Returns a future completing once the value is set to [`Some`] value. + /// Maps this watcher with a function that transforms the observed values. /// - /// If the current value is [`Some`] value, this future will resolve immediately. - /// - /// This is a utility for the common case of storing an [`Option`] inside a - /// [`Watchable`]. - pub fn initialized(&mut self) -> WatchInitializedFut<T> { - self.epoch = PRE_INITIAL_EPOCH; - WatchInitializedFut { watcher: self } + /// The returned watcher will only register updates, when the *mapped* value + /// observably changes. For this, it needs to store a clone of `T` in the watcher. + fn map<T: Clone + Eq>( + self, + map: impl Fn(Self::Value) -> T + 'static, + ) -> Result<Map<Self, T>, Disconnected> { + Ok(Map { + current: (map)(self.get()?), + map: Arc::new(map), + watcher: self, + }) + } + + /// Returns a watcher that updates every time this or the other watcher + /// updates, and yields both watcher's items together when that happens. + fn or<W: Watcher>(self, other: W) -> (Self, W) { + (self, other) + } +} + +/// The immediate, direct observer of a [`Watchable`] value. +/// +/// This type is mainly used via the [`Watcher`] interface. +#[derive(Debug, Clone)] +pub struct Direct<T> { + epoch: u64, + shared: Weak<Shared<T>>, +} + +impl<T: Clone + Eq> Watcher for Direct<T> { + type Value = T; + + fn get(&self) -> Result<Self::Value, Disconnected> { + let shared = self.shared.upgrade().ok_or(Disconnected)?; + Ok(shared.get()) + } + + fn poll_updated( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll<Result<Self::Value, Disconnected>> { + let Some(shared) = self.shared.upgrade() else { + return Poll::Ready(Err(Disconnected)); + }; + match shared.poll_updated(cx, self.epoch) { + Poll::Pending => Poll::Pending, + Poll::Ready((current_epoch, value)) => { + self.epoch = current_epoch; + Poll::Ready(Ok(value)) + } + } + } +} + +impl<S: Watcher, T: Watcher> Watcher for (S, T) { + type Value = (S::Value, T::Value); + + fn get(&self) -> Result<Self::Value, Disconnected> { + Ok((self.0.get()?, self.1.get()?)) + } + + fn poll_updated( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll<Result<Self::Value, Disconnected>> { + let poll_0 = self.0.poll_updated(cx)?; + let poll_1 = self.1.poll_updated(cx)?; + match (poll_0, poll_1) { + (Poll::Ready(s), Poll::Ready(t)) => Poll::Ready(Ok((s, t))), + (Poll::Ready(s), Poll::Pending) => Poll::Ready(self.1.get().map(move |t| (s, t))), + (Poll::Pending, Poll::Ready(t)) => Poll::Ready(self.0.get().map(move |s| (s, t))), + (Poll::Pending, Poll::Pending) => Poll::Pending, + } + } +} + +/// Wraps a [`Watcher`] to allow observing a derived value. +/// +/// See [`Watcher::map`]. +#[derive(derive_more::Debug, Clone)] +pub struct Map<W: Watcher, T: Clone + Eq> { + #[debug("Arc<dyn Fn(W::Value) -> T + 'static>")] + map: Arc<dyn Fn(W::Value) -> T + 'static>, + watcher: W, + current: T, +} + +impl<W: Watcher, T: Clone + Eq> Watcher for Map<W, T> { + type Value = T; + + fn get(&self) -> Result<Self::Value, Disconnected> { + Ok((self.map)(self.watcher.get()?)) + } + + fn poll_updated( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll<Result<Self::Value, Disconnected>> { + loop { + let value = futures_lite::ready!(self.watcher.poll_updated(cx)?); + let mapped = (self.map)(value); + if mapped != self.current { + self.current = mapped.clone(); + return Poll::Ready(Ok(mapped)); + } else { + self.current = mapped; + } + } } } @@ -194,24 +347,15 @@ impl<T: Clone + Eq> Watcher<Option<T>> { /// /// This future is cancel-safe. #[derive(Debug)] -pub struct WatchNextFut<'a, T> { - watcher: &'a mut Watcher<T>, +pub struct NextFut<'a, W: Watcher> { + watcher: &'a mut W, } -impl<T: Clone + Eq> Future for WatchNextFut<'_, T> { - type Output = Result<T, Disconnected>; +impl<W: Watcher> Future for NextFut<'_, W> { + type Output = Result<W::Value, Disconnected>; fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { - let Some(shared) = self.watcher.shared.upgrade() else { - return Poll::Ready(Err(Disconnected)); - }; - match shared.poll_next(cx, self.watcher.epoch) { - Poll::Pending => Poll::Pending, - Poll::Ready((current_epoch, value)) => { - self.watcher.epoch = current_epoch; - Poll::Ready(Ok(value)) - } - } + self.watcher.poll_updated(cx) } } @@ -224,22 +368,22 @@ impl<T: Clone + Eq> Future for WatchNextFut<'_, T> { /// /// This Future is cancel-safe. #[derive(Debug)] -pub struct WatchInitializedFut<'a, T> { - watcher: &'a mut Watcher<Option<T>>, +pub struct InitializedFut<'a, T, W: Watcher<Value = Option<T>>> { + initial: Option<Result<T, Disconnected>>, + watcher: &'a mut W, } -impl<T: Clone + Eq> Future for WatchInitializedFut<'_, T> { +impl<T: Clone + Eq + Unpin, W: Watcher<Value = Option<T>> + Unpin> Future + for InitializedFut<'_, T, W> +{ type Output = Result<T, Disconnected>; fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + if let Some(value) = self.as_mut().initial.take() { + return Poll::Ready(value); + } loop { - let Some(shared) = self.watcher.shared.upgrade() else { - return Poll::Ready(Err(Disconnected)); - }; - let (epoch, value) = futures_lite::ready!(shared.poll_next(cx, self.watcher.epoch)); - self.watcher.epoch = epoch; - - if let Some(value) = value { + if let Some(value) = futures_lite::ready!(self.as_mut().watcher.poll_updated(cx)?) { return Poll::Ready(Ok(value)); } } @@ -254,23 +398,25 @@ impl<T: Clone + Eq> Future for WatchInitializedFut<'_, T> { /// /// This stream is cancel-safe. #[derive(Debug, Clone)] -pub struct WatcherStream<T> { - watcher: Watcher<T>, +pub struct Stream<W: Watcher + Unpin> { + initial: Option<W::Value>, + watcher: W, } -impl<T: Clone + Eq> Stream for WatcherStream<T> { - type Item = T; +impl<W: Watcher + Unpin> futures_lite::stream::Stream for Stream<W> +where + W::Value: Unpin, +{ + type Item = W::Value; fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> { - let Some(shared) = self.watcher.shared.upgrade() else { - return Poll::Ready(None); - }; - match shared.poll_next(cx, self.watcher.epoch) { + if let Some(value) = self.as_mut().initial.take() { + return Poll::Ready(Some(value)); + } + match self.as_mut().watcher.poll_updated(cx) { + Poll::Ready(Ok(value)) => Poll::Ready(Some(value)), + Poll::Ready(Err(Disconnected)) => Poll::Ready(None), Poll::Pending => Poll::Pending, - Poll::Ready((epoch, value)) => { - self.watcher.epoch = epoch; - Poll::Ready(Some(value)) - } } } } @@ -278,13 +424,12 @@ impl<T: Clone + Eq> Stream for WatcherStream<T> { /// The error for when a [`Watcher`] is disconnected from its underlying /// [`Watchable`] value, because of that watchable having been dropped. #[derive(thiserror::Error, Debug)] -#[error("Watch lost connection to underlying Watchable, it was dropped")] +#[error("Watcher lost connection to underlying Watchable, it was dropped")] pub struct Disconnected; // Private: const INITIAL_EPOCH: u64 = 1; -const PRE_INITIAL_EPOCH: u64 = 0; /// The shared state for a [`Watchable`]. #[derive(Debug, Default)] @@ -315,7 +460,7 @@ impl<T: Clone> Shared<T> { self.state.read().expect("poisoned").value.clone() } - fn poll_next(&self, cx: &mut task::Context<'_>, last_epoch: u64) -> Poll<(u64, T)> { + fn poll_updated(&self, cx: &mut task::Context<'_>, last_epoch: u64) -> Poll<(u64, T)> { { let state = self.state.read().expect("poisoned"); let epoch = state.epoch;