Skip to content

Commit

Permalink
Merge pull request #13 from kolapapa/fix_repeat_bug
Browse files Browse the repository at this point in the history
Support for duplicate address ping
  • Loading branch information
kolapapa authored Feb 21, 2022
2 parents beb5d93 + 7278fef commit 1f08361
Show file tree
Hide file tree
Showing 11 changed files with 112 additions and 79 deletions.
18 changes: 9 additions & 9 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "surge-ping"
version = "0.4.1"
version = "0.4.2"
authors = ["kolapapa <[email protected]>"]
edition = "2018"
license = "MIT"
Expand All @@ -11,20 +11,20 @@ keywords = ["tokio", "icmp", "ping"]
categories = ["network-programming", "asynchronous"]

[dependencies]
log = "0.4.14"
parking_lot = "0.11.2"
parking_lot = "0.12.0"
pnet_packet = "0.29.0"
rand = "0.8.4"
socket2 = { version = "0.4.3", features = ["all"] }
rand = "0.8.5"
socket2 = { version = "0.4.4", features = ["all"] }
thiserror = "1.0.30"
tokio = { version = "1.15.0", features = ["time", "macros"] }
tokio = { version = "1.17.0", features = ["time", "macros"] }
tracing = "0.1.31"
uuid = { version = "0.8.2", features = ["v4"] }

[dev-dependencies]
log = "0.4.14"
structopt = "0.3.26"
pretty_env_logger = "0.4.0"
tokio = { version = "1.15.0", features = ["full"] }
futures = "0.3.19"
tokio = { version = "1.17.0", features = ["full"] }
futures = "0.3.21"

[[example]]
name = "simple"
Expand Down
20 changes: 14 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ rust ping libray based on `tokio` + `socket2` + `pnet_packet`.
```rust
use std::time::Duration;

use surge_ping::Pinger;
use surge_ping::IcmpPacket;

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
Expand All @@ -24,11 +24,19 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
// let mut pinger = client.pinger("114.114.114.114".parse()?);
pinger.timeout(Duration::from_secs(1));
for seq_cnt in 0..10 {
let (reply, dur) = pinger.ping(seq_cnt).await?;
println!(
"{} bytes from {}: icmp_seq={} ttl={:?} time={:?}",
reply.size, reply.source, reply.sequence, reply.ttl, dur
);
match pinger.ping(seq_cnt).await? {
(IcmpPacket::V4(packet), dur) => {
println!(
"{} bytes from {}: icmp_seq={} ttl={:?} time={:?}",
packet.get_size(),
packet.get_source(),
packet.get_sequence(),
packet.get_ttl(),
dur
);
}
(IcmpPacket::V6(_), dur) => unreachable!(),
}
}
Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion examples/cmd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ async fn main() {
}
let config = config_builder.build();

let client = Client::new(&config).unwrap();
let client = Client::new(&config).await.unwrap();
let mut pinger = client.pinger(ip).await;
pinger.timeout(Duration::from_secs(opt.timeout));

Expand Down
6 changes: 4 additions & 2 deletions examples/multi_ping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,18 @@ use tokio::time;

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// test same url 114.114.114.114
let ips = [
"114.114.114.114",
"8.8.8.8",
"39.156.69.79",
"172.217.26.142",
"240c::6666",
"2a02:930::ff76",
"114.114.114.114",
];
let client_v4 = Client::new(&Config::default())?;
let client_v6 = Client::new(&Config::builder().kind(ICMP::V6).build())?;
let client_v4 = Client::new(&Config::default()).await?;
let client_v6 = Client::new(&Config::builder().kind(ICMP::V6).build()).await?;
let mut tasks = Vec::new();
for ip in &ips {
match ip.parse() {
Expand Down
2 changes: 1 addition & 1 deletion examples/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ async fn main() {
}
let config = config_builder.build();

let client = Client::new(&config).unwrap();
let client = Client::new(&config).await.unwrap();
let mut pinger = client.pinger(ip).await;
pinger
.ident(111)
Expand Down
78 changes: 45 additions & 33 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@ use std::{
time::Instant,
};

use log::trace;
use pnet_packet::{icmp, icmpv6, ipv4, ipv6, Packet};
use socket2::{Domain, Protocol, Socket, Type};
use tokio::{
net::UdpSocket,
sync::{broadcast, mpsc, Mutex},
sync::{mpsc, Mutex},
task,
};
use tracing::warn;
use uuid::Uuid;

use crate::{config::Config, Pinger, ICMP};

Expand Down Expand Up @@ -86,59 +88,69 @@ impl AsyncSocket {
#[derive(Clone)]
pub struct Client {
socket: AsyncSocket,
mapping: Arc<Mutex<HashMap<IpAddr, mpsc::Sender<Message>>>>,
mapping: Arc<Mutex<HashMap<Uuid, mpsc::Sender<Message>>>>,
}

impl Client {
/// A client is generated according to the configuration. In fact, a `AsyncSocket` is wrapped inside,
/// and you can clone to any `task` at will.
pub fn new(config: &Config) -> io::Result<Self> {
pub async fn new(config: &Config) -> io::Result<Self> {
let socket = AsyncSocket::new(config)?;
Ok(Self {
socket,
mapping: Arc::new(Mutex::new(HashMap::new())),
})
let mapping = Arc::new(Mutex::new(HashMap::new()));
task::spawn(recv_task(socket.clone(), mapping.clone()));
Ok(Self { socket, mapping })
}

/// Create a `Pinger` instance, you can make special configuration for this instance. Such as `timeout`, `size` etc.
pub async fn pinger(&self, host: IpAddr) -> Pinger {
let (shutdown_tx, _) = broadcast::channel(1);
let (tx, rx) = mpsc::channel(10);
let key = Uuid::new_v4();
{
self.mapping.lock().await.insert(host, tx);
self.mapping.lock().await.insert(key, tx);
}
task::spawn(recv_task(
self.socket.clone(),
self.mapping.clone(),
shutdown_tx.subscribe(),
));
Pinger::new(host, self.socket.clone(), rx, shutdown_tx)
Pinger::new(host, self.socket.clone(), rx, key)
}
}

async fn recv_task(
socket: AsyncSocket,
mapping: Arc<Mutex<HashMap<IpAddr, mpsc::Sender<Message>>>>,
mut shutdown_rx: broadcast::Receiver<()>,
) {
async fn recv_task(socket: AsyncSocket, mapping: Arc<Mutex<HashMap<Uuid, mpsc::Sender<Message>>>>) {
let mut buf = [0; 2048];
loop {
tokio::select! {
answer = socket.recv_from(&mut buf) => {
if let Ok((sz, addr)) = answer {
let instant = Instant::now();
let mut w = mapping.lock().await;
if let Some(tx) = (*w).get(&addr.ip()) {
if tx.send(Message::new(instant, buf[0..sz].to_vec())).await.is_err() {
trace!("send message error");
(*w).remove(&addr.ip());
}
if let Ok((sz, addr)) = socket.recv_from(&mut buf).await {
let datas = buf[0..sz].to_vec();
if let Some(uuid) = gen_uuid_with_payload(addr.ip(), datas.as_slice()) {
let instant = Instant::now();
let mut w = mapping.lock().await;
if let Some(tx) = (*w).get(&uuid) {
if tx.send(Message::new(instant, datas)).await.is_err() {
warn!("Pinger({}) already closed.", addr);
(*w).remove(&uuid);
}
}
}
_ = shutdown_rx.recv() => {
break
}
}
}

fn gen_uuid_with_payload(addr: IpAddr, datas: &[u8]) -> Option<Uuid> {
match addr {
IpAddr::V4(_) => {
if let Some(ip_packet) = ipv4::Ipv4Packet::new(datas) {
if let Some(icmp_packet) = icmp::IcmpPacket::new(ip_packet.payload()) {
let payload = icmp_packet.payload();
let uuid = &payload[4..20];
return Uuid::from_slice(uuid).ok();
}
}
}
IpAddr::V6(_) => {
if let Some(ipv6_packet) = ipv6::Ipv6Packet::new(datas) {
if let Some(icmpv6_packet) = icmpv6::Icmpv6Packet::new(ipv6_packet.payload()) {
let payload = icmpv6_packet.payload();
let uuid = &payload[4..20];
return Uuid::from_slice(uuid).ok();
}
}
}
}
None
}
8 changes: 7 additions & 1 deletion src/icmp/icmpv4.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,19 @@ use pnet_packet::{ipv4, PacketSize};

use crate::error::{MalformedPacketError, Result, SurgeError};

pub fn make_icmpv4_echo_packet(ident: u16, seq_cnt: u16, size: usize) -> Result<Vec<u8>> {
pub fn make_icmpv4_echo_packet(
ident: u16,
seq_cnt: u16,
size: usize,
key: &[u8],
) -> Result<Vec<u8>> {
let mut buf = vec![0; 8 + size]; // 8 bytes of header, then payload
let mut packet = icmp::echo_request::MutableEchoRequestPacket::new(&mut buf[..])
.ok_or(SurgeError::IncorrectBufferSize)?;
packet.set_icmp_type(icmp::IcmpTypes::EchoRequest);
packet.set_identifier(ident);
packet.set_sequence_number(seq_cnt);
packet.set_payload(key);

// Calculate and set the checksum
let icmp_packet =
Expand Down
8 changes: 7 additions & 1 deletion src/icmp/icmpv6.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,19 @@ use pnet_packet::PacketSize;
use crate::error::{MalformedPacketError, Result, SurgeError};

#[allow(dead_code)]
pub fn make_icmpv6_echo_packet(ident: u16, seq_cnt: u16, size: usize) -> Result<Vec<u8>> {
pub fn make_icmpv6_echo_packet(
ident: u16,
seq_cnt: u16,
size: usize,
key: &[u8],
) -> Result<Vec<u8>> {
let mut buf = vec![0; 8 + size]; // 8 bytes of header, then payload
let mut packet = icmpv6::echo_request::MutableEchoRequestPacket::new(&mut buf[..])
.ok_or(SurgeError::IncorrectBufferSize)?;
packet.set_icmpv6_type(icmpv6::Icmpv6Types::EchoRequest);
packet.set_identifier(ident);
packet.set_sequence_number(seq_cnt);
packet.set_payload(key);

// Per https://tools.ietf.org/html/rfc3542#section-3.1 the checksum is
// omitted, the kernel will insert it.
Expand Down
4 changes: 2 additions & 2 deletions src/icmp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ impl IcmpPacket {
destination.eq(&IpAddr::V4(packet.get_real_dest()))
&& packet.get_sequence() == seq_cnt
&& packet.get_identifier() == identifier
},
}
IcmpPacket::V6(packet) => {
packet.get_sequence() == seq_cnt && packet.get_identifier() == identifier
packet.get_sequence() == seq_cnt && packet.get_identifier() == identifier
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ pub async fn pinger(host: IpAddr) -> Result<Pinger, SurgeError> {
IpAddr::V4(_) => Config::default(),
IpAddr::V6(_) => Config::builder().kind(ICMP::V6).build(),
};
let client = Client::new(&config)?;
let client = Client::new(&config).await?;
let pinger = client.pinger(host).await;
Ok(pinger)
}
43 changes: 21 additions & 22 deletions src/ping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,11 @@ use std::{
time::{Duration, Instant},
};

use log::trace;
use parking_lot::Mutex;
use rand::random;
use tokio::{
sync::{broadcast, mpsc},
task,
time::timeout,
};
use tokio::{sync::mpsc, task, time::timeout};
use tracing::error;
use uuid::Uuid;

use crate::client::{AsyncSocket, Message};
use crate::error::{Result, SurgeError};
Expand Down Expand Up @@ -50,23 +47,15 @@ pub struct Pinger {
socket: AsyncSocket,
rx: mpsc::Receiver<Message>,
cache: Cache,
shutdown_notify: broadcast::Sender<()>,
}

impl Drop for Pinger {
fn drop(&mut self) {
if self.shutdown_notify.send(()).is_err() {
trace!("notify shutdown error");
}
}
key: Uuid,
}

impl Pinger {
pub(crate) fn new(
host: IpAddr,
socket: AsyncSocket,
rx: mpsc::Receiver<Message>,
shutdown_notify: broadcast::Sender<()>,
key: Uuid,
) -> Pinger {
Pinger {
destination: host,
Expand All @@ -76,7 +65,7 @@ impl Pinger {
socket,
rx,
cache: Cache::new(),
shutdown_notify,
key,
}
}

Expand All @@ -86,9 +75,9 @@ impl Pinger {
self
}

/// Set the packet size.(default: 56)
/// Set the packet payload size, minimal is 16. (default: 56)
pub fn size(&mut self, size: usize) -> &mut Pinger {
self.size = size;
self.size = if size < 16 { 16 } else { size };
self
}

Expand Down Expand Up @@ -125,16 +114,26 @@ impl Pinger {
pub async fn ping(&mut self, seq_cnt: u16) -> Result<(IcmpPacket, Duration)> {
let sender = self.socket.clone();
let mut packet = match self.destination {
IpAddr::V4(_) => icmpv4::make_icmpv4_echo_packet(self.ident, seq_cnt, self.size)?,
IpAddr::V6(_) => icmpv6::make_icmpv6_echo_packet(self.ident, seq_cnt, self.size)?,
IpAddr::V4(_) => icmpv4::make_icmpv4_echo_packet(
self.ident,
seq_cnt,
self.size,
self.key.as_bytes(),
)?,
IpAddr::V6(_) => icmpv6::make_icmpv6_echo_packet(
self.ident,
seq_cnt,
self.size,
self.key.as_bytes(),
)?,
};
// let mut packet = EchoRequest::new(self.host, self.ident, seq_cnt, self.size).encode()?;
let sock_addr = SocketAddr::new(self.destination, 0);
let ident = self.ident;
let cache = self.cache.clone();
task::spawn(async move {
if let Err(e) = sender.send_to(&mut packet, &sock_addr).await {
trace!("socket send packet error: {}", e)
error!("socket send packet error: {}", e)
}
cache.insert(ident, seq_cnt, Instant::now());
});
Expand Down

0 comments on commit 1f08361

Please sign in to comment.