diff --git a/skunk-cli/Cargo.toml b/skunk-cli/Cargo.toml index cc184f9..3f5d600 100644 --- a/skunk-cli/Cargo.toml +++ b/skunk-cli/Cargo.toml @@ -49,6 +49,7 @@ semver-macro = "0.1.0" serde = { version = "1.0.201", features = ["derive"] } thiserror = "1.0.61" tokio = { version = "1.37.0", features = ["macros", "rt-multi-thread", "signal"] } +tokio-util = "0.7.11" toml = "0.8.12" toml_edit = { version = "0.22.15", features = ["serde"] } tower-http = { version = "0.5.2", features = ["fs"] } diff --git a/skunk-cli/src/proxy.rs b/skunk-cli/src/proxy.rs index 293005e..38cf027 100644 --- a/skunk-cli/src/proxy.rs +++ b/skunk-cli/src/proxy.rs @@ -26,15 +26,13 @@ use skunk::{ Passthrough, Proxy, }, - util::{ - error::ResultExt, - CancellationToken, - }, }; +use skunk_util::error::ResultExt; use tokio::{ net::TcpStream, task::JoinSet, }; +use tokio_util::sync::CancellationToken; use tracing::Instrument; use crate::{ @@ -151,20 +149,24 @@ pub async fn run(environment: Environment, args: ProxyArgs) -> Result<(), Error> let shutdown = shutdown.clone(); let interface = interface.clone(); async move { - if args.pcap.enabled { + let _hostapd = if args.pcap.ap { let country_code = std::env::var("HOSTAPD_CC") .expect("Environment variable `HOSTAPD_CC` not set. You need to set this variable to your country code."); tracing::info!("Starting hostapd"); let mut hostapd = pcap::ap::Builder::new(&interface, &country_code) .with_channel(11) - .with_graceful_shutdown(shutdown.clone()) .start()?; tracing::info!("Waiting for hostapd to configure the interface..."); hostapd.ready().await?; tracing::info!("hostapd ready"); + + Some(hostapd) } + else { + None + }; let _network = VirtualNetwork::new(&interface)?; shutdown.cancelled().await; diff --git a/skunk-cli/src/util/shutdown.rs b/skunk-cli/src/util/shutdown.rs index fc322db..8ae1abd 100644 --- a/skunk-cli/src/util/shutdown.rs +++ b/skunk-cli/src/util/shutdown.rs @@ -1,4 +1,4 @@ -use skunk::util::CancellationToken; +use tokio_util::sync::CancellationToken; /// Resolves when the application receives SIGTERM on unix systems, or never on /// other systems. diff --git a/skunk-util/Cargo.toml b/skunk-util/Cargo.toml index efe01a8..68848fe 100644 --- a/skunk-util/Cargo.toml +++ b/skunk-util/Cargo.toml @@ -5,8 +5,12 @@ edition = "2021" [features] default = [] -full = ["trigger"] trigger = ["dep:tokio"] +ordered-multimap = ["dep:ahash", "dep:hashbrown"] +error = [] [dependencies] +ahash = { version = "0.8.11", optional = true } +hashbrown = { version = "0.14.5", optional = true } tokio = { version = "1.37.0", default-features = false, features = ["sync"], optional = true } +tracing = "0.1.40" diff --git a/skunk/src/util/error.rs b/skunk-util/src/error.rs similarity index 100% rename from skunk/src/util/error.rs rename to skunk-util/src/error.rs diff --git a/skunk-util/src/lib.rs b/skunk-util/src/lib.rs index 49be85f..e1ca23a 100644 --- a/skunk-util/src/lib.rs +++ b/skunk-util/src/lib.rs @@ -1,2 +1,8 @@ #[cfg(feature = "trigger")] pub mod trigger; + +#[cfg(feature = "error")] +pub mod error; + +#[cfg(feature = "ordered-multimap")] +pub mod ordered_multimap; diff --git a/skunk/src/util/ordered_multimap.rs b/skunk-util/src/ordered_multimap.rs similarity index 98% rename from skunk/src/util/ordered_multimap.rs rename to skunk-util/src/ordered_multimap.rs index 9f40321..4dcd94f 100644 --- a/skunk/src/util/ordered_multimap.rs +++ b/skunk-util/src/ordered_multimap.rs @@ -503,14 +503,12 @@ impl<'a, K, V, H> OccupiedEntryMut<'a, K, V, H> { pub fn append(&mut self, key: K, value: V) -> &mut V { let tail = self.list().tail; - let bucket = self.map.buckets.get_index(&self.bucket); let index = self.map.pairs.push(Pair { key, value, next: None, prev: Some(tail), - bucket, }); self.map.pairs.get_mut(tail).next = Some(index); @@ -582,14 +580,12 @@ impl<'a, K, V, H> VacantEntryMut<'a, K, V, H> { count: 1, }, ); - let bucket_index = self.map.buckets.get_index(&bucket); let index2 = self.map.pairs.push(Pair { key, value, next: None, prev: None, - bucket: bucket_index, }); assert_eq!(index, index2); @@ -878,16 +874,12 @@ impl ExactSizeIterator for IntoIter {} #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] struct PairIndex(usize); -#[derive(Clone, Copy)] -struct BucketIndex(usize); - #[derive(Clone)] struct Pair { key: K, value: V, next: Option, prev: Option, - bucket: BucketIndex, } #[derive(Clone, Copy)] @@ -978,10 +970,6 @@ impl Buckets { } impl Buckets { - pub fn get_index(&self, bucket: &Bucket) -> BucketIndex { - BucketIndex(unsafe { self.inner.bucket_index(bucket) }) - } - pub fn remove(&mut self, bucket: Bucket) -> (List, InsertSlot) { unsafe { self.inner.remove(bucket) } } diff --git a/skunk/Cargo.toml b/skunk/Cargo.toml index ae5b59f..90f278c 100644 --- a/skunk/Cargo.toml +++ b/skunk/Cargo.toml @@ -49,8 +49,11 @@ path = "../byst" [dependencies.skunk-macros] path = "../skunk-macros" +[dependencies.skunk-util] +path = "../skunk-util" +features = ["error", "ordered-multimap"] + [dependencies] -ahash = "0.8.11" bitflags = "2.5.0" bytes = "1.6.0" crc = "3.2.1" @@ -83,6 +86,6 @@ tempfile = "3.10.1" thiserror = "1.0.60" tokio = { version = "1.37.0", features = ["macros", "net", "io-util", "process"] } tokio-rustls = { version = "0.26.0", optional = true } -tokio-util = "0.7.11" +#tokio-util = "0.7.11" tracing = "0.1.40" url = { version = "2.5.0", features = ["serde"] } diff --git a/skunk/src/protocol/inet/dhcp.rs b/skunk/src/protocol/inet/dhcp.rs index ebfde72..6fd2bf3 100644 --- a/skunk/src/protocol/inet/dhcp.rs +++ b/skunk/src/protocol/inet/dhcp.rs @@ -15,11 +15,9 @@ use byst::{ }, util::for_tuple, }; +use skunk_util::ordered_multimap::OrderedMultiMap; -use crate::util::{ - network_enum, - ordered_multimap::OrderedMultiMap, -}; +use crate::util::network_enum; /// A [DHCP message][1] /// diff --git a/skunk/src/proxy/pcap/ap.rs b/skunk/src/proxy/pcap/ap.rs index a55799e..bde1284 100644 --- a/skunk/src/proxy/pcap/ap.rs +++ b/skunk/src/proxy/pcap/ap.rs @@ -8,14 +8,15 @@ use tempfile::NamedTempFile; use tokio::{ io::{ AsyncBufReadExt, - AsyncRead, BufReader, }, process::Command, - sync::watch, + sync::{ + oneshot, + watch, + }, task::JoinHandle, }; -use tokio_util::sync::CancellationToken; use tracing::Instrument; use super::interface::Interface; @@ -65,7 +66,6 @@ pub struct Builder<'a> { hw_mode: HwMode, channel: Option, password: Option<&'a str>, - shutdown: CancellationToken, ready: Option>, } @@ -80,7 +80,6 @@ impl<'a> Builder<'a> { hw_mode: Default::default(), channel: None, password: None, - shutdown: Default::default(), ready: None, } } @@ -115,11 +114,6 @@ impl<'a> Builder<'a> { self } - pub fn with_graceful_shutdown(mut self, shutdown: CancellationToken) -> Self { - self.shutdown = shutdown; - self - } - pub fn write_config(&self, mut writer: impl Write) -> Result<(), Error> { writeln!(writer, "interface={}", self.interface.name())?; writeln!(writer, "driver={}", <&'static str>::from(self.driver))?; @@ -156,33 +150,49 @@ impl<'a> Builder<'a> { tracing::debug!(parent: &span, "spawning hostapd"); let (ready_tx, ready_rx) = watch::channel(false); - let shutdown = CancellationToken::new(); + let (shutdown_tx, mut shutdown_rx) = oneshot::channel(); - let join_handle = tokio::spawn({ - let shutdown = shutdown.clone(); + let join_handle = tokio::spawn( async move { // move temp config file here, so it is only deleted once the process // terminates. let _cfg_file = cfg_file; - tokio::select! { - result = handle_stdout(&mut process.stdout, ready_tx) => { - result?; - }, - _ = self.shutdown.cancelled() => {}, - _ = shutdown.cancelled() => {}, - }; + let reader = BufReader::new(process.stdout.as_mut().expect("no stdout")); + let mut lines = reader.lines(); + + loop { + tokio::select! { + result = lines.next_line() => { + if let Some(line) = result? { + if line.contains("AP-ENABLED") { + let _ = ready_tx.send(true); + } + tracing::debug!("{}", line); + } + else { + // EOF on stdout + break; + } + }, + _ = &mut shutdown_rx => { + // either the user sent a shutdown Signal through [`HostApd::stop`], or the sender was dropped. + // either case, we're done. + break; + }, + }; + } tracing::debug!("killing hostapd"); process.kill().await?; Ok::<(), Error>(()) } - .instrument(span) - }); + .instrument(span), + ); Ok(HostApd { join_handle, - shutdown, + shutdown_tx, ready_rx, }) } @@ -190,7 +200,7 @@ impl<'a> Builder<'a> { pub struct HostApd { join_handle: JoinHandle>, - shutdown: CancellationToken, + shutdown_tx: oneshot::Sender<()>, ready_rx: watch::Receiver, } @@ -206,31 +216,9 @@ impl HostApd { }) } - pub async fn wait(self) -> Result<(), Error> { + pub async fn stop(self) -> Result<(), Error> { + let _ = self.shutdown_tx.send(()); self.join_handle.await.ok().transpose()?; Ok(()) } - - pub async fn stop(self) -> Result<(), Error> { - self.shutdown.cancel(); - self.wait().await - } -} - -async fn handle_stdout( - stream_opt: &mut Option, - ready_tx: watch::Sender, -) -> Result<(), Error> { - if let Some(stream) = stream_opt { - let stream = BufReader::new(stream); - let mut lines = stream.lines(); - while let Some(line) = lines.next_line().await? { - let line = line.trim_end(); - if line.ends_with("AP-ENABLED") { - let _ = ready_tx.send(true); - } - tracing::debug!("{}", line); - } - } - Ok(()) } diff --git a/skunk/src/proxy/pcap/arp.rs b/skunk/src/proxy/pcap/arp.rs index edfc59a..7d79e6f 100644 --- a/skunk/src/proxy/pcap/arp.rs +++ b/skunk/src/proxy/pcap/arp.rs @@ -12,6 +12,7 @@ use std::{ }; use futures::TryFutureExt; +use skunk_util::error::ResultExt; use smallvec::SmallVec; use tokio::{ sync::{ @@ -29,12 +30,9 @@ use super::{ SendError, }; pub use crate::protocol::inet::arp::Packet; -use crate::{ - protocol::inet::{ - arp::Operation, - MacAddress, - }, - util::error::ResultExt, +use crate::protocol::inet::{ + arp::Operation, + MacAddress, }; #[derive(Debug)] diff --git a/skunk/src/proxy/pcap/mod.rs b/skunk/src/proxy/pcap/mod.rs index 6b4dd27..53407ab 100644 --- a/skunk/src/proxy/pcap/mod.rs +++ b/skunk/src/proxy/pcap/mod.rs @@ -29,6 +29,7 @@ use skunk_macros::{ ipv4_address, ipv4_network, }; +use skunk_util::error::ResultExt; use tokio::sync::mpsc; use tracing::Instrument; @@ -39,13 +40,10 @@ use self::{ ReceiveError, }, }; -use crate::{ - protocol::inet::{ - ethernet, - ipv4, - MacAddress, - }, - util::error::ResultExt, +use crate::protocol::inet::{ + ethernet, + ipv4, + MacAddress, }; #[derive(Debug)] diff --git a/skunk/src/util/mod.rs b/skunk/src/util/mod.rs index 6b0a758..e331bd8 100644 --- a/skunk/src/util/mod.rs +++ b/skunk/src/util/mod.rs @@ -2,9 +2,7 @@ pub(crate) mod boolean; pub mod crc; -pub mod error; pub mod io; -pub mod ordered_multimap; use std::{ fmt::{ @@ -21,9 +19,7 @@ use std::{ }, }; -pub use byst::util::for_tuple; use parking_lot::Mutex; -pub use tokio_util::sync::CancellationToken; /// [`Oncelock`](std::sync::OnceLock::get_or_try_init) is not stabilized yet, so /// we implement it ourselves. Also we inclose the `Arc`, because why not.