From a021be7e50207d7415db47477881a1464851e307 Mon Sep 17 00:00:00 2001 From: Noah Kennedy Date: Tue, 21 Feb 2023 14:29:43 -0600 Subject: [PATCH] rt: add new api for oneshot op submission and creation (#244) This change adds a new API for creating and submitting oneshot operations. This new API is intended to obsolete the existing system, however transferring over has been left out of this PR to keep the change small. It is intended that a similar API for multishot operations be landed in a followup PR as well. This was also left out of this PR to keep the change small. This refactoring paves the way for further work on SQE linking, multishot operations, and other improvements and additions. The goal of this change is to allow us to split up the oneshot and multishot logic to allow for a cleaner and more extensible system, allow for user-provided operations, and allow users to control when and how their ops get submitted to the squeue. --- examples/mix.rs | 2 +- examples/tcp_stream.rs | 2 +- examples/unix_listener.rs | 2 +- examples/unix_stream.rs | 2 +- examples/wrk-bench.rs | 2 +- src/fs/file.rs | 16 ++--- src/io/mod.rs | 2 +- src/io/socket.rs | 9 +-- src/io/write.rs | 79 ++++++++++++---------- src/lib.rs | 2 + src/net/tcp/listener.rs | 2 +- src/net/tcp/stream.rs | 7 +- src/net/udp.rs | 7 +- src/net/unix/listener.rs | 2 +- src/net/unix/stream.rs | 7 +- src/runtime/driver/handle.rs | 14 +++- src/runtime/driver/mod.rs | 127 ++++++++++++++++++++++++++--------- src/runtime/driver/op/mod.rs | 124 +++++++++++++++++++++++++++++++--- tests/driver.rs | 6 +- tests/fs_file.rs | 6 +- 20 files changed, 312 insertions(+), 108 deletions(-) diff --git a/examples/mix.rs b/examples/mix.rs index e55a5247..4e094019 100644 --- a/examples/mix.rs +++ b/examples/mix.rs @@ -41,7 +41,7 @@ fn main() { break; } - let (res, b) = socket.write(b).await; + let (res, b) = socket.write(b).submit().await; pos += res.unwrap() as u64; buf = b; diff --git a/examples/tcp_stream.rs b/examples/tcp_stream.rs index 7c56057f..4983ee4c 100644 --- a/examples/tcp_stream.rs +++ b/examples/tcp_stream.rs @@ -15,7 +15,7 @@ fn main() { let stream = TcpStream::connect(socket_addr).await.unwrap(); let buf = vec![1u8; 128]; - let (result, buf) = stream.write(buf).await; + let (result, buf) = stream.write(buf).submit().await; println!("written: {}", result.unwrap()); let (result, buf) = stream.read(buf).await; diff --git a/examples/unix_listener.rs b/examples/unix_listener.rs index e3916070..9e10496d 100644 --- a/examples/unix_listener.rs +++ b/examples/unix_listener.rs @@ -20,7 +20,7 @@ fn main() { tokio_uring::spawn(async move { let buf = vec![1u8; 128]; - let (result, buf) = stream.write(buf).await; + let (result, buf) = stream.write(buf).submit().await; println!("written to {}: {}", &socket_addr, result.unwrap()); let (result, buf) = stream.read(buf).await; diff --git a/examples/unix_stream.rs b/examples/unix_stream.rs index 5e48951a..7caf06f9 100644 --- a/examples/unix_stream.rs +++ b/examples/unix_stream.rs @@ -15,7 +15,7 @@ fn main() { let stream = UnixStream::connect(socket_addr).await.unwrap(); let buf = vec![1u8; 128]; - let (result, buf) = stream.write(buf).await; + let (result, buf) = stream.write(buf).submit().await; println!("written: {}", result.unwrap()); let (result, buf) = stream.read(buf).await; diff --git a/examples/wrk-bench.rs b/examples/wrk-bench.rs index 4a76ed62..222df76a 100644 --- a/examples/wrk-bench.rs +++ b/examples/wrk-bench.rs @@ -21,7 +21,7 @@ fn main() -> io::Result<()> { let (stream, _) = listener.accept().await?; tokio_uring::spawn(async move { - let (result, _) = stream.write(RESPONSE).await; + let (result, _) = stream.write(RESPONSE).submit().await; if let Err(err) = result { eprintln!("Client connection failed: {}", err); diff --git a/src/fs/file.rs b/src/fs/file.rs index ca7d7a2e..9cd47f21 100644 --- a/src/fs/file.rs +++ b/src/fs/file.rs @@ -4,6 +4,7 @@ use crate::fs::OpenOptions; use crate::io::SharedFd; use crate::runtime::driver::op::Op; +use crate::{UnsubmittedOneshot, UnsubmittedWrite}; use std::fmt; use std::io; use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; @@ -38,7 +39,7 @@ use std::path::Path; /// let file = File::create("hello.txt").await?; /// /// // Write some data -/// let (res, buf) = file.write_at(&b"hello world"[..], 0).await; +/// let (res, buf) = file.write_at(&b"hello world"[..], 0).submit().await; /// let n = res?; /// /// println!("wrote {} bytes", n); @@ -525,7 +526,7 @@ impl File { /// let file = File::create("foo.txt").await?; /// /// // Writes some prefix of the byte string, not necessarily all of it. - /// let (res, _) = file.write_at(&b"some bytes"[..], 0).await; + /// let (res, _) = file.write_at(&b"some bytes"[..], 0).submit().await; /// let n = res?; /// /// println!("wrote {} bytes", n); @@ -538,9 +539,8 @@ impl File { /// ``` /// /// [`Ok(n)`]: Ok - pub async fn write_at(&self, buf: T, pos: u64) -> crate::BufResult { - let op = Op::write_at(&self.fd, buf, pos).unwrap(); - op.await + pub fn write_at(&self, buf: T, pos: u64) -> UnsubmittedWrite { + UnsubmittedOneshot::write_at(&self.fd, buf, pos) } /// Attempts to write an entire buffer into this file at the specified offset. @@ -609,7 +609,7 @@ impl File { } while buf.bytes_init() != 0 { - let (res, slice) = self.write_at(buf, pos).await; + let (res, slice) = self.write_at(buf, pos).submit().await; match res { Ok(0) => { return ( @@ -773,7 +773,7 @@ impl File { /// fn main() -> Result<(), Box> { /// tokio_uring::start(async { /// let f = File::create("foo.txt").await?; - /// let (res, buf) = f.write_at(&b"Hello, world!"[..], 0).await; + /// let (res, buf) = f.write_at(&b"Hello, world!"[..], 0).submit().await; /// let n = res?; /// /// f.sync_all().await?; @@ -810,7 +810,7 @@ impl File { /// fn main() -> Result<(), Box> { /// tokio_uring::start(async { /// let f = File::create("foo.txt").await?; - /// let (res, buf) = f.write_at(&b"Hello, world!"[..], 0).await; + /// let (res, buf) = f.write_at(&b"Hello, world!"[..], 0).submit().await; /// let n = res?; /// /// f.sync_data().await?; diff --git a/src/io/mod.rs b/src/io/mod.rs index ae1242be..578c3418 100644 --- a/src/io/mod.rs +++ b/src/io/mod.rs @@ -44,7 +44,7 @@ mod unlink_at; mod util; pub(crate) use util::cstr; -mod write; +pub(crate) mod write; mod write_fixed; diff --git a/src/io/socket.rs b/src/io/socket.rs index ff183ac2..0b467ff6 100644 --- a/src/io/socket.rs +++ b/src/io/socket.rs @@ -1,8 +1,10 @@ +use crate::io::write::UnsubmittedWrite; use crate::runtime::driver::op::Op; use crate::{ buf::fixed::FixedBuf, buf::{BoundedBuf, BoundedBufMut, IoBuf, Slice}, io::SharedFd, + UnsubmittedOneshot, }; use std::{ io, @@ -41,9 +43,8 @@ impl Socket { Ok(Socket { fd }) } - pub(crate) async fn write(&self, buf: T) -> crate::BufResult { - let op = Op::write_at(&self.fd, buf, 0).unwrap(); - op.await + pub(crate) fn write(&self, buf: T) -> UnsubmittedWrite { + UnsubmittedOneshot::write_at(&self.fd, buf, 0) } pub async fn write_all(&self, buf: T) -> crate::BufResult<(), T> { @@ -54,7 +55,7 @@ impl Socket { async fn write_all_slice(&self, mut buf: Slice) -> crate::BufResult<(), T> { while buf.bytes_init() != 0 { - let res = self.write(buf).await; + let res = self.write(buf).submit().await; match res { (Ok(0), slice) => { return ( diff --git a/src/io/write.rs b/src/io/write.rs index 9775f4fe..ddb0408e 100644 --- a/src/io/write.rs +++ b/src/io/write.rs @@ -1,50 +1,59 @@ -use crate::runtime::driver::op::{Completable, CqeResult, Op}; -use crate::runtime::CONTEXT; -use crate::{buf::BoundedBuf, io::SharedFd, BufResult}; +use crate::{buf::BoundedBuf, io::SharedFd, BufResult, OneshotOutputTransform, UnsubmittedOneshot}; +use io_uring::cqueue::Entry; use std::io; +use std::marker::PhantomData; -pub(crate) struct Write { +/// An unsubmitted write operation. +pub type UnsubmittedWrite = UnsubmittedOneshot, WriteTransform>; + +#[allow(missing_docs)] +pub struct WriteData { /// Holds a strong ref to the FD, preventing the file from being closed /// while the operation is in-flight. - #[allow(dead_code)] - fd: SharedFd, + _fd: SharedFd, buf: T, } -impl Op> { - pub(crate) fn write_at(fd: &SharedFd, buf: T, offset: u64) -> io::Result>> { - use io_uring::{opcode, types}; - - CONTEXT.with(|x| { - x.handle().expect("Not in a runtime context").submit_op( - Write { - fd: fd.clone(), - buf, - }, - |write| { - // Get raw buffer info - let ptr = write.buf.stable_ptr(); - let len = write.buf.bytes_init(); - - opcode::Write::new(types::Fd(fd.raw_fd()), ptr, len as _) - .offset(offset as _) - .build() - }, - ) - }) - } +#[allow(missing_docs)] +pub struct WriteTransform { + _phantom: PhantomData, } -impl Completable for Write { +impl OneshotOutputTransform for WriteTransform { type Output = BufResult; + type StoredData = WriteData; + + fn transform_oneshot_output(self, data: Self::StoredData, cqe: Entry) -> Self::Output { + let res = if cqe.result() >= 0 { + Ok(cqe.result() as usize) + } else { + Err(io::Error::from_raw_os_error(cqe.result())) + }; - fn complete(self, cqe: CqeResult) -> Self::Output { - // Convert the operation result to `usize` - let res = cqe.result.map(|v| v as usize); - // Recover the buffer - let buf = self.buf; + (res, data.buf) + } +} + +impl UnsubmittedWrite { + pub(crate) fn write_at(fd: &SharedFd, buf: T, offset: u64) -> Self { + use io_uring::{opcode, types}; - (res, buf) + // Get raw buffer info + let ptr = buf.stable_ptr(); + let len = buf.bytes_init(); + + Self::new( + WriteData { + _fd: fd.clone(), + buf, + }, + WriteTransform { + _phantom: PhantomData::default(), + }, + opcode::Write::new(types::Fd(fd.raw_fd()), ptr, len as _) + .offset(offset as _) + .build(), + ) } } diff --git a/src/lib.rs b/src/lib.rs index 39348138..d1cc6e02 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -78,6 +78,8 @@ pub mod buf; pub mod fs; pub mod net; +pub use io::write::*; +pub use runtime::driver::op::{InFlightOneshot, OneshotOutputTransform, UnsubmittedOneshot}; pub use runtime::spawn; pub use runtime::Runtime; diff --git a/src/net/tcp/listener.rs b/src/net/tcp/listener.rs index 365373d6..98ca8fdd 100644 --- a/src/net/tcp/listener.rs +++ b/src/net/tcp/listener.rs @@ -29,7 +29,7 @@ use std::{io, net::SocketAddr}; /// let tx = TcpStream::connect("127.0.0.1:2345".parse().unwrap()).await.unwrap(); /// let rx = rx_ch.await.expect("The spawned task expected to send a TcpStream"); /// -/// tx.write(b"test" as &'static [u8]).await.0.unwrap(); +/// tx.write(b"test" as &'static [u8]).submit().await.0.unwrap(); /// /// let (_, buf) = rx.read(vec![0; 4]).await; /// diff --git a/src/net/tcp/stream.rs b/src/net/tcp/stream.rs index bc81bc8e..2450dcb9 100644 --- a/src/net/tcp/stream.rs +++ b/src/net/tcp/stream.rs @@ -8,6 +8,7 @@ use crate::{ buf::fixed::FixedBuf, buf::{BoundedBuf, BoundedBufMut}, io::{SharedFd, Socket}, + UnsubmittedWrite, }; /// A TCP stream between a local and a remote socket. @@ -27,7 +28,7 @@ use crate::{ /// let mut stream = TcpStream::connect("127.0.0.1:8080".parse().unwrap()).await?; /// /// // Write some data. -/// let (result, _) = stream.write(b"hello world!".as_slice()).await; +/// let (result, _) = stream.write(b"hello world!".as_slice()).submit().await; /// result.unwrap(); /// /// Ok(()) @@ -100,8 +101,8 @@ impl TcpStream { /// Write some data to the stream from the buffer. /// /// Returns the original buffer and quantity of data written. - pub async fn write(&self, buf: T) -> crate::BufResult { - self.inner.write(buf).await + pub fn write(&self, buf: T) -> UnsubmittedWrite { + self.inner.write(buf) } /// Attempts to write an entire buffer to the stream. diff --git a/src/net/udp.rs b/src/net/udp.rs index 13510a1b..42ba2456 100644 --- a/src/net/udp.rs +++ b/src/net/udp.rs @@ -2,6 +2,7 @@ use crate::{ buf::fixed::FixedBuf, buf::{BoundedBuf, BoundedBufMut}, io::{SharedFd, Socket}, + UnsubmittedWrite, }; use socket2::SockAddr; use std::{ @@ -42,7 +43,7 @@ use std::{ /// let buf = vec![0; 32]; /// /// // write data -/// let (result, _) = socket.write(b"hello world".as_slice()).await; +/// let (result, _) = socket.write(b"hello world".as_slice()).submit().await; /// result.unwrap(); /// /// // read data @@ -312,8 +313,8 @@ impl UdpSocket { /// Writes data into the socket from the specified buffer. /// /// Returns the original buffer and quantity of data written. - pub async fn write(&self, buf: T) -> crate::BufResult { - self.inner.write(buf).await + pub fn write(&self, buf: T) -> UnsubmittedWrite { + self.inner.write(buf) } /// Writes data into the socket from a registered buffer. diff --git a/src/net/unix/listener.rs b/src/net/unix/listener.rs index 70ed6089..ffabb5d2 100644 --- a/src/net/unix/listener.rs +++ b/src/net/unix/listener.rs @@ -30,7 +30,7 @@ use std::{io, path::Path}; /// let tx = UnixStream::connect(&sock_file).await.unwrap(); /// let rx = rx_ch.await.expect("The spawned task expected to send a UnixStream"); /// -/// tx.write(b"test" as &'static [u8]).await.0.unwrap(); +/// tx.write(b"test" as &'static [u8]).submit().await.0.unwrap(); /// /// let (_, buf) = rx.read(vec![0; 4]).await; /// diff --git a/src/net/unix/stream.rs b/src/net/unix/stream.rs index e93a6904..40e7ddc5 100644 --- a/src/net/unix/stream.rs +++ b/src/net/unix/stream.rs @@ -2,6 +2,7 @@ use crate::{ buf::fixed::FixedBuf, buf::{BoundedBuf, BoundedBufMut}, io::{SharedFd, Socket}, + UnsubmittedWrite, }; use socket2::SockAddr; use std::{ @@ -27,7 +28,7 @@ use std::{ /// let mut stream = UnixStream::connect("/tmp/tokio-uring-unix-test.sock").await?; /// /// // Write some data. -/// let (result, _) = stream.write(b"hello world!".as_slice()).await; +/// let (result, _) = stream.write(b"hello world!".as_slice()).submit().await; /// result.unwrap(); /// /// Ok(()) @@ -98,8 +99,8 @@ impl UnixStream { /// Write some data to the stream from the buffer, returning the original buffer and /// quantity of data written. - pub async fn write(&self, buf: T) -> crate::BufResult { - self.inner.write(buf).await + pub fn write(&self, buf: T) -> UnsubmittedWrite { + self.inner.write(buf) } /// Attempts to write an entire buffer to the stream. diff --git a/src/runtime/driver/handle.rs b/src/runtime/driver/handle.rs index ab3dcc51..115f780d 100644 --- a/src/runtime/driver/handle.rs +++ b/src/runtime/driver/handle.rs @@ -12,7 +12,7 @@ //! The weak handle should be used by anything which is stored in the driver or does not need to //! keep the driver alive for it's duration. -use io_uring::squeue; +use io_uring::{cqueue, squeue}; use std::cell::RefCell; use std::io; use std::ops::Deref; @@ -63,6 +63,10 @@ impl Handle { self.inner.borrow_mut().unregister_buffers(buffers) } + pub(crate) fn submit_op_2(&self, sqe: squeue::Entry) -> usize { + self.inner.borrow_mut().submit_op_2(sqe) + } + pub(crate) fn submit_op(&self, data: T, f: F) -> io::Result> where T: Completable, @@ -78,6 +82,10 @@ impl Handle { self.inner.borrow_mut().poll_op(op, cx) } + pub(crate) fn poll_op_2(&self, index: usize, cx: &mut Context<'_>) -> Poll { + self.inner.borrow_mut().poll_op_2(index, cx) + } + pub(crate) fn poll_multishot_op( &self, op: &mut Op, @@ -92,6 +100,10 @@ impl Handle { pub(crate) fn remove_op(&self, op: &mut Op) { self.inner.borrow_mut().remove_op(op) } + + pub(crate) fn remove_op_2(&self, index: usize, data: T) { + self.inner.borrow_mut().remove_op_2(index, data) + } } impl WeakHandle { diff --git a/src/runtime/driver/mod.rs b/src/runtime/driver/mod.rs index ab80624b..21d7de0b 100644 --- a/src/runtime/driver/mod.rs +++ b/src/runtime/driver/mod.rs @@ -4,10 +4,10 @@ use io_uring::opcode::AsyncCancel; use io_uring::{cqueue, squeue, IoUring}; use slab::Slab; use std::cell::RefCell; -use std::io; use std::os::unix::io::{AsRawFd, RawFd}; use std::rc::Rc; use std::task::{Context, Poll}; +use std::{io, mem}; pub(crate) use handle::*; @@ -89,7 +89,7 @@ impl Driver { let index = cqe.user_data() as _; - self.ops.complete(index, cqe.into()); + self.ops.complete(index, cqe); } } @@ -122,6 +122,21 @@ impl Driver { )) } + pub(crate) fn submit_op_2(&mut self, sqe: squeue::Entry) -> usize { + let index = self.ops.insert(); + + // Configure the SQE + let sqe = sqe.user_data(index as _); + + // Push the new operation + while unsafe { self.uring.submission().push(&sqe).is_err() } { + // If the submission queue is full, flush it to the kernel + self.submit().expect("Internal error, failed to submit ops"); + } + + index + } + pub(crate) fn submit_op( &mut self, mut data: T, @@ -150,8 +165,6 @@ impl Driver { } pub(crate) fn remove_op(&mut self, op: &mut Op) { - use std::mem; - // Get the Op Lifecycle state from the driver let (lifecycle, completions) = match self.ops.get_mut(op.index()) { Some(val) => val, @@ -186,12 +199,72 @@ impl Driver { } } + pub(crate) fn remove_op_2(&mut self, index: usize, data: T) { + // Get the Op Lifecycle state from the driver + let (lifecycle, completions) = match self.ops.get_mut(index) { + Some(val) => val, + None => { + // Op dropped after the driver + return; + } + }; + + match mem::replace(lifecycle, Lifecycle::Submitted) { + Lifecycle::Submitted | Lifecycle::Waiting(_) => { + *lifecycle = Lifecycle::Ignored(Box::new(data)); + } + Lifecycle::Completed(..) => { + self.ops.remove(index); + } + Lifecycle::CompletionList(indices) => { + // Deallocate list entries, recording if more CQE's are expected + let more = { + let mut list = indices.into_list(completions); + cqueue::more(list.peek_end().unwrap().flags) + // Dropping list deallocates the list entries + }; + if more { + // If more are expected, we have to keep the op around + *lifecycle = Lifecycle::Ignored(Box::new(data)); + } else { + self.ops.remove(index); + } + } + Lifecycle::Ignored(..) => unreachable!(), + } + } + + pub(crate) fn poll_op_2(&mut self, index: usize, cx: &mut Context<'_>) -> Poll { + let (lifecycle, _) = self.ops.get_mut(index).expect("invalid internal state"); + + match mem::replace(lifecycle, Lifecycle::Submitted) { + Lifecycle::Submitted => { + *lifecycle = Lifecycle::Waiting(cx.waker().clone()); + Poll::Pending + } + Lifecycle::Waiting(waker) if !waker.will_wake(cx.waker()) => { + *lifecycle = Lifecycle::Waiting(cx.waker().clone()); + Poll::Pending + } + Lifecycle::Waiting(waker) => { + *lifecycle = Lifecycle::Waiting(waker); + Poll::Pending + } + Lifecycle::Ignored(..) => unreachable!(), + Lifecycle::Completed(cqe) => { + self.ops.remove(index); + Poll::Ready(cqe) + } + Lifecycle::CompletionList(..) => { + unreachable!("No `more` flag set for SingleCQE") + } + } + } + pub(crate) fn poll_op(&mut self, op: &mut Op, cx: &mut Context<'_>) -> Poll where T: Unpin + 'static + Completable, { - use std::mem; - let (lifecycle, _) = self .ops .get_mut(op.index()) @@ -213,7 +286,7 @@ impl Driver { Lifecycle::Ignored(..) => unreachable!(), Lifecycle::Completed(cqe) => { self.ops.remove(op.index()); - Poll::Ready(op.take_data().unwrap().complete(cqe)) + Poll::Ready(op.take_data().unwrap().complete(cqe.into())) } Lifecycle::CompletionList(..) => { unreachable!("No `more` flag set for SingleCQE") @@ -229,8 +302,6 @@ impl Driver { where T: Unpin + 'static + Completable + Updateable, { - use std::mem; - let (lifecycle, completions) = self .ops .get_mut(op.index()) @@ -254,7 +325,7 @@ impl Driver { // This is possible. We may have previously polled a CompletionList, // and the final CQE registered as Completed self.ops.remove(op.index()); - Poll::Ready(op.take_data().unwrap().complete(cqe)) + Poll::Ready(op.take_data().unwrap().complete(cqe.into())) } Lifecycle::CompletionList(indices) => { let mut data = op.take_data().unwrap(); @@ -322,10 +393,9 @@ impl Drop for Driver { let mut list = indices.clone().into_list(&mut self.ops.completions); if !io_uring::cqueue::more(list.peek_end().unwrap().flags) { // This op is complete. Replace with a null Completed entry - *cycle = Lifecycle::Completed(op::CqeResult { - result: Ok(0), - flags: 0, - }); + // safety: zeroed memory is entirely valid with this underlying + // representation + *cycle = Lifecycle::Completed(unsafe { mem::zeroed() }); } } @@ -414,7 +484,7 @@ impl Ops { self.lifecycle.remove(index); } - fn complete(&mut self, index: usize, cqe: op::CqeResult) { + fn complete(&mut self, index: usize, cqe: cqueue::Entry) { let completions = &mut self.completions; if self.lifecycle[index].complete(completions, cqe) { self.lifecycle.remove(index); @@ -478,7 +548,7 @@ mod test { assert_pending!(op.poll()); assert_eq!(2, Rc::strong_count(&data)); - complete(&op, Ok(1)); + complete(&op); assert_eq!(1, num_operations()); assert_eq!(2, Rc::strong_count(&data)); @@ -489,7 +559,7 @@ mod test { data: d, } = assert_ready!(op.poll()); assert_eq!(2, Rc::strong_count(&data)); - assert_eq!(1, result.unwrap()); + assert_eq!(0, result.unwrap()); assert_eq!(0, flags); drop(d); @@ -509,11 +579,11 @@ mod test { assert_pending!(op.poll()); assert_pending!(op.poll()); - complete(&op, Ok(1)); + complete(&op); assert!(op.is_woken()); let Completion { result, flags, .. } = assert_ready!(op.poll()); - assert_eq!(1, result.unwrap()); + assert_eq!(0, result.unwrap()); assert_eq!(0, flags); } @@ -531,11 +601,11 @@ mod test { let mut op = task::spawn(op); assert_pending!(op.poll()); - complete(&op, Ok(1)); + complete(&op); assert!(op.is_woken()); let Completion { result, flags, .. } = assert_ready!(op.poll()); - assert_eq!(1, result.unwrap()); + assert_eq!(0, result.unwrap()); assert_eq!(0, flags); } @@ -546,12 +616,12 @@ mod test { fn complete_before_poll() { let (op, data) = init(); let mut op = task::spawn(op); - complete(&op, Ok(1)); + complete(&op); assert_eq!(1, num_operations()); assert_eq!(2, Rc::strong_count(&data)); let Completion { result, flags, .. } = assert_ready!(op.poll()); - assert_eq!(1, result.unwrap()); + assert_eq!(0, result.unwrap()); assert_eq!(0, flags); drop(op); @@ -570,18 +640,13 @@ mod test { assert_eq!(1, num_operations()); - let cqe = CqeResult { - result: Ok(1), - flags: 0, - }; - CONTEXT.with(|cx| { cx.handle() .unwrap() .inner .borrow_mut() .ops - .complete(index, cqe) + .complete(index, unsafe { mem::zeroed() }) }); assert_eq!(1, Rc::strong_count(&data)); @@ -611,8 +676,8 @@ mod test { CONTEXT.with(|cx| cx.handle().unwrap().inner.borrow().num_operations()) } - fn complete(op: &Op>, result: io::Result) { - let cqe = CqeResult { result, flags: 0 }; + fn complete(op: &Op>) { + let cqe = unsafe { mem::zeroed() }; CONTEXT.with(|cx| { let driver = cx.handle().unwrap(); diff --git a/src/runtime/driver/op/mod.rs b/src/runtime/driver/op/mod.rs index 5758a29d..32ba3e7a 100644 --- a/src/runtime/driver/op/mod.rs +++ b/src/runtime/driver/op/mod.rs @@ -4,14 +4,14 @@ use std::marker::PhantomData; use std::pin::Pin; use std::task::{Context, Poll, Waker}; -use io_uring::cqueue; +use io_uring::{cqueue, squeue}; mod slab_list; use slab::Slab; use slab_list::{SlabListEntry, SlabListIndices}; -use crate::runtime::driver; +use crate::runtime::{driver, CONTEXT}; /// A SlabList is used to hold unserved completions. /// @@ -20,6 +20,110 @@ use crate::runtime::driver; /// captured before completion. pub(crate) type Completion = SlabListEntry; +/// An unsubmitted oneshot operation. +pub struct UnsubmittedOneshot> { + stable_data: D, + post_op: T, + sqe: squeue::Entry, +} + +impl> UnsubmittedOneshot { + /// Construct a new operation for later submission. + pub fn new(stable_data: D, post_op: T, sqe: squeue::Entry) -> Self { + Self { + stable_data, + post_op, + sqe, + } + } + + /// Submit an operation to the driver for batched entry to the kernel. + pub fn submit(self) -> InFlightOneshot { + let handle = CONTEXT + .with(|x| x.handle()) + .expect("Could not submit op; not in runtime context"); + + self.submit_with_driver(&handle) + } + + fn submit_with_driver(self, driver: &driver::Handle) -> InFlightOneshot { + let index = driver.submit_op_2(self.sqe); + + let driver = driver.into(); + + let inner = InFlightOneshotInner { + index, + driver, + stable_data: self.stable_data, + post_op: self.post_op, + }; + + InFlightOneshot { inner: Some(inner) } + } +} + +/// An in-progress oneshot operation which can be polled for completion. +pub struct InFlightOneshot> { + inner: Option>, +} + +struct InFlightOneshotInner> { + driver: driver::WeakHandle, + index: usize, + stable_data: D, + post_op: T, +} + +impl + Unpin> Future for InFlightOneshot { + type Output = T::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + + let inner = this + .inner + .as_mut() + .expect("Cannot poll already-completed operation"); + + let index = inner.index; + + let upgraded = inner + .driver + .upgrade() + .expect("Failed to poll op: driver no longer exists"); + + let cqe = ready!(upgraded.poll_op_2(index, cx)); + + let inner = this.inner.take().unwrap(); + + Poll::Ready( + inner + .post_op + .transform_oneshot_output(inner.stable_data, cqe), + ) + } +} + +impl> Drop for InFlightOneshot { + fn drop(&mut self) { + if let Some(inner) = self.inner.take() { + if let Some(driver) = inner.driver.upgrade() { + driver.remove_op_2(inner.index, inner.stable_data) + } + } + } +} + +/// Transforms the output of a oneshot operation into a more user-friendly format. +pub trait OneshotOutputTransform { + /// The final output after the transformation. + type Output; + /// The stored data within the op. + type StoredData; + /// Transform the stored data and the cqe into the final output. + fn transform_oneshot_output(self, data: Self::StoredData, cqe: cqueue::Entry) -> Self::Output; +} + /// In-flight operation pub(crate) struct Op { driver: driver::WeakHandle, @@ -64,7 +168,7 @@ pub(crate) enum Lifecycle { Ignored(Box), /// The operation has completed with a single cqe result - Completed(CqeResult), + Completed(cqueue::Entry), /// One or more completion results have been recieved /// This holds the indices uniquely identifying the list within the slab @@ -156,14 +260,18 @@ impl Drop for Op { } impl Lifecycle { - pub(crate) fn complete(&mut self, completions: &mut Slab, cqe: CqeResult) -> bool { + pub(crate) fn complete( + &mut self, + completions: &mut Slab, + cqe: cqueue::Entry, + ) -> bool { use std::mem; match mem::replace(self, Lifecycle::Submitted) { x @ Lifecycle::Submitted | x @ Lifecycle::Waiting(..) => { - if io_uring::cqueue::more(cqe.flags) { + if io_uring::cqueue::more(cqe.flags()) { let mut list = SlabListIndices::new().into_list(completions); - list.push(cqe); + list.push(cqe.into()); *self = Lifecycle::CompletionList(list.into_indices()); } else { *self = Lifecycle::Completed(cqe); @@ -177,7 +285,7 @@ impl Lifecycle { } lifecycle @ Lifecycle::Ignored(..) => { - if io_uring::cqueue::more(cqe.flags) { + if io_uring::cqueue::more(cqe.flags()) { // Not yet complete. The Op has been dropped, so we can drop the CQE // but we must keep the lifecycle alive until no more CQE's expected *self = lifecycle; @@ -200,7 +308,7 @@ impl Lifecycle { // A completion list may contain CQE's with and without `more` flag set. // Only the final one may have `more` unset, although we don't check. let mut list = indices.into_list(completions); - list.push(cqe); + list.push(cqe.into()); *self = Lifecycle::CompletionList(list.into_indices()); false } diff --git a/tests/driver.rs b/tests/driver.rs index b9bda473..f4381dd5 100644 --- a/tests/driver.rs +++ b/tests/driver.rs @@ -83,7 +83,11 @@ fn too_many_submissions() { let file = File::create(tempfile.path()).await.unwrap(); for _ in 0..600 { poll_once(async { - file.write_at(b"hello world".to_vec(), 0).await.0.unwrap(); + file.write_at(b"hello world".to_vec(), 0) + .submit() + .await + .0 + .unwrap(); }) .await; } diff --git a/tests/fs_file.rs b/tests/fs_file.rs index 739fea56..6ec14d43 100644 --- a/tests/fs_file.rs +++ b/tests/fs_file.rs @@ -60,7 +60,7 @@ fn basic_write() { let file = File::create(tempfile.path()).await.unwrap(); - file.write_at(HELLO, 0).await.0.unwrap(); + file.write_at(HELLO, 0).submit().await.0.unwrap(); let file = std::fs::read(tempfile.path()).unwrap(); assert_eq!(file, HELLO); @@ -155,7 +155,7 @@ fn drop_open() { // Do something else let file = File::create(tempfile.path()).await.unwrap(); - file.write_at(HELLO, 0).await.0.unwrap(); + file.write_at(HELLO, 0).submit().await.0.unwrap(); let file = std::fs::read(tempfile.path()).unwrap(); assert_eq!(file, HELLO); @@ -183,7 +183,7 @@ fn sync_doesnt_kill_anything() { let file = File::create(tempfile.path()).await.unwrap(); file.sync_all().await.unwrap(); file.sync_data().await.unwrap(); - file.write_at(&b"foo"[..], 0).await.0.unwrap(); + file.write_at(&b"foo"[..], 0).submit().await.0.unwrap(); file.sync_all().await.unwrap(); file.sync_data().await.unwrap(); });