Skip to content

Commit

Permalink
Merge pull request #132 from carver/fix-large-transfers
Browse files Browse the repository at this point in the history
Fix large transfers
  • Loading branch information
carver authored Aug 2, 2024
2 parents 1dde178 + 16f9f87 commit 00e1b0a
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 17 deletions.
37 changes: 28 additions & 9 deletions src/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,15 @@ impl<const N: usize> fmt::Debug for State<N> {
}

pub type Write = (Vec<u8>, oneshot::Sender<io::Result<usize>>);
type QueuedWrite = (
// Remaining bytes to write
Vec<u8>,
// Number of bytes successfully written in previous partial writes.
// Sometimes the content is larger than the buffer and must be written in parts.
usize,
// oneshot sender to notify about the final result of the write operation
oneshot::Sender<io::Result<usize>>,
);
pub type Read = io::Result<Vec<u8>>;

#[derive(Clone, Copy, Debug)]
Expand Down Expand Up @@ -169,7 +178,7 @@ pub struct Connection<const N: usize, P> {
unacked: HashMapDelay<u16, Packet>,
reads: mpsc::UnboundedSender<Read>,
readable: Notify,
pending_writes: VecDeque<Write>,
pending_writes: VecDeque<QueuedWrite>,
writable: Notify,
latest_timeout: Option<Instant>,
}
Expand Down Expand Up @@ -488,11 +497,20 @@ impl<const N: usize, P: ConnectionPeer> Connection<N, P> {
}

// Write as much data as possible into send buffer.
while let Some((data, ..)) = self.pending_writes.front() {
if data.len() <= send_buf.available() {
let (data, tx) = self.pending_writes.pop_front().unwrap();
send_buf.write(&data).unwrap();
let _ = tx.send(Ok(data.len()));
while self.pending_writes.front().is_some() {
let buf_space = send_buf.available();
if buf_space > 0 {
let (mut data, written, tx) = self.pending_writes.pop_front().unwrap();
if data.len() <= buf_space {
send_buf.write(&data).unwrap();
let _ = tx.send(Ok(data.len() + written));
} else {
let next_write = data.drain(..buf_space);
send_buf.write(next_write.as_slice()).unwrap();
drop(next_write);
// data was mutated by drain, so we only store the remaining data
self.pending_writes.push_front((data, buf_space + written, tx));
}
self.writable.notify_one();
} else {
break;
Expand Down Expand Up @@ -535,21 +553,22 @@ impl<const N: usize, P: ConnectionPeer> Connection<N, P> {

match &mut self.state {
State::Connecting(..) => {
self.pending_writes.push_back((data, tx));
// There are 0 bytes written so far
self.pending_writes.push_back((data, 0, tx));
}
State::Connected { closing, .. } => match closing {
Some(Closing {
local_fin,
remote_fin,
}) => {
if local_fin.is_none() && remote_fin.is_some() {
self.pending_writes.push_back((data, tx));
self.pending_writes.push_back((data, 0, tx));
} else {
let _ = tx.send(Ok(0));
}
}
None => {
self.pending_writes.push_back((data, tx));
self.pending_writes.push_back((data, 0, tx));
}
},
State::Closed { err, .. } => {
Expand Down
54 changes: 46 additions & 8 deletions tests/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ use utp_rs::socket::UtpSocket;
const TEST_DATA: &[u8] = &[0xf0; 1_000_000];

#[tokio::test(flavor = "multi_thread", worker_threads = 16)]
async fn socket() {
tracing_subscriber::fmt::init();
async fn many_concurrent_transfers() {
let _ = tracing_subscriber::fmt::try_init();

tracing::info!("starting socket test");

Expand All @@ -31,7 +31,7 @@ async fn socket() {
for i in 0..num_transfers {
// step up cid by two to avoid collisions
let handle =
initiate_transfer(i * 2, recv_addr, recv.clone(), send_addr, send.clone()).await;
initiate_transfer(i * 2, recv_addr, recv.clone(), send_addr, send.clone(), TEST_DATA).await;
handles.push(handle.0);
handles.push(handle.1);
}
Expand All @@ -42,7 +42,44 @@ async fn socket() {
let elapsed = Instant::now() - start;
let megabits_sent = num_transfers as f64 * TEST_DATA.len() as f64 * 8.0 / 1_000_000.0;
let transfer_rate = megabits_sent / elapsed.as_secs_f64();
tracing::info!("finished real udp load test of {} simultaneous transfers, in {:?}, at a rate of {:.0} Mbps", num_transfers, elapsed, transfer_rate);
tracing::info!("finished high concurrency load test of {} simultaneous transfers, in {:?}, at a rate of {:.0} Mbps", num_transfers, elapsed, transfer_rate);
}

#[tokio::test]
/// Test that a socket can send and receive a large amount of data
async fn one_huge_data_transfer() {
// TODO: test 100MiB or more. Currently, it fails (perhaps due to a rollover at 2^16 packets)

// At the time of writing, 1024 * 1024 + 1 will hang, because it's bigger than the send buffer,
// and the sending logic pauses until the buffer is larger than the pending data.
const HUGE_DATA: &[u8] = &[0xf0; 1024 * 1024 * 50];

let _ = tracing_subscriber::fmt::try_init();

tracing::info!("starting single transfer of huge data test");

let recv_addr = SocketAddr::from(([127, 0, 0, 1], 3500));
let send_addr = SocketAddr::from(([127, 0, 0, 1], 3501));

let recv = UtpSocket::bind(recv_addr).await.unwrap();
let recv = Arc::new(recv);
let send = UtpSocket::bind(send_addr).await.unwrap();
let send = Arc::new(send);

let start = Instant::now();
let handle =
initiate_transfer(0, recv_addr, recv.clone(), send_addr, send.clone(), HUGE_DATA).await;

// Wait for the sending side of the transfer to complete
handle.0.await.unwrap();
// Wait for the receiving side of the transfer to complete
handle.1.await.unwrap();

let elapsed = Instant::now() - start;
let megabytes_sent = HUGE_DATA.len() as f64 / 1_000_000.0;
let megabits_sent = megabytes_sent * 8.0;
let transfer_rate = megabits_sent / elapsed.as_secs_f64();
tracing::info!("finished single large transfer test with {:.0} MB, in {:?}, at a rate of {:.1} Mbps", megabytes_sent, elapsed, transfer_rate);
}

async fn initiate_transfer(
Expand All @@ -51,6 +88,7 @@ async fn initiate_transfer(
recv: Arc<UtpSocket<SocketAddr>>,
send_addr: SocketAddr,
send: Arc<UtpSocket<SocketAddr>>,
data: &'static [u8],
) -> (JoinHandle<()>, JoinHandle<()>) {
let conn_config = ConnectionConfig::default();
let initiator_cid = 100 + i;
Expand Down Expand Up @@ -79,14 +117,14 @@ async fn initiate_transfer(
};
tracing::info!(cid.send = %recv_cid.send, cid.recv = %recv_cid.recv, "read {n} bytes from uTP stream");

assert_eq!(n, TEST_DATA.len());
assert_eq!(buf, TEST_DATA);
assert_eq!(n, data.len());
assert_eq!(buf, data);
});

let send_handle = tokio::spawn(async move {
let mut stream = send.connect_with_cid(send_cid, conn_config).await.unwrap();
let n = stream.write(TEST_DATA).await.unwrap();
assert_eq!(n, TEST_DATA.len());
let n = stream.write(data).await.unwrap();
assert_eq!(n, data.len());

stream.close().await.unwrap();
});
Expand Down

0 comments on commit 00e1b0a

Please sign in to comment.