From 3d621186b1ff9d4989b2c9a8c00d7743a58f4afe Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Wed, 10 Apr 2024 11:48:21 +0300 Subject: [PATCH] Add ability to read fixed size array without allocating also rename len to size --- src/http.rs | 2 +- src/lib.rs | 99 +++++++++++++++++++++++++++++++++++++------------ src/mem.rs | 14 ++++++- src/stats.rs | 18 +++++---- src/tokio_io.rs | 24 ++++++++++-- 5 files changed, 120 insertions(+), 37 deletions(-) diff --git a/src/http.rs b/src/http.rs index d726e14..da9defc 100644 --- a/src/http.rs +++ b/src/http.rs @@ -146,7 +146,7 @@ pub mod http_adapter { Ok(res.freeze()) } - async fn len(&mut self) -> io::Result { + async fn size(&mut self) -> io::Result { let io_err = |text: &str| io::Error::new(io::ErrorKind::Other, text); let head_response = self .head_request() diff --git a/src/lib.rs b/src/lib.rs index 8de4435..a700631 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -37,7 +37,7 @@ //! an allocation. #![deny(missing_docs, rustdoc::broken_intra_doc_links)] -use bytes::{Bytes, BytesMut}; +use bytes::{Buf, Bytes, BytesMut}; use std::future::Future; use std::io::{self, Cursor}; @@ -64,7 +64,7 @@ pub trait AsyncSliceReader { /// Get the length of the resource #[must_use = "io futures must be polled to completion"] - fn len(&mut self) -> impl Future>; + fn size(&mut self) -> impl Future>; } impl<'b, T: AsyncSliceReader> AsyncSliceReader for &'b mut T { @@ -72,8 +72,8 @@ impl<'b, T: AsyncSliceReader> AsyncSliceReader for &'b mut T { (**self).read_at(offset, len).await } - async fn len(&mut self) -> io::Result { - (**self).len().await + async fn size(&mut self) -> io::Result { + (**self).size().await } } @@ -82,8 +82,8 @@ impl AsyncSliceReader for Box { (**self).read_at(offset, len).await } - async fn len(&mut self) -> io::Result { - (**self).len().await + async fn size(&mut self) -> io::Result { + (**self).size().await } } @@ -172,45 +172,94 @@ pub trait AsyncStreamReader { /// Read at most `len` bytes. To read to the end, pass u64::MAX. /// /// returns an empty buffer to indicate EOF. - fn read(&mut self, len: usize) -> impl Future>; + fn read_bytes(&mut self, len: usize) -> impl Future>; + + /// Read a fixed size buffer. + /// + /// If there are less than L bytes available, an io::ErrorKind::UnexpectedEof error is returned. + fn read(&mut self) -> impl Future>; } impl AsyncStreamReader for &mut T { - async fn read(&mut self, len: usize) -> io::Result { - (**self).read(len).await + async fn read_bytes(&mut self, len: usize) -> io::Result { + (**self).read_bytes(len).await + } + + async fn read(&mut self) -> io::Result<[u8; L]> { + (**self).read().await } } impl AsyncStreamReader for Bytes { - async fn read(&mut self, len: usize) -> io::Result { + async fn read_bytes(&mut self, len: usize) -> io::Result { let res = self.split_to(len.min(Bytes::len(self))); Ok(res) } + + async fn read(&mut self) -> io::Result<[u8; L]> { + if Bytes::len(self) < L { + return Err(io::ErrorKind::UnexpectedEof.into()); + } + let mut res = [0u8; L]; + self.split_to(L).copy_to_slice(&mut res); + Ok(res) + } +} + +impl AsyncStreamReader for BytesMut { + async fn read_bytes(&mut self, len: usize) -> io::Result { + let res = self.split_to(len.min(BytesMut::len(self))); + Ok(res.freeze()) + } + + async fn read(&mut self) -> io::Result<[u8; L]> { + if BytesMut::len(self) < L { + return Err(io::ErrorKind::UnexpectedEof.into()); + } + let mut res = [0u8; L]; + self.split_to(L).copy_to_slice(&mut res); + Ok(res) + } } impl AsyncStreamReader for &[u8] { - async fn read(&mut self, len: usize) -> io::Result { + async fn read_bytes(&mut self, len: usize) -> io::Result { let len = len.min(self.len()); let res = Bytes::copy_from_slice(&self[..len]); *self = &self[len..]; Ok(res) } -} -impl AsyncStreamReader for BytesMut { - async fn read(&mut self, len: usize) -> io::Result { - let res = self.split_to(len.min(BytesMut::len(self))); - Ok(res.freeze()) + async fn read(&mut self) -> io::Result<[u8; L]> { + if self.len() < L { + return Err(io::ErrorKind::UnexpectedEof.into()); + } + let mut res = [0u8; L]; + res.copy_from_slice(&self[..L]); + *self = &self[L..]; + Ok(res) } } impl AsyncStreamReader for Cursor { - async fn read(&mut self, len: usize) -> io::Result { + async fn read_bytes(&mut self, len: usize) -> io::Result { let offset = self.position(); let res = self.get_mut().read_at(offset, len).await?; self.set_position(offset + res.len() as u64); Ok(res) } + + async fn read(&mut self) -> io::Result<[u8; L]> { + let offset = self.position(); + let res = self.get_mut().read_at(offset, L).await?; + if res.len() < L { + return Err(io::ErrorKind::UnexpectedEof.into()); + } + self.set_position(offset + res.len() as u64); + let mut buf = [0u8; L]; + buf.copy_from_slice(&res); + Ok(buf) + } } /// A non seekable writer, e.g. a network socket. @@ -304,10 +353,10 @@ where } } - async fn len(&mut self) -> io::Result { + async fn size(&mut self) -> io::Result { match self { - Self::Left(l) => l.len().await, - Self::Right(r) => r.len().await, + Self::Left(l) => l.size().await, + Self::Right(r) => r.size().await, } } } @@ -452,7 +501,7 @@ mod tests { let res = file.read_at(0, usize::MAX).await?; assert_eq!(res, expected); - let res = file.len().await?; + let res = file.size().await?; assert_eq!(res, 100); // read 3 bytes at offset 0 @@ -687,7 +736,7 @@ mod tests { current = offset.checked_add(len as u64).unwrap(); } ReadOp::Len => { - let len = AsyncSliceReader::len(&mut file).await?; + let len = AsyncSliceReader::size(&mut file).await?; assert_eq!(len, actual.len() as u64); } } @@ -717,7 +766,7 @@ mod tests { let url = reqwest::Url::parse(&url).unwrap(); let server = tokio::spawn(server); let mut reader = HttpAdapter::new(url); - let len = reader.len().await.unwrap(); + let len = reader.size().await.unwrap(); assert_eq!(len, 11); println!("len: {:?}", reader); let part = reader.read_at(0, 11).await.unwrap(); @@ -747,7 +796,9 @@ mod tests { #[test] fn bytes_read(data in proptest::collection::vec(any::(), 0..1024), ops in random_read_ops(1024, 1024, 2)) { - async_test(read_op_test(ops, Bytes::from(data.clone()), &data)).unwrap(); + async_test(read_op_test(ops.clone(), Bytes::from(data.clone()), &data)).unwrap(); + async_test(read_op_test(ops.clone(), BytesMut::from(data.as_slice()), &data)).unwrap(); + async_test(read_op_test(ops, data.as_slice(), &data)).unwrap(); } #[cfg(feature = "tokio-io")] diff --git a/src/mem.rs b/src/mem.rs index 90338f1..298a3c2 100644 --- a/src/mem.rs +++ b/src/mem.rs @@ -7,7 +7,7 @@ impl AsyncSliceReader for bytes::Bytes { Ok(get_limited_slice(self, offset, len)) } - async fn len(&mut self) -> io::Result { + async fn size(&mut self) -> io::Result { Ok(Bytes::len(self) as u64) } } @@ -17,11 +17,21 @@ impl AsyncSliceReader for bytes::BytesMut { Ok(copy_limited_slice(self, offset, len)) } - async fn len(&mut self) -> io::Result { + async fn size(&mut self) -> io::Result { Ok(BytesMut::len(self) as u64) } } +impl AsyncSliceReader for &[u8] { + async fn read_at(&mut self, offset: u64, len: usize) -> io::Result { + Ok(copy_limited_slice(self, offset, len)) + } + + async fn size(&mut self) -> io::Result { + Ok(self.len() as u64) + } +} + impl AsyncSliceWriter for bytes::BytesMut { async fn write_bytes_at(&mut self, offset: u64, data: Bytes) -> io::Result<()> { write_extend(self, offset, &data) diff --git a/src/stats.rs b/src/stats.rs index 8fec29e..e0147de 100644 --- a/src/stats.rs +++ b/src/stats.rs @@ -222,8 +222,12 @@ impl TrackingStreamReader { } impl AsyncStreamReader for TrackingStreamReader { - async fn read(&mut self, len: usize) -> io::Result { - AggregateSizeAndStats::new(self.inner.read(len), &mut self.stats.read).await + async fn read_bytes(&mut self, len: usize) -> io::Result { + AggregateSizeAndStats::new(self.inner.read_bytes(len), &mut self.stats.read).await + } + + async fn read(&mut self) -> io::Result<[u8; L]> { + AggregateSizeAndStats::new(self.inner.read(), &mut self.stats.read).await } } @@ -288,8 +292,8 @@ impl AsyncSliceReader for TrackingSliceReader { AggregateSizeAndStats::new(self.inner.read_at(offset, len), &mut self.stats.read_at).await } - async fn len(&mut self) -> io::Result { - AggregateStats::new(self.inner.len(), &mut self.stats.len).await + async fn size(&mut self) -> io::Result { + AggregateStats::new(self.inner.size(), &mut self.stats.len).await } } @@ -507,8 +511,8 @@ mod tests { #[tokio::test] async fn tracking_stream_reader() { let mut writer = TrackingStreamReader::new(Bytes::from(vec![0, 1, 2, 3])); - writer.read(2).await.unwrap(); - writer.read(3).await.unwrap(); + writer.read_bytes(2).await.unwrap(); + writer.read_bytes(3).await.unwrap(); assert_eq!(writer.stats().read.size, 4); // not 5, because the last read was only 2 bytes assert_eq!(writer.stats().read.stats.count, 2); } @@ -537,7 +541,7 @@ mod tests { let mut reader = TrackingSliceReader::new(Bytes::from(vec![1u8, 2, 3])); let _ = reader.read_at(0, 1).await.unwrap(); let _ = reader.read_at(10, 1).await.unwrap(); - let _ = reader.len().await.unwrap(); + let _ = reader.size().await.unwrap(); assert_eq!(reader.stats().read_at.size, 1); assert_eq!(reader.stats().read_at.stats.count, 2); assert_eq!(reader.stats().len.count, 1); diff --git a/src/tokio_io.rs b/src/tokio_io.rs index c9e2d9d..7d47032 100644 --- a/src/tokio_io.rs +++ b/src/tokio_io.rs @@ -67,7 +67,7 @@ pub mod file { Asyncify::from(self.0.take().map(|t| (t.read_at(offset, len), &mut self.0))).await } - async fn len(&mut self) -> io::Result { + async fn size(&mut self) -> io::Result { Asyncify::from(self.0.take().map(|t| (t.len(), &mut self.0))).await } } @@ -289,12 +289,30 @@ impl AsyncStreamWriter for TokioStreamWriter(T); +pub struct TokioStreamReader(pub T); + +impl TokioStreamReader { + /// Create a new `TokioStreamReader` from an inner reader + pub fn new(inner: T) -> Self { + Self(inner) + } + + /// Return the inner reader + pub fn into_inner(self) -> T { + self.0 + } +} impl AsyncStreamReader for TokioStreamReader { - async fn read(&mut self, len: usize) -> io::Result { + async fn read_bytes(&mut self, len: usize) -> io::Result { let mut buf = Vec::with_capacity(len.min(MAX_PREALLOC)); (&mut self.0).take(len as u64).read_to_end(&mut buf).await?; Ok(buf.into()) } + + async fn read(&mut self) -> io::Result<[u8; L]> { + let mut buf = [0; L]; + self.0.read_exact(&mut buf).await?; + Ok(buf) + } }