Skip to content

Commit

Permalink
rt: add new api for oneshot op submission and creation (#244)
Browse files Browse the repository at this point in the history
This change adds a new API for creating and submitting oneshot
operations. This new API is intended to obsolete the existing system,
however transferring over has been left out of this PR to keep the
change small.

It is intended that a similar API for multishot operations be landed in
a followup PR as well. This was also left out of this PR to keep the
change small.

This refactoring paves the way for further work on SQE linking,
multishot operations, and other improvements and additions.

The goal of this change is to allow us to split up the oneshot and
multishot logic to allow for a cleaner and more extensible system, allow
for user-provided operations, and allow users to control when and how
their ops get submitted to the squeue.
  • Loading branch information
Noah-Kennedy authored Feb 21, 2023
1 parent 575d864 commit a021be7
Show file tree
Hide file tree
Showing 20 changed files with 312 additions and 108 deletions.
2 changes: 1 addition & 1 deletion examples/mix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ fn main() {
break;
}

let (res, b) = socket.write(b).await;
let (res, b) = socket.write(b).submit().await;
pos += res.unwrap() as u64;

buf = b;
Expand Down
2 changes: 1 addition & 1 deletion examples/tcp_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ fn main() {
let stream = TcpStream::connect(socket_addr).await.unwrap();
let buf = vec![1u8; 128];

let (result, buf) = stream.write(buf).await;
let (result, buf) = stream.write(buf).submit().await;
println!("written: {}", result.unwrap());

let (result, buf) = stream.read(buf).await;
Expand Down
2 changes: 1 addition & 1 deletion examples/unix_listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ fn main() {
tokio_uring::spawn(async move {
let buf = vec![1u8; 128];

let (result, buf) = stream.write(buf).await;
let (result, buf) = stream.write(buf).submit().await;
println!("written to {}: {}", &socket_addr, result.unwrap());

let (result, buf) = stream.read(buf).await;
Expand Down
2 changes: 1 addition & 1 deletion examples/unix_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ fn main() {
let stream = UnixStream::connect(socket_addr).await.unwrap();
let buf = vec![1u8; 128];

let (result, buf) = stream.write(buf).await;
let (result, buf) = stream.write(buf).submit().await;
println!("written: {}", result.unwrap());

let (result, buf) = stream.read(buf).await;
Expand Down
2 changes: 1 addition & 1 deletion examples/wrk-bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ fn main() -> io::Result<()> {
let (stream, _) = listener.accept().await?;

tokio_uring::spawn(async move {
let (result, _) = stream.write(RESPONSE).await;
let (result, _) = stream.write(RESPONSE).submit().await;

if let Err(err) = result {
eprintln!("Client connection failed: {}", err);
Expand Down
16 changes: 8 additions & 8 deletions src/fs/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::fs::OpenOptions;
use crate::io::SharedFd;

use crate::runtime::driver::op::Op;
use crate::{UnsubmittedOneshot, UnsubmittedWrite};
use std::fmt;
use std::io;
use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
Expand Down Expand Up @@ -38,7 +39,7 @@ use std::path::Path;
/// let file = File::create("hello.txt").await?;
///
/// // Write some data
/// let (res, buf) = file.write_at(&b"hello world"[..], 0).await;
/// let (res, buf) = file.write_at(&b"hello world"[..], 0).submit().await;
/// let n = res?;
///
/// println!("wrote {} bytes", n);
Expand Down Expand Up @@ -525,7 +526,7 @@ impl File {
/// let file = File::create("foo.txt").await?;
///
/// // Writes some prefix of the byte string, not necessarily all of it.
/// let (res, _) = file.write_at(&b"some bytes"[..], 0).await;
/// let (res, _) = file.write_at(&b"some bytes"[..], 0).submit().await;
/// let n = res?;
///
/// println!("wrote {} bytes", n);
Expand All @@ -538,9 +539,8 @@ impl File {
/// ```
///
/// [`Ok(n)`]: Ok
pub async fn write_at<T: BoundedBuf>(&self, buf: T, pos: u64) -> crate::BufResult<usize, T> {
let op = Op::write_at(&self.fd, buf, pos).unwrap();
op.await
pub fn write_at<T: BoundedBuf>(&self, buf: T, pos: u64) -> UnsubmittedWrite<T> {
UnsubmittedOneshot::write_at(&self.fd, buf, pos)
}

/// Attempts to write an entire buffer into this file at the specified offset.
Expand Down Expand Up @@ -609,7 +609,7 @@ impl File {
}

while buf.bytes_init() != 0 {
let (res, slice) = self.write_at(buf, pos).await;
let (res, slice) = self.write_at(buf, pos).submit().await;
match res {
Ok(0) => {
return (
Expand Down Expand Up @@ -773,7 +773,7 @@ impl File {
/// fn main() -> Result<(), Box<dyn std::error::Error>> {
/// tokio_uring::start(async {
/// let f = File::create("foo.txt").await?;
/// let (res, buf) = f.write_at(&b"Hello, world!"[..], 0).await;
/// let (res, buf) = f.write_at(&b"Hello, world!"[..], 0).submit().await;
/// let n = res?;
///
/// f.sync_all().await?;
Expand Down Expand Up @@ -810,7 +810,7 @@ impl File {
/// fn main() -> Result<(), Box<dyn std::error::Error>> {
/// tokio_uring::start(async {
/// let f = File::create("foo.txt").await?;
/// let (res, buf) = f.write_at(&b"Hello, world!"[..], 0).await;
/// let (res, buf) = f.write_at(&b"Hello, world!"[..], 0).submit().await;
/// let n = res?;
///
/// f.sync_data().await?;
Expand Down
2 changes: 1 addition & 1 deletion src/io/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ mod unlink_at;
mod util;
pub(crate) use util::cstr;

mod write;
pub(crate) mod write;

mod write_fixed;

Expand Down
9 changes: 5 additions & 4 deletions src/io/socket.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use crate::io::write::UnsubmittedWrite;
use crate::runtime::driver::op::Op;
use crate::{
buf::fixed::FixedBuf,
buf::{BoundedBuf, BoundedBufMut, IoBuf, Slice},
io::SharedFd,
UnsubmittedOneshot,
};
use std::{
io,
Expand Down Expand Up @@ -41,9 +43,8 @@ impl Socket {
Ok(Socket { fd })
}

pub(crate) async fn write<T: BoundedBuf>(&self, buf: T) -> crate::BufResult<usize, T> {
let op = Op::write_at(&self.fd, buf, 0).unwrap();
op.await
pub(crate) fn write<T: BoundedBuf>(&self, buf: T) -> UnsubmittedWrite<T> {
UnsubmittedOneshot::write_at(&self.fd, buf, 0)
}

pub async fn write_all<T: BoundedBuf>(&self, buf: T) -> crate::BufResult<(), T> {
Expand All @@ -54,7 +55,7 @@ impl Socket {

async fn write_all_slice<T: IoBuf>(&self, mut buf: Slice<T>) -> crate::BufResult<(), T> {
while buf.bytes_init() != 0 {
let res = self.write(buf).await;
let res = self.write(buf).submit().await;
match res {
(Ok(0), slice) => {
return (
Expand Down
79 changes: 44 additions & 35 deletions src/io/write.rs
Original file line number Diff line number Diff line change
@@ -1,50 +1,59 @@
use crate::runtime::driver::op::{Completable, CqeResult, Op};
use crate::runtime::CONTEXT;
use crate::{buf::BoundedBuf, io::SharedFd, BufResult};
use crate::{buf::BoundedBuf, io::SharedFd, BufResult, OneshotOutputTransform, UnsubmittedOneshot};
use io_uring::cqueue::Entry;
use std::io;
use std::marker::PhantomData;

pub(crate) struct Write<T> {
/// An unsubmitted write operation.
pub type UnsubmittedWrite<T> = UnsubmittedOneshot<WriteData<T>, WriteTransform<T>>;

#[allow(missing_docs)]
pub struct WriteData<T> {
/// Holds a strong ref to the FD, preventing the file from being closed
/// while the operation is in-flight.
#[allow(dead_code)]
fd: SharedFd,
_fd: SharedFd,

buf: T,
}

impl<T: BoundedBuf> Op<Write<T>> {
pub(crate) fn write_at(fd: &SharedFd, buf: T, offset: u64) -> io::Result<Op<Write<T>>> {
use io_uring::{opcode, types};

CONTEXT.with(|x| {
x.handle().expect("Not in a runtime context").submit_op(
Write {
fd: fd.clone(),
buf,
},
|write| {
// Get raw buffer info
let ptr = write.buf.stable_ptr();
let len = write.buf.bytes_init();

opcode::Write::new(types::Fd(fd.raw_fd()), ptr, len as _)
.offset(offset as _)
.build()
},
)
})
}
#[allow(missing_docs)]
pub struct WriteTransform<T> {
_phantom: PhantomData<T>,
}

impl<T> Completable for Write<T> {
impl<T> OneshotOutputTransform for WriteTransform<T> {
type Output = BufResult<usize, T>;
type StoredData = WriteData<T>;

fn transform_oneshot_output(self, data: Self::StoredData, cqe: Entry) -> Self::Output {
let res = if cqe.result() >= 0 {
Ok(cqe.result() as usize)
} else {
Err(io::Error::from_raw_os_error(cqe.result()))
};

fn complete(self, cqe: CqeResult) -> Self::Output {
// Convert the operation result to `usize`
let res = cqe.result.map(|v| v as usize);
// Recover the buffer
let buf = self.buf;
(res, data.buf)
}
}

impl<T: BoundedBuf> UnsubmittedWrite<T> {
pub(crate) fn write_at(fd: &SharedFd, buf: T, offset: u64) -> Self {
use io_uring::{opcode, types};

(res, buf)
// Get raw buffer info
let ptr = buf.stable_ptr();
let len = buf.bytes_init();

Self::new(
WriteData {
_fd: fd.clone(),
buf,
},
WriteTransform {
_phantom: PhantomData::default(),
},
opcode::Write::new(types::Fd(fd.raw_fd()), ptr, len as _)
.offset(offset as _)
.build(),
)
}
}
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ pub mod buf;
pub mod fs;
pub mod net;

pub use io::write::*;
pub use runtime::driver::op::{InFlightOneshot, OneshotOutputTransform, UnsubmittedOneshot};
pub use runtime::spawn;
pub use runtime::Runtime;

Expand Down
2 changes: 1 addition & 1 deletion src/net/tcp/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use std::{io, net::SocketAddr};
/// let tx = TcpStream::connect("127.0.0.1:2345".parse().unwrap()).await.unwrap();
/// let rx = rx_ch.await.expect("The spawned task expected to send a TcpStream");
///
/// tx.write(b"test" as &'static [u8]).await.0.unwrap();
/// tx.write(b"test" as &'static [u8]).submit().await.0.unwrap();
///
/// let (_, buf) = rx.read(vec![0; 4]).await;
///
Expand Down
7 changes: 4 additions & 3 deletions src/net/tcp/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crate::{
buf::fixed::FixedBuf,
buf::{BoundedBuf, BoundedBufMut},
io::{SharedFd, Socket},
UnsubmittedWrite,
};

/// A TCP stream between a local and a remote socket.
Expand All @@ -27,7 +28,7 @@ use crate::{
/// let mut stream = TcpStream::connect("127.0.0.1:8080".parse().unwrap()).await?;
///
/// // Write some data.
/// let (result, _) = stream.write(b"hello world!".as_slice()).await;
/// let (result, _) = stream.write(b"hello world!".as_slice()).submit().await;
/// result.unwrap();
///
/// Ok(())
Expand Down Expand Up @@ -100,8 +101,8 @@ impl TcpStream {
/// Write some data to the stream from the buffer.
///
/// Returns the original buffer and quantity of data written.
pub async fn write<T: BoundedBuf>(&self, buf: T) -> crate::BufResult<usize, T> {
self.inner.write(buf).await
pub fn write<T: BoundedBuf>(&self, buf: T) -> UnsubmittedWrite<T> {
self.inner.write(buf)
}

/// Attempts to write an entire buffer to the stream.
Expand Down
7 changes: 4 additions & 3 deletions src/net/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::{
buf::fixed::FixedBuf,
buf::{BoundedBuf, BoundedBufMut},
io::{SharedFd, Socket},
UnsubmittedWrite,
};
use socket2::SockAddr;
use std::{
Expand Down Expand Up @@ -42,7 +43,7 @@ use std::{
/// let buf = vec![0; 32];
///
/// // write data
/// let (result, _) = socket.write(b"hello world".as_slice()).await;
/// let (result, _) = socket.write(b"hello world".as_slice()).submit().await;
/// result.unwrap();
///
/// // read data
Expand Down Expand Up @@ -312,8 +313,8 @@ impl UdpSocket {
/// Writes data into the socket from the specified buffer.
///
/// Returns the original buffer and quantity of data written.
pub async fn write<T: BoundedBuf>(&self, buf: T) -> crate::BufResult<usize, T> {
self.inner.write(buf).await
pub fn write<T: BoundedBuf>(&self, buf: T) -> UnsubmittedWrite<T> {
self.inner.write(buf)
}

/// Writes data into the socket from a registered buffer.
Expand Down
2 changes: 1 addition & 1 deletion src/net/unix/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use std::{io, path::Path};
/// let tx = UnixStream::connect(&sock_file).await.unwrap();
/// let rx = rx_ch.await.expect("The spawned task expected to send a UnixStream");
///
/// tx.write(b"test" as &'static [u8]).await.0.unwrap();
/// tx.write(b"test" as &'static [u8]).submit().await.0.unwrap();
///
/// let (_, buf) = rx.read(vec![0; 4]).await;
///
Expand Down
7 changes: 4 additions & 3 deletions src/net/unix/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::{
buf::fixed::FixedBuf,
buf::{BoundedBuf, BoundedBufMut},
io::{SharedFd, Socket},
UnsubmittedWrite,
};
use socket2::SockAddr;
use std::{
Expand All @@ -27,7 +28,7 @@ use std::{
/// let mut stream = UnixStream::connect("/tmp/tokio-uring-unix-test.sock").await?;
///
/// // Write some data.
/// let (result, _) = stream.write(b"hello world!".as_slice()).await;
/// let (result, _) = stream.write(b"hello world!".as_slice()).submit().await;
/// result.unwrap();
///
/// Ok(())
Expand Down Expand Up @@ -98,8 +99,8 @@ impl UnixStream {

/// Write some data to the stream from the buffer, returning the original buffer and
/// quantity of data written.
pub async fn write<T: BoundedBuf>(&self, buf: T) -> crate::BufResult<usize, T> {
self.inner.write(buf).await
pub fn write<T: BoundedBuf>(&self, buf: T) -> UnsubmittedWrite<T> {
self.inner.write(buf)
}

/// Attempts to write an entire buffer to the stream.
Expand Down
14 changes: 13 additions & 1 deletion src/runtime/driver/handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
//! The weak handle should be used by anything which is stored in the driver or does not need to
//! keep the driver alive for it's duration.

use io_uring::squeue;
use io_uring::{cqueue, squeue};
use std::cell::RefCell;
use std::io;
use std::ops::Deref;
Expand Down Expand Up @@ -63,6 +63,10 @@ impl Handle {
self.inner.borrow_mut().unregister_buffers(buffers)
}

pub(crate) fn submit_op_2(&self, sqe: squeue::Entry) -> usize {
self.inner.borrow_mut().submit_op_2(sqe)
}

pub(crate) fn submit_op<T, S, F>(&self, data: T, f: F) -> io::Result<Op<T, S>>
where
T: Completable,
Expand All @@ -78,6 +82,10 @@ impl Handle {
self.inner.borrow_mut().poll_op(op, cx)
}

pub(crate) fn poll_op_2(&self, index: usize, cx: &mut Context<'_>) -> Poll<cqueue::Entry> {
self.inner.borrow_mut().poll_op_2(index, cx)
}

pub(crate) fn poll_multishot_op<T>(
&self,
op: &mut Op<T, MultiCQEFuture>,
Expand All @@ -92,6 +100,10 @@ impl Handle {
pub(crate) fn remove_op<T, CqeType>(&self, op: &mut Op<T, CqeType>) {
self.inner.borrow_mut().remove_op(op)
}

pub(crate) fn remove_op_2<T: 'static>(&self, index: usize, data: T) {
self.inner.borrow_mut().remove_op_2(index, data)
}
}

impl WeakHandle {
Expand Down
Loading

0 comments on commit a021be7

Please sign in to comment.