diff --git a/.circleci/config.yml b/.circleci/config.yml
index 1eada56..97a9003 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -7,5 +7,4 @@ workflows:
- rust/lint-test-build:
clippy_arguments: '--all-targets --all-features -- --deny warnings'
release: true
- version: 1.71.1
-
+ version: 1.81.0
diff --git a/src/cid.rs b/src/cid.rs
index b4b0dc5..1db52be 100644
--- a/src/cid.rs
+++ b/src/cid.rs
@@ -1,15 +1,6 @@
-use std::fmt::Debug;
-use std::hash::Hash;
-use std::net::SocketAddr;
-
-/// A remote peer.
-pub trait ConnectionPeer: Clone + Debug + Eq + Hash + PartialEq + Send + Sync {}
-
-impl ConnectionPeer for SocketAddr {}
-
#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)]
pub struct ConnectionId
{
pub send: u16,
pub recv: u16,
- pub peer: P,
+ pub peer_id: P,
}
diff --git a/src/conn.rs b/src/conn.rs
index e8dbcc3..d2336e0 100644
--- a/src/conn.rs
+++ b/src/conn.rs
@@ -8,10 +8,12 @@ use delay_map::HashMapDelay;
use futures::StreamExt;
use tokio::sync::{mpsc, oneshot, Notify};
-use crate::cid::{ConnectionId, ConnectionPeer};
+use crate::cid::ConnectionId;
use crate::congestion;
use crate::event::{SocketEvent, StreamEvent};
use crate::packet::{Packet, PacketBuilder, PacketType, SelectiveAck};
+use crate::peer::ConnectionPeer;
+use crate::peer::Peer;
use crate::recv::ReceiveBuffer;
use crate::send::SendBuffer;
use crate::sent::{SentPackets, SentPacketsError};
@@ -167,9 +169,10 @@ impl From for congestion::Config {
}
}
-pub struct Connection {
+pub struct Connection {
state: State,
- cid: ConnectionId,
+ cid: ConnectionId,
+ peer: Peer,
config: ConnectionConfig,
endpoint: Endpoint,
peer_ts_diff: Duration,
@@ -185,7 +188,8 @@ pub struct Connection {
impl Connection {
pub fn new(
- cid: ConnectionId,
+ cid: ConnectionId,
+ peer: Peer,
config: ConnectionConfig,
syn: Option,
connected: oneshot::Sender>,
@@ -212,6 +216,7 @@ impl Connection {
Self {
state: State::Connecting(Some(connected)),
cid,
+ peer,
config,
endpoint,
peer_ts_diff,
@@ -232,7 +237,7 @@ impl Connection {
mut writes: mpsc::UnboundedReceiver,
mut shutdown: oneshot::Receiver<()>,
) -> io::Result<()> {
- tracing::debug!("uTP conn starting... {:?}", self.cid.peer);
+ tracing::debug!("uTP conn starting... {:?}", self.peer);
// If we are the initiating endpoint, then send the SYN. If we are the accepting endpoint,
// then send the SYN-ACK.
@@ -240,7 +245,7 @@ impl Connection {
Endpoint::Initiator((syn_seq_num, ..)) => {
let syn = self.syn_packet(syn_seq_num);
self.socket_events
- .send(SocketEvent::Outgoing((syn.clone(), self.cid.peer.clone())))
+ .send(SocketEvent::Outgoing((syn.clone(), self.peer.clone())))
.unwrap();
self.unacked
.insert_at(syn_seq_num, syn, self.config.initial_timeout);
@@ -250,7 +255,7 @@ impl Connection {
Endpoint::Acceptor((syn, syn_ack)) => {
let state = self.state_packet().unwrap();
self.socket_events
- .send(SocketEvent::Outgoing((state, self.cid.peer.clone())))
+ .send(SocketEvent::Outgoing((state, self.peer.clone())))
.unwrap();
let recv_buf = ReceiveBuffer::new(syn);
@@ -409,7 +414,7 @@ impl Connection {
&mut self.unacked,
&mut self.socket_events,
fin,
- &self.cid.peer,
+ &self.peer,
Instant::now(),
);
}
@@ -441,7 +446,7 @@ impl Connection {
&mut self.unacked,
&mut self.socket_events,
fin,
- &self.cid.peer,
+ &self.peer,
Instant::now(),
);
}
@@ -542,7 +547,7 @@ impl Connection {
&mut self.unacked,
&mut self.socket_events,
packet,
- &self.cid.peer,
+ &self.peer,
now,
);
seq_num = seq_num.wrapping_add(1);
@@ -680,7 +685,7 @@ impl Connection {
let packet = self.syn_packet(seq);
let _ = self
.socket_events
- .send(SocketEvent::Outgoing((packet, self.cid.peer.clone())));
+ .send(SocketEvent::Outgoing((packet, self.peer.clone())));
}
}
Endpoint::Acceptor(..) => {}
@@ -728,7 +733,7 @@ impl Connection {
&mut self.unacked,
&mut self.socket_events,
packet,
- &self.cid.peer,
+ &self.peer,
now,
);
}
@@ -784,7 +789,7 @@ impl Connection {
match packet.packet_type() {
PacketType::Syn | PacketType::Fin | PacketType::Data => {
if let Some(state) = self.state_packet() {
- let event = SocketEvent::Outgoing((state, self.cid.peer.clone()));
+ let event = SocketEvent::Outgoing((state, self.peer.clone()));
if self.socket_events.send(event).is_err() {
tracing::warn!("Cannot transmit state packet: socket closed channel");
return;
@@ -1156,7 +1161,7 @@ impl Connection {
&mut self.unacked,
&mut self.socket_events,
packet,
- &self.cid.peer,
+ &self.peer,
now,
);
}
@@ -1167,7 +1172,7 @@ impl Connection {
unacked: &mut HashMapDelay,
socket_events: &mut mpsc::UnboundedSender>,
packet: Packet,
- dest: &P,
+ peer: &Peer,
now: Instant,
) {
let (payload, len) = if packet.payload().is_empty() {
@@ -1189,7 +1194,7 @@ impl Connection {
sent_packets.on_transmit(packet.seq_num(), packet.packet_type(), payload, len, now);
unacked.insert_at(packet.seq_num(), packet.clone(), sent_packets.timeout());
- let outbound = SocketEvent::Outgoing((packet, dest.clone()));
+ let outbound = SocketEvent::Outgoing((packet, peer.clone()));
if socket_events.send(outbound).is_err() {
tracing::warn!("Cannot transmit packet: socket closed channel");
}
@@ -1214,12 +1219,13 @@ mod test {
let cid = ConnectionId {
send: 101,
recv: 100,
- peer,
+ peer_id: peer,
};
Connection {
state: State::Connecting(Some(connected)),
cid,
+ peer: Peer::new(peer),
config: ConnectionConfig::default(),
endpoint,
peer_ts_diff: Duration::from_millis(100),
diff --git a/src/event.rs b/src/event.rs
index 8399329..43c6b6c 100644
--- a/src/event.rs
+++ b/src/event.rs
@@ -1,5 +1,6 @@
use crate::cid::ConnectionId;
use crate::packet::Packet;
+use crate::peer::{ConnectionPeer, Peer};
#[derive(Clone, Debug)]
pub enum StreamEvent {
@@ -8,7 +9,7 @@ pub enum StreamEvent {
}
#[derive(Clone, Debug)]
-pub enum SocketEvent {
- Outgoing((Packet, P)),
- Shutdown(ConnectionId
),
+pub enum SocketEvent {
+ Outgoing((Packet, Peer)),
+ Shutdown(ConnectionId),
}
diff --git a/src/lib.rs b/src/lib.rs
index ca83ee2..4006266 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -3,6 +3,7 @@ pub mod congestion;
pub mod conn;
pub mod event;
pub mod packet;
+pub mod peer;
pub mod recv;
pub mod send;
pub mod sent;
diff --git a/src/peer.rs b/src/peer.rs
new file mode 100644
index 0000000..f668276
--- /dev/null
+++ b/src/peer.rs
@@ -0,0 +1,96 @@
+use std::fmt::Debug;
+use std::hash::Hash;
+use std::net::SocketAddr;
+
+/// A trait that describes remote peer
+pub trait ConnectionPeer: Debug + Clone + Send + Sync {
+ type Id: Debug + Clone + PartialEq + Eq + Hash + Send + Sync;
+
+ /// Returns peer's id
+ fn id(&self) -> Self::Id;
+
+ /// Consolidates two peers into one.
+ ///
+ /// It's possible that we have two instances that represent the same peer (equal `peer_id`),
+ /// and we need to consolidate them into one. This can happen when [Peer]-s passed with
+ /// [UtpSocket::accept_with_cid](crate::socket::UtpSocket::accept_with_cid) or
+ /// [UtpSocket::connect_with_cid](crate::socket::UtpSocket::connect_with_cid), and returned by
+ /// [AsyncUdpSocket::recv_from](crate::udp::AsyncUdpSocket::recv_from) contain peers (not just
+ /// `peer_id`).
+ ///
+ /// The structure implementing this trait can decide on the exact behavior. Some examples:
+ /// - If structure is simple (i.e. two peers are the same iff all fields are the same), return
+ /// either (see implementation for `SocketAddr`)
+ /// - If we can determine which peer is newer (e.g. using timestamp or version field), return
+ /// newer peer
+ /// - If structure behaves more like a key-value map whose values don't change over time,
+ /// merge key-value pairs from both instances into one
+ ///
+ /// Should panic if ids are not matching.
+ fn consolidate(a: Self, b: Self) -> Self;
+}
+
+impl ConnectionPeer for SocketAddr {
+ type Id = Self;
+
+ fn id(&self) -> Self::Id {
+ *self
+ }
+
+ fn consolidate(a: Self, b: Self) -> Self {
+ assert!(a == b, "Consolidating non-equal peers");
+ a
+ }
+}
+
+/// Structure that stores peer's id, and maybe peer as well.
+#[derive(Debug, Clone)]
+pub struct Peer {
+ id: P::Id,
+ peer: Option,
+}
+
+impl Peer {
+ /// Creates new instance that stores peer
+ pub fn new(peer: P) -> Self {
+ Self {
+ id: peer.id(),
+ peer: Some(peer),
+ }
+ }
+
+ /// Creates new instance that only stores peer's id
+ pub fn new_id(peer_id: P::Id) -> Self {
+ Self {
+ id: peer_id,
+ peer: None,
+ }
+ }
+
+ /// Returns peer's id
+ pub fn id(&self) -> &P::Id {
+ &self.id
+ }
+
+ /// Returns optional reference to peer
+ pub fn peer(&self) -> Option<&P> {
+ self.peer.as_ref()
+ }
+
+ /// Consolidates given peer into `Self` whilst consuming it.
+ ///
+ /// See [ConnectionPeer::consolidate] for details.
+ ///
+ /// Panics if ids are not matching.
+ pub fn consolidate(&mut self, other: Self) {
+ assert!(self.id == other.id, "Consolidating with non-equal peer");
+ let Some(other_peer) = other.peer else {
+ return;
+ };
+
+ self.peer = match self.peer.take() {
+ Some(peer) => Some(P::consolidate(peer, other_peer)),
+ None => Some(other_peer),
+ };
+ }
+}
diff --git a/src/socket.rs b/src/socket.rs
index 06a654d..4ea0766 100644
--- a/src/socket.rs
+++ b/src/socket.rs
@@ -11,20 +11,27 @@ use tokio::net::UdpSocket;
use tokio::sync::mpsc::UnboundedSender;
use tokio::sync::{mpsc, oneshot};
-use crate::cid::{ConnectionId, ConnectionPeer};
+use crate::cid::ConnectionId;
use crate::conn::ConnectionConfig;
use crate::event::{SocketEvent, StreamEvent};
use crate::packet::{Packet, PacketBuilder, PacketType};
+use crate::peer::{ConnectionPeer, Peer};
use crate::stream::UtpStream;
use crate::udp::AsyncUdpSocket;
type ConnChannel = UnboundedSender;
-struct Accept {
+struct Accept {
stream: oneshot::Sender>>,
config: ConnectionConfig,
}
+struct AcceptWithCidPeer {
+ cid: ConnectionId,
+ peer: Peer,
+ accept: Accept
,
+}
+
const MAX_UDP_PAYLOAD_SIZE: usize = u16::MAX as usize;
const CID_GENERATION_TRY_WARNING_COUNT: usize = 10;
@@ -36,10 +43,10 @@ const CID_GENERATION_TRY_WARNING_COUNT: usize = 10;
/// but thee uTP config refactor is currently very low priority.
const AWAITING_CONNECTION_TIMEOUT: Duration = Duration::from_secs(20);
-pub struct UtpSocket
{
- conns: Arc, ConnChannel>>>,
+pub struct UtpSocket {
+ conns: Arc, ConnChannel>>>,
accepts: UnboundedSender>,
- accepts_with_cid: UnboundedSender<(Accept, ConnectionId
)>,
+ accepts_with_cid: UnboundedSender>,
socket_events: UnboundedSender>,
}
@@ -53,7 +60,7 @@ impl UtpSocket {
impl UtpSocket
where
- P: ConnectionPeer + Unpin + 'static,
+ P: ConnectionPeer + Unpin + 'static,
{
pub fn with_socket(mut socket: S) -> Self
where
@@ -62,10 +69,10 @@ where
let conns = HashMap::new();
let conns = Arc::new(RwLock::new(conns));
- let mut awaiting: HashMapDelay, Accept> =
+ let mut awaiting: HashMapDelay, AcceptWithCidPeer> =
HashMapDelay::new(AWAITING_CONNECTION_TIMEOUT);
- let mut incoming_conns: HashMapDelay, Packet> =
+ let mut incoming_conns: HashMapDelay, (Peer, Packet)> =
HashMapDelay::new(AWAITING_CONNECTION_TIMEOUT);
let (socket_event_tx, mut socket_event_rx) = mpsc::unbounded_channel();
@@ -84,18 +91,19 @@ where
loop {
tokio::select! {
biased;
- Ok((n, src)) = socket.recv_from(&mut buf) => {
+ Ok((n, mut peer)) = socket.recv_from(&mut buf) => {
+ let peer_id = peer.id();
let packet = match Packet::decode(&buf[..n]) {
Ok(pkt) => pkt,
Err(..) => {
- tracing::warn!(?src, "unable to decode uTP packet");
+ tracing::warn!(?peer, "unable to decode uTP packet");
continue;
}
};
- let peer_init_cid = cid_from_packet(&packet, &src, IdType::SendIdPeerInitiated);
- let we_init_cid = cid_from_packet(&packet, &src, IdType::SendIdWeInitiated);
- let acc_cid = cid_from_packet(&packet, &src, IdType::RecvId);
+ let peer_init_cid = cid_from_packet::
(&packet, peer_id, IdType::SendIdPeerInitiated);
+ let we_init_cid = cid_from_packet::
(&packet, peer_id, IdType::SendIdWeInitiated);
+ let acc_cid = cid_from_packet::
(&packet, peer_id, IdType::RecvId);
let mut conns = conns.write().unwrap();
let conn = conns
.get(&acc_cid)
@@ -107,12 +115,14 @@ where
}
None => {
if std::matches!(packet.packet_type(), PacketType::Syn) {
- let cid = cid_from_packet(&packet, &src, IdType::RecvId);
+ let cid = acc_cid;
// If there was an awaiting connection with the CID, then
// create a new stream for that connection. Otherwise, add the
// connection to the incoming connections.
- if let Some(accept) = awaiting.remove(&cid) {
+ if let Some(accept_with_cid) = awaiting.remove(&cid) {
+ peer.consolidate(accept_with_cid.peer);
+
let (connected_tx, connected_rx) = oneshot::channel();
let (events_tx, events_rx) = mpsc::unbounded_channel();
@@ -120,7 +130,8 @@ where
let stream = UtpStream::new(
cid,
- accept.config,
+ peer,
+ accept_with_cid.accept.config,
Some(packet),
socket_event_tx.clone(),
events_rx,
@@ -128,10 +139,10 @@ where
);
tokio::spawn(async move {
- Self::await_connected(stream, accept, connected_rx).await
+ Self::await_connected(stream, accept_with_cid.accept.stream, connected_rx).await
});
} else {
- incoming_conns.insert(cid, packet);
+ incoming_conns.insert(cid, (peer, packet));
}
} else {
tracing::debug!(
@@ -151,7 +162,7 @@ where
let reset_packet =
PacketBuilder::new(PacketType::Reset, packet.conn_id(), crate::time::now_micros(), 100_000, random_seq_num)
.build();
- let event = SocketEvent::Outgoing((reset_packet, src.clone()));
+ let event = SocketEvent::Outgoing((reset_packet, peer));
if socket_event_tx.send(event).is_err() {
tracing::warn!("Cannot transmit reset packet: socket closed channel");
return;
@@ -161,18 +172,19 @@ where
},
}
}
- Some((accept, cid)) = accepts_with_cid_rx.recv() => {
- let Some(syn) = incoming_conns.remove(&cid) else {
- awaiting.insert(cid, accept);
+ Some(accept_with_cid) = accepts_with_cid_rx.recv() => {
+ let Some((mut peer, syn)) = incoming_conns.remove(&accept_with_cid.cid) else {
+ awaiting.insert(accept_with_cid.cid.clone(), accept_with_cid);
continue;
};
- Self::select_accept_helper(cid, syn, conns.clone(), accept, socket_event_tx.clone());
+ peer.consolidate(accept_with_cid.peer);
+ Self::select_accept_helper(accept_with_cid.cid, peer, syn, conns.clone(), accept_with_cid.accept, socket_event_tx.clone());
}
Some(accept) = accepts_rx.recv(), if !incoming_conns.is_empty() => {
- let (cid, _) = incoming_conns.iter().next().expect("at least one incoming connection");
+ let cid = incoming_conns.keys().next().expect("at least one incoming connection");
let cid = cid.clone();
- let packet = incoming_conns.remove(&cid).expect("to delete incoming connection");
- Self::select_accept_helper(cid, packet, conns.clone(), accept, socket_event_tx.clone());
+ let (peer, packet) = incoming_conns.remove(&cid).expect("to delete incoming connection");
+ Self::select_accept_helper(cid, peer, packet, conns.clone(), accept, socket_event_tx.clone());
}
Some(event) = socket_event_rx.recv() => {
match event {
@@ -195,11 +207,11 @@ where
}
}
}
- Some(Ok((cid, accept))) = awaiting.next() => {
+ Some(Ok((cid, accept_with_cid))) = awaiting.next() => {
// accept_with_cid didn't receive an inbound connection within the timeout period
// log it and return a timeout error
tracing::debug!(%cid.send, %cid.recv, "accept_with_cid timed out");
- let _ = accept
+ let _ = accept_with_cid.accept
.stream
.send(Err(io::Error::from(io::ErrorKind::TimedOut)));
}
@@ -218,14 +230,14 @@ where
/// Internal cid generation
fn generate_cid(
&self,
- peer: P,
+ peer_id: P::Id,
is_initiator: bool,
event_tx: Option>,
- ) -> ConnectionId {
+ ) -> ConnectionId {
let mut cid = ConnectionId {
send: 0,
recv: 0,
- peer,
+ peer_id,
};
let mut generation_attempt_count = 0;
loop {
@@ -251,8 +263,8 @@ where
}
}
- pub fn cid(&self, peer: P, is_initiator: bool) -> ConnectionId {
- self.generate_cid(peer, is_initiator, None)
+ pub fn cid(&self, peer_id: P::Id, is_initiator: bool) -> ConnectionId {
+ self.generate_cid(peer_id, is_initiator, None)
}
/// Returns the number of connections currently open, both inbound and outbound.
@@ -281,16 +293,21 @@ where
/// they aren't compatible to use interchangeably in a program
pub async fn accept_with_cid(
&self,
- cid: ConnectionId,
+ cid: ConnectionId,
+ peer: Peer,
config: ConnectionConfig,
) -> io::Result> {
let (stream_tx, stream_rx) = oneshot::channel();
- let accept = Accept {
- stream: stream_tx,
- config,
+ let accept = AcceptWithCidPeer {
+ cid,
+ peer,
+ accept: Accept {
+ stream: stream_tx,
+ config,
+ },
};
self.accepts_with_cid
- .send((accept, cid))
+ .send(accept)
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))?;
match stream_rx.await {
Ok(stream) => Ok(stream?),
@@ -298,13 +315,18 @@ where
}
}
- pub async fn connect(&self, peer: P, config: ConnectionConfig) -> io::Result> {
+ pub async fn connect(
+ &self,
+ peer: Peer,
+ config: ConnectionConfig,
+ ) -> io::Result> {
let (connected_tx, connected_rx) = oneshot::channel();
let (events_tx, events_rx) = mpsc::unbounded_channel();
- let cid = self.generate_cid(peer, true, Some(events_tx));
+ let cid = self.generate_cid(peer.id().clone(), true, Some(events_tx));
let stream = UtpStream::new(
cid,
+ peer,
config,
None,
self.socket_events.clone(),
@@ -321,7 +343,8 @@ where
pub async fn connect_with_cid(
&self,
- cid: ConnectionId,
+ cid: ConnectionId,
+ peer: Peer,
config: ConnectionConfig,
) -> io::Result> {
if self.conns.read().unwrap().contains_key(&cid) {
@@ -340,6 +363,7 @@ where
let stream = UtpStream::new(
cid.clone(),
+ peer,
config,
None,
self.socket_events.clone(),
@@ -362,28 +386,27 @@ where
async fn await_connected(
stream: UtpStream,
- accept: Accept
,
+ callback: oneshot::Sender>>,
connected: oneshot::Receiver>,
) {
match connected.await {
Ok(Ok(..)) => {
- let _ = accept.stream.send(Ok(stream));
+ let _ = callback.send(Ok(stream));
}
Ok(Err(err)) => {
- let _ = accept.stream.send(Err(err));
+ let _ = callback.send(Err(err));
}
Err(..) => {
- let _ = accept
- .stream
- .send(Err(io::Error::from(io::ErrorKind::ConnectionAborted)));
+ let _ = callback.send(Err(io::Error::from(io::ErrorKind::ConnectionAborted)));
}
}
}
fn select_accept_helper(
- cid: ConnectionId,
+ cid: ConnectionId,
+ peer: Peer,
syn: Packet,
- conns: Arc, UnboundedSender>>>,
+ conns: Arc, ConnChannel>>>,
accept: Accept,
socket_event_tx: UnboundedSender>,
) {
@@ -404,6 +427,7 @@ where
let stream = UtpStream::new(
cid,
+ peer,
accept.config,
Some(syn),
socket_event_tx,
@@ -411,7 +435,9 @@ where
connected_tx,
);
- tokio::spawn(async move { Self::await_connected(stream, accept, connected_rx).await });
+ tokio::spawn(
+ async move { Self::await_connected(stream, accept.stream, connected_rx).await },
+ );
}
}
@@ -424,9 +450,10 @@ enum IdType {
fn cid_from_packet(
packet: &Packet,
- src: &P,
+ peer_id: &P::Id,
id_type: IdType,
-) -> ConnectionId {
+) -> ConnectionId {
+ let peer_id = peer_id.clone();
match id_type {
IdType::RecvId => {
let (send, recv) = match packet.packet_type() {
@@ -438,7 +465,7 @@ fn cid_from_packet(
ConnectionId {
send,
recv,
- peer: src.clone(),
+ peer_id,
}
}
IdType::SendIdWeInitiated => {
@@ -446,7 +473,7 @@ fn cid_from_packet(
ConnectionId {
send,
recv,
- peer: src.clone(),
+ peer_id,
}
}
IdType::SendIdPeerInitiated => {
@@ -454,13 +481,13 @@ fn cid_from_packet(
ConnectionId {
send,
recv,
- peer: src.clone(),
+ peer_id,
}
}
}
}
-impl Drop for UtpSocket
{
+impl Drop for UtpSocket {
fn drop(&mut self) {
for conn in self.conns.read().unwrap().values() {
let _ = conn.send(StreamEvent::Shutdown);
diff --git a/src/stream.rs b/src/stream.rs
index 363f3a8..4311709 100644
--- a/src/stream.rs
+++ b/src/stream.rs
@@ -4,18 +4,19 @@ use tokio::sync::{mpsc, oneshot};
use tokio::task;
use tracing::Instrument;
-use crate::cid::{ConnectionId, ConnectionPeer};
+use crate::cid::ConnectionId;
use crate::congestion::DEFAULT_MAX_PACKET_SIZE_BYTES;
use crate::conn;
use crate::event::{SocketEvent, StreamEvent};
use crate::packet::Packet;
+use crate::peer::{ConnectionPeer, Peer};
/// The size of the send and receive buffers.
// TODO: Make the buffer size configurable.
const BUF: usize = 1024 * 1024;
-pub struct UtpStream
{
- cid: ConnectionId
,
+pub struct UtpStream {
+ cid: ConnectionId,
reads: mpsc::UnboundedReceiver,
writes: mpsc::UnboundedSender,
shutdown: Option>,
@@ -27,7 +28,8 @@ where
P: ConnectionPeer + 'static,
{
pub(crate) fn new(
- cid: ConnectionId,
+ cid: ConnectionId,
+ peer: Peer,
config: conn::ConnectionConfig,
syn: Option,
socket_events: mpsc::UnboundedSender>,
@@ -39,6 +41,7 @@ where
let (writes_tx, writes_rx) = mpsc::unbounded_channel();
let mut conn = conn::Connection::::new(
cid.clone(),
+ peer,
config,
syn,
connected,
@@ -60,7 +63,7 @@ where
}
}
- pub fn cid(&self) -> &ConnectionId {
+ pub fn cid(&self) -> &ConnectionId {
&self.cid
}
@@ -117,7 +120,7 @@ where
}
}
-impl UtpStream
{
+impl UtpStream {
// Send signal to the connection event loop to exit, after all outgoing writes have completed.
// Public callers should use close() instead.
fn shutdown(&mut self) -> io::Result<()> {
@@ -130,7 +133,7 @@ impl
UtpStream
{
}
}
-impl
Drop for UtpStream
{
+impl Drop for UtpStream {
fn drop(&mut self) {
let _ = self.shutdown();
}
diff --git a/src/testutils.rs b/src/testutils.rs
index 4372d0a..2ea0f00 100644
--- a/src/testutils.rs
+++ b/src/testutils.rs
@@ -5,7 +5,8 @@ use std::sync::Arc;
use async_trait::async_trait;
use tokio::sync::mpsc;
-use crate::cid::{ConnectionId, ConnectionPeer};
+use crate::cid::ConnectionId;
+use crate::peer::{ConnectionPeer, Peer};
use crate::udp::AsyncUdpSocket;
/// A mock socket that can be used to simulate a perfect link.
@@ -38,8 +39,8 @@ impl AsyncUdpSocket for MockUdpSocket {
///
/// Panics if `target` is not equal to `self.only_peer`. This socket is built to support
/// exactly two peers communicating with each other, so it will panic if used with more.
- async fn send_to(&mut self, buf: &[u8], target: &char) -> io::Result {
- if target != &self.only_peer {
+ async fn send_to(&mut self, buf: &[u8], peer: &Peer) -> io::Result {
+ if peer.id() != &self.only_peer {
panic!("MockUdpSocket only supports sending to one peer");
}
if !self.is_up() {
@@ -58,7 +59,7 @@ impl AsyncUdpSocket for MockUdpSocket {
/// # Panics
///
/// Panics if `buf` is smaller than the packet size.
- async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, char)> {
+ async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, Peer)> {
let packet = self
.inbound
.recv()
@@ -69,11 +70,22 @@ impl AsyncUdpSocket for MockUdpSocket {
}
let packet_len = packet.len();
buf[..packet_len].copy_from_slice(&packet[..]);
- Ok((packet_len, self.only_peer))
+ Ok((packet_len, Peer::new(self.only_peer)))
}
}
-impl ConnectionPeer for char {}
+impl ConnectionPeer for char {
+ type Id = char;
+
+ fn id(&self) -> Self::Id {
+ *self
+ }
+
+ fn consolidate(a: Self, b: Self) -> Self {
+ assert!(a == b, "Consolidating non-equal peers");
+ a
+ }
+}
fn build_link_pair() -> (MockUdpSocket, MockUdpSocket) {
let (peer_a, peer_b): (char, char) = ('A', 'B');
@@ -110,12 +122,12 @@ fn build_connection_id_pair_starting_at(
let a_cid = ConnectionId {
send: higher_id,
recv: lower_id,
- peer: socket_a.only_peer,
+ peer_id: socket_a.only_peer,
};
let b_cid = ConnectionId {
send: lower_id,
recv: higher_id,
- peer: socket_b.only_peer,
+ peer_id: socket_b.only_peer,
};
(a_cid, b_cid)
}
diff --git a/src/udp.rs b/src/udp.rs
index 62d2bae..ced0e7a 100644
--- a/src/udp.rs
+++ b/src/udp.rs
@@ -4,25 +4,27 @@ use std::net::SocketAddr;
use async_trait::async_trait;
use tokio::net::UdpSocket;
-use crate::cid::ConnectionPeer;
+use crate::peer::{ConnectionPeer, Peer};
/// An abstract representation of an asynchronous UDP socket.
#[async_trait]
pub trait AsyncUdpSocket: Send + Sync {
- /// Attempts to send data on the socket to a given address.
+ /// Attempts to send data on the socket to a given peer.
/// Note that this should return nearly immediately, rather than awaiting something internally.
- async fn send_to(&mut self, buf: &[u8], target: &P) -> io::Result;
+ async fn send_to(&mut self, buf: &[u8], peer: &Peer) -> io::Result;
/// Attempts to receive a single datagram on the socket.
- async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, P)>;
+ async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, Peer)>;
}
#[async_trait]
impl AsyncUdpSocket for UdpSocket {
- async fn send_to(&mut self, buf: &[u8], target: &SocketAddr) -> io::Result {
- UdpSocket::send_to(self, buf, target).await
+ async fn send_to(&mut self, buf: &[u8], peer: &Peer) -> io::Result {
+ UdpSocket::send_to(self, buf, peer.id()).await
}
- async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
- UdpSocket::recv_from(self, buf).await
+ async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, Peer)> {
+ UdpSocket::recv_from(self, buf)
+ .await
+ .map(|(len, peer)| (len, Peer::new(peer)))
}
}
diff --git a/tests/socket.rs b/tests/socket.rs
index 026d80d..2f9382e 100644
--- a/tests/socket.rs
+++ b/tests/socket.rs
@@ -1,6 +1,7 @@
use futures::stream::{FuturesUnordered, StreamExt};
use std::net::SocketAddr;
use std::sync::Arc;
+use utp_rs::peer::Peer;
use tokio::task::JoinHandle;
use tokio::time::Instant;
@@ -115,16 +116,19 @@ async fn initiate_transfer(
let recv_cid = cid::ConnectionId {
send: initiator_cid,
recv: responder_cid,
- peer: send_addr,
+ peer_id: send_addr,
};
let send_cid = cid::ConnectionId {
send: responder_cid,
recv: initiator_cid,
- peer: recv_addr,
+ peer_id: recv_addr,
};
let recv_handle = tokio::spawn(async move {
- let mut stream = recv.accept_with_cid(recv_cid, conn_config).await.unwrap();
+ let mut stream = recv
+ .accept_with_cid(recv_cid, Peer::new(send_addr), conn_config)
+ .await
+ .unwrap();
let mut buf = vec![];
let n = match stream.read_to_eof(&mut buf).await {
Ok(num_bytes) => num_bytes,
@@ -141,7 +145,10 @@ async fn initiate_transfer(
});
let send_handle = tokio::spawn(async move {
- let mut stream = send.connect_with_cid(send_cid, conn_config).await.unwrap();
+ let mut stream = send
+ .connect_with_cid(send_cid, Peer::new(recv_addr), conn_config)
+ .await
+ .unwrap();
let n = stream.write(data).await.unwrap();
assert_eq!(n, data.len());
@@ -174,18 +181,18 @@ async fn test_socket_reports_two_connections() {
let recv_one_cid = cid::ConnectionId {
send: 100,
recv: 101,
- peer: send_addr,
+ peer_id: send_addr,
};
let send_one_cid = cid::ConnectionId {
send: 101,
recv: 100,
- peer: recv_addr,
+ peer_id: recv_addr,
};
let recv_one = Arc::clone(&recv);
let recv_one_handle = tokio::spawn(async move {
recv_one
- .accept_with_cid(recv_one_cid, conn_config)
+ .accept_with_cid(recv_one_cid, Peer::new(send_addr), conn_config)
.await
.unwrap()
});
@@ -193,7 +200,7 @@ async fn test_socket_reports_two_connections() {
let send_one = Arc::clone(&send);
let send_one_handle = tokio::spawn(async move {
send_one
- .connect_with_cid(send_one_cid, conn_config)
+ .connect_with_cid(send_one_cid, Peer::new(recv_addr), conn_config)
.await
.unwrap()
});
@@ -201,18 +208,18 @@ async fn test_socket_reports_two_connections() {
let recv_two_cid = cid::ConnectionId {
send: 200,
recv: 201,
- peer: send_addr,
+ peer_id: send_addr,
};
let send_two_cid = cid::ConnectionId {
send: 201,
recv: 200,
- peer: recv_addr,
+ peer_id: recv_addr,
};
let recv_two = Arc::clone(&recv);
let recv_two_handle = tokio::spawn(async move {
recv_two
- .accept_with_cid(recv_two_cid, conn_config)
+ .accept_with_cid(recv_two_cid, Peer::new(send_addr), conn_config)
.await
.unwrap()
});
@@ -220,7 +227,7 @@ async fn test_socket_reports_two_connections() {
let send_two = Arc::clone(&send);
let send_two_handle = tokio::spawn(async move {
send_two
- .connect_with_cid(send_two_cid, conn_config)
+ .connect_with_cid(send_two_cid, Peer::new(recv_addr), conn_config)
.await
.unwrap()
});
diff --git a/tests/stream.rs b/tests/stream.rs
index 188a00f..719c196 100644
--- a/tests/stream.rs
+++ b/tests/stream.rs
@@ -6,6 +6,7 @@ use std::time::Duration;
use tokio::time::timeout;
use utp_rs::conn::{ConnectionConfig, DEFAULT_MAX_IDLE_TIMEOUT};
+use utp_rs::peer::Peer;
use utp_rs::socket::UtpSocket;
use utp_rs::testutils;
@@ -29,7 +30,7 @@ async fn close_is_successful_when_write_completes() {
let recv_one = Arc::clone(&recv);
let recv_one_handle = tokio::spawn(async move {
recv_one
- .accept_with_cid(recv_cid, conn_config)
+ .accept_with_cid(recv_cid, Peer::new_id(recv_cid.peer_id), conn_config)
.await
.unwrap()
});
@@ -39,7 +40,7 @@ async fn close_is_successful_when_write_completes() {
let send_one = Arc::clone(&send);
let send_one_handle = tokio::spawn(async move {
send_one
- .connect_with_cid(send_cid, conn_config)
+ .connect_with_cid(send_cid, Peer::new_id(send_cid.peer_id), conn_config)
.await
.unwrap()
});
@@ -100,7 +101,7 @@ async fn close_errors_if_all_packets_dropped() {
let recv_one = Arc::clone(&recv);
let recv_one_handle = tokio::spawn(async move {
recv_one
- .accept_with_cid(recv_cid, conn_config)
+ .accept_with_cid(recv_cid, Peer::new_id(recv_cid.peer_id), conn_config)
.await
.unwrap()
});
@@ -110,7 +111,7 @@ async fn close_errors_if_all_packets_dropped() {
let send_one = Arc::clone(&send);
let send_one_handle = tokio::spawn(async move {
send_one
- .connect_with_cid(send_cid, conn_config)
+ .connect_with_cid(send_cid, Peer::new_id(send_cid.peer_id), conn_config)
.await
.unwrap()
});
@@ -178,7 +179,7 @@ async fn close_succeeds_if_only_fin_ack_dropped() {
let recv_one = Arc::clone(&recv);
let recv_one_handle = tokio::spawn(async move {
recv_one
- .accept_with_cid(recv_cid, conn_config)
+ .accept_with_cid(recv_cid, Peer::new_id(recv_cid.peer_id), conn_config)
.await
.unwrap()
});
@@ -188,7 +189,7 @@ async fn close_succeeds_if_only_fin_ack_dropped() {
let send_one = Arc::clone(&send);
let send_one_handle = tokio::spawn(async move {
send_one
- .connect_with_cid(send_cid, conn_config)
+ .connect_with_cid(send_cid, Peer::new_id(send_cid.peer_id), conn_config)
.await
.unwrap()
});