Skip to content

Commit

Permalink
Read to uninitialized buffer
Browse files Browse the repository at this point in the history
Currently, when using an uninitialized buffer for `read`, as would be
typical in C or required for `Read::read_buf`, it is casted from `*mut
u8` at the FFI boundary in `sys_read`/`sys_readv` to a `&[u8]`. I think
this is unsound.

Instead, use `&[MaybeUninit<u8>]` internally. I use this instead of
`core::io::BorrowedCursor<'_>`, because there are currently no cases
where the initialized portion needs to be separately tracked.

This enables implementing `std::io::Read::read_buf` for `std::fs::File`
and `std::io::Stdin` on Hermit. That effort is tracked in
rust-lang/rust#136756.
  • Loading branch information
thaliaarchi committed Feb 19, 2025
1 parent 9123cda commit 9951a8d
Show file tree
Hide file tree
Showing 12 changed files with 39 additions and 29 deletions.
9 changes: 4 additions & 5 deletions src/fd/eventfd.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use alloc::boxed::Box;
use alloc::collections::vec_deque::VecDeque;
use core::future::{self, Future};
use core::mem;
use core::mem::{self, MaybeUninit};
use core::task::{Poll, Waker, ready};

use async_lock::Mutex;
Expand Down Expand Up @@ -45,7 +45,7 @@ impl EventFd {

#[async_trait]
impl ObjectInterface for EventFd {
async fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
async fn read(&self, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
let len = mem::size_of::<u64>();

if buf.len() < len {
Expand All @@ -58,8 +58,7 @@ impl ObjectInterface for EventFd {
let mut guard = ready!(pinned.as_mut().poll(cx));
if guard.counter > 0 {
guard.counter -= 1;
let tmp = u64::to_ne_bytes(1);
buf[..len].copy_from_slice(&tmp);
buf[..len].write_copy_of_slice(&u64::to_ne_bytes(1));
if let Some(cx) = guard.write_queue.pop_front() {
cx.wake_by_ref();
}
Expand All @@ -74,7 +73,7 @@ impl ObjectInterface for EventFd {
let tmp = guard.counter;
if tmp > 0 {
guard.counter = 0;
buf[..len].copy_from_slice(&u64::to_ne_bytes(tmp));
buf[..len].write_copy_of_slice(&u64::to_ne_bytes(tmp));
if let Some(cx) = guard.read_queue.pop_front() {
cx.wake_by_ref();
}
Expand Down
5 changes: 3 additions & 2 deletions src/fd/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use alloc::boxed::Box;
use alloc::sync::Arc;
use alloc::vec::Vec;
use core::future::{self, Future};
use core::mem::MaybeUninit;
use core::task::Poll::{Pending, Ready};
use core::time::Duration;

Expand Down Expand Up @@ -152,7 +153,7 @@ pub(crate) trait ObjectInterface: Sync + Send + core::fmt::Debug {

/// `async_read` attempts to read `len` bytes from the object references
/// by the descriptor
async fn read(&self, _buf: &mut [u8]) -> io::Result<usize> {
async fn read(&self, _buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
Err(io::Error::ENOSYS)
}

Expand Down Expand Up @@ -264,7 +265,7 @@ pub(crate) trait ObjectInterface: Sync + Send + core::fmt::Debug {
}
}

pub(crate) fn read(fd: FileDescriptor, buf: &mut [u8]) -> io::Result<usize> {
pub(crate) fn read(fd: FileDescriptor, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
let obj = get_object(fd)?;

if buf.is_empty() {
Expand Down
7 changes: 4 additions & 3 deletions src/fd/socket/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use alloc::boxed::Box;
use alloc::collections::BTreeSet;
use alloc::sync::Arc;
use core::future;
use core::mem::MaybeUninit;
use core::sync::atomic::{AtomicU16, Ordering};
use core::task::Poll;

Expand Down Expand Up @@ -171,7 +172,7 @@ impl Socket {
.await
}

async fn read(&self, buffer: &mut [u8]) -> io::Result<usize> {
async fn read(&self, buffer: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
future::poll_fn(|cx| {
self.with(|socket| {
let state = socket.state();
Expand All @@ -187,7 +188,7 @@ impl Socket {
socket
.recv(|data| {
let len = core::cmp::min(buffer.len(), data.len());
buffer[..len].copy_from_slice(&data[..len]);
buffer[..len].write_copy_of_slice(&data[..len]);
(len, len)
})
.map_err(|_| io::Error::EIO),
Expand Down Expand Up @@ -468,7 +469,7 @@ impl ObjectInterface for async_lock::RwLock<Socket> {
self.read().await.poll(event).await
}

async fn read(&self, buffer: &mut [u8]) -> io::Result<usize> {
async fn read(&self, buffer: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
self.read().await.read(buffer).await
}

Expand Down
9 changes: 5 additions & 4 deletions src/fd/socket/vsock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use alloc::boxed::Box;
use alloc::sync::Arc;
use alloc::vec::Vec;
use core::future;
use core::mem::MaybeUninit;
use core::task::Poll;

use async_trait::async_trait;
Expand Down Expand Up @@ -312,7 +313,7 @@ impl Socket {
}
}

async fn read(&self, buffer: &mut [u8]) -> io::Result<usize> {
async fn read(&self, buffer: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
let port = self.port;
future::poll_fn(|cx| {
let mut guard = VSOCK_MAP.lock();
Expand All @@ -331,7 +332,7 @@ impl Socket {
}
} else {
let tmp: Vec<_> = raw.buffer.drain(..len).collect();
buffer[..len].copy_from_slice(tmp.as_slice());
buffer[..len].write_copy_of_slice(tmp.as_slice());

Poll::Ready(Ok(len))
}
Expand All @@ -343,7 +344,7 @@ impl Socket {
Poll::Ready(Ok(0))
} else {
let tmp: Vec<_> = raw.buffer.drain(..len).collect();
buffer[..len].copy_from_slice(tmp.as_slice());
buffer[..len].write_copy_of_slice(tmp.as_slice());

Poll::Ready(Ok(len))
}
Expand Down Expand Up @@ -424,7 +425,7 @@ impl ObjectInterface for async_lock::RwLock<Socket> {
self.read().await.poll(event).await
}

async fn read(&self, buffer: &mut [u8]) -> io::Result<usize> {
async fn read(&self, buffer: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
self.read().await.read(buffer).await
}

Expand Down
5 changes: 3 additions & 2 deletions src/fd/stdio.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use alloc::boxed::Box;
use core::future;
use core::mem::MaybeUninit;
use core::task::Poll;

use async_trait::async_trait;
Expand Down Expand Up @@ -27,7 +28,7 @@ impl ObjectInterface for GenericStdin {
Ok(event & available)
}

async fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
async fn read(&self, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
future::poll_fn(|cx| {
let mut read_bytes = 0;
let mut guard = CONSOLE.lock();
Expand All @@ -36,7 +37,7 @@ impl ObjectInterface for GenericStdin {
let c = unsafe { char::from_u32_unchecked(byte.into()) };
guard.write(c.as_bytes());

buf[read_bytes] = byte;
buf[read_bytes].write(byte);
read_bytes += 1;

if read_bytes >= buf.len() {
Expand Down
7 changes: 4 additions & 3 deletions src/fs/fuse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use alloc::ffi::CString;
use alloc::string::String;
use alloc::sync::Arc;
use alloc::vec::Vec;
use core::mem::MaybeUninit;
use core::sync::atomic::{AtomicU64, Ordering};
use core::task::Poll;
use core::{future, mem};
Expand Down Expand Up @@ -629,7 +630,7 @@ impl FuseFileHandleInner {
.await
}

fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
fn read(&mut self, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
let mut len = buf.len();
if len > MAX_READ_LEN {
debug!("Reading longer than max_read_len: {}", len);
Expand All @@ -651,7 +652,7 @@ impl FuseFileHandleInner {
};
self.offset += len;

buf[..len].copy_from_slice(&rsp.payload.unwrap()[..len]);
buf[..len].write_copy_of_slice(&rsp.payload.unwrap()[..len]);

Ok(len)
} else {
Expand Down Expand Up @@ -767,7 +768,7 @@ impl ObjectInterface for FuseFileHandle {
self.0.lock().await.poll(event).await
}

async fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
async fn read(&self, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
self.0.lock().await.read(buf)
}

Expand Down
11 changes: 6 additions & 5 deletions src/fs/mem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use alloc::collections::BTreeMap;
use alloc::string::{String, ToString};
use alloc::sync::Arc;
use alloc::vec::Vec;
use core::mem::MaybeUninit;

use async_lock::{Mutex, RwLock};
use async_trait::async_trait;
Expand Down Expand Up @@ -59,7 +60,7 @@ impl ObjectInterface for RomFileInterface {
Ok(ret)
}

async fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
async fn read(&self, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
{
let microseconds = arch::kernel::systemtime::now_micros();
let t = timespec::from_usec(microseconds as i64);
Expand All @@ -81,7 +82,7 @@ impl ObjectInterface for RomFileInterface {
buf.len()
};

buf[0..len].clone_from_slice(&vec[pos..pos + len]);
buf[..len].write_copy_of_slice(&vec[pos..pos + len]);
*pos_guard = pos + len;

Ok(len)
Expand Down Expand Up @@ -170,7 +171,7 @@ impl ObjectInterface for RamFileInterface {
Ok(event & available)
}

async fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
async fn read(&self, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
{
let microseconds = arch::kernel::systemtime::now_micros();
let t = timespec::from_usec(microseconds as i64);
Expand All @@ -192,7 +193,7 @@ impl ObjectInterface for RamFileInterface {
buf.len()
};

buf[0..len].clone_from_slice(&guard.data[pos..pos + len]);
buf[..len].write_copy_of_slice(&guard.data[pos..pos + len]);
*pos_guard = pos + len;

Ok(len)
Expand All @@ -214,7 +215,7 @@ impl ObjectInterface for RamFileInterface {
guard.attr.st_mtim = t;
guard.attr.st_ctim = t;

guard.data[pos..pos + buf.len()].clone_from_slice(buf);
guard.data[pos..pos + buf.len()].copy_from_slice(buf);
*pos_guard = pos + buf.len();

Ok(buf.len())
Expand Down
1 change: 1 addition & 0 deletions src/fs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,7 @@ impl File {

impl crate::io::Read for File {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let buf = unsafe { core::slice::from_raw_parts_mut(buf.as_mut_ptr().cast(), buf.len()) };
fd::read(self.fd, buf)
}
}
Expand Down
5 changes: 3 additions & 2 deletions src/fs/uhyve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use alloc::ffi::CString;
use alloc::string::{String, ToString};
use alloc::sync::Arc;
use alloc::vec::Vec;
use core::mem::MaybeUninit;

use async_lock::Mutex;
use async_trait::async_trait;
Expand All @@ -29,7 +30,7 @@ impl UhyveFileHandleInner {
Self(fd)
}

fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
fn read(&mut self, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
let mut read_params = ReadParams {
fd: self.0,
buf: GuestVirtAddr::new(buf.as_mut_ptr() as u64),
Expand Down Expand Up @@ -94,7 +95,7 @@ impl UhyveFileHandle {

#[async_trait]
impl ObjectInterface for UhyveFileHandle {
async fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
async fn read(&self, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
self.0.lock().await.read(buf)
}

Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#![feature(map_try_insert)]
#![feature(maybe_uninit_as_bytes)]
#![feature(maybe_uninit_slice)]
#![feature(maybe_uninit_write_slice)]
#![feature(naked_functions)]
#![feature(never_type)]
#![feature(slice_from_ptr_range)]
Expand Down
6 changes: 4 additions & 2 deletions src/syscalls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ pub extern "C" fn sys_close(fd: FileDescriptor) -> i32 {
#[hermit_macro::system]
#[unsafe(no_mangle)]
pub unsafe extern "C" fn sys_read(fd: FileDescriptor, buf: *mut u8, len: usize) -> isize {
let slice = unsafe { core::slice::from_raw_parts_mut(buf, len) };
let slice = unsafe { core::slice::from_raw_parts_mut(buf.cast(), len) };
crate::fd::read(fd, slice).map_or_else(
|e| -num::ToPrimitive::to_isize(&e).unwrap(),
|v| v.try_into().unwrap(),
Expand Down Expand Up @@ -420,7 +420,9 @@ pub unsafe extern "C" fn sys_readv(fd: i32, iov: *const iovec, iovcnt: usize) ->
let iovec_buffers = unsafe { core::slice::from_raw_parts(iov, iovcnt) };

for iovec_buf in iovec_buffers {
let buf = unsafe { core::slice::from_raw_parts_mut(iovec_buf.iov_base, iovec_buf.iov_len) };
let buf = unsafe {
core::slice::from_raw_parts_mut(iovec_buf.iov_base.cast(), iovec_buf.iov_len)
};

let len = crate::fd::read(fd, buf).map_or_else(
|e| -num::ToPrimitive::to_isize(&e).unwrap(),
Expand Down
2 changes: 1 addition & 1 deletion src/syscalls/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -884,7 +884,7 @@ pub extern "C" fn sys_shutdown_socket(fd: i32, how: i32) -> i32 {
#[unsafe(no_mangle)]
pub unsafe extern "C" fn sys_recv(fd: i32, buf: *mut u8, len: usize, flags: i32) -> isize {
if flags == 0 {
let slice = unsafe { core::slice::from_raw_parts_mut(buf, len) };
let slice = unsafe { core::slice::from_raw_parts_mut(buf.cast(), len) };
crate::fd::read(fd, slice).map_or_else(
|e| -num::ToPrimitive::to_isize(&e).unwrap(),
|v| v.try_into().unwrap(),
Expand Down

0 comments on commit 9951a8d

Please sign in to comment.