From 3d7f7f90176b288717e04f31fffc186f83636185 Mon Sep 17 00:00:00 2001 From: Friedel Ziegelmayer Date: Fri, 26 Apr 2024 19:47:04 +0200 Subject: [PATCH] net/tcp: Add poll_accept, poll_shared_accept and poll_shutdown (#533) * wip: add poll_accept and poll_shared_accept * add poll_shutdown * fixup: docs --------- Co-authored-by: Glauber Costa --- glommio/src/net/tcp_socket.rs | 72 ++++++++++++++++++++++++++++------- glommio/src/sys/source.rs | 19 ++++----- 2 files changed, 69 insertions(+), 22 deletions(-) diff --git a/glommio/src/net/tcp_socket.rs b/glommio/src/net/tcp_socket.rs index d03cf78f2..81488c0e3 100644 --- a/glommio/src/net/tcp_socket.rs +++ b/glommio/src/net/tcp_socket.rs @@ -10,17 +10,20 @@ use crate::{ yolo_accept, }, reactor::Reactor, + sys::Source, GlommioError, }; use futures_lite::{ future::poll_fn, io::{AsyncBufRead, AsyncRead, AsyncWrite}, + ready, stream::{self, Stream}, }; use nix::sys::socket::SockaddrStorage; use pin_project_lite::pin_project; use socket2::{Domain, Protocol, Socket, Type}; use std::{ + cell::RefCell, io, net::{self, Shutdown, SocketAddr, ToSocketAddrs}, os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}, @@ -75,6 +78,7 @@ type Result = crate::Result; pub struct TcpListener { reactor: Weak, listener: net::TcpListener, + current_source: RefCell>, } impl FromRawFd for TcpListener { @@ -86,6 +90,7 @@ impl FromRawFd for TcpListener { TcpListener { reactor: Rc::downgrade(&crate::executor().reactor()), listener, + current_source: Default::default(), } } } @@ -132,6 +137,7 @@ impl TcpListener { Ok(TcpListener { reactor: Rc::downgrade(&crate::executor().reactor()), listener, + current_source: Default::default(), }) } @@ -164,19 +170,38 @@ impl TcpListener { /// [`TcpStream`]: struct.TcpStream.html /// [`Send`]: https://doc.rust-lang.org/std/marker/trait.Send.html pub async fn shared_accept(&self) -> Result { - let reactor = self.reactor.upgrade().unwrap(); - let raw_fd = self.listener.as_raw_fd(); - if let Some(r) = yolo_accept(raw_fd) { - match r { - Ok(fd) => { - return Ok(AcceptedTcpStream { fd }); + poll_fn(|cx| self.poll_shared_accept(cx)).await + } + + /// Poll version of [`shared_accept`]. + /// + /// [`shared_accept`]: TcpListener::shared_accept + pub fn poll_shared_accept(&self, cx: &mut Context<'_>) -> Poll> { + let mut poll_source = |source: Source| match source.poll_collect_rw(cx) { + Poll::Ready(Ok(fd)) => Poll::Ready(Ok(AcceptedTcpStream { fd: fd as RawFd })), + Poll::Ready(Err(err)) => Poll::Ready(Err(GlommioError::IoError(err))), + Poll::Pending => { + *self.current_source.borrow_mut() = Some(source); + Poll::Pending + } + }; + match self.current_source.take() { + Some(source) => poll_source(source), + None => { + let reactor = self.reactor.upgrade().unwrap(); + let raw_fd = self.listener.as_raw_fd(); + match yolo_accept(raw_fd) { + Some(r) => match r { + Ok(fd) => Poll::Ready(Ok(AcceptedTcpStream { fd })), + Err(err) => Poll::Ready(Err(GlommioError::IoError(err))), + }, + None => { + let source = reactor.accept(self.listener.as_raw_fd()); + poll_source(source) + } } - Err(err) => return Err(GlommioError::IoError(err)), } } - let source = reactor.accept(self.listener.as_raw_fd()); - let fd = source.collect_rw().await?; - Ok(AcceptedTcpStream { fd: fd as RawFd }) } /// Accepts a new incoming TCP connection in this executor @@ -208,6 +233,19 @@ impl TcpListener { Ok(a.bind_to_executor()) } + /// Poll version of [`accept`]. + /// + /// [`accept`]: TcpListener::accept + pub fn poll_accept(&self, cx: &mut Context<'_>) -> Poll> { + match ready!(self.poll_shared_accept(cx)) { + Ok(a) => { + let a = a.bind_to_executor(); + Poll::Ready(Ok(a)) + } + Err(err) => Poll::Ready(Err(err)), + } + } + /// Creates a stream of incoming connections /// /// # Examples @@ -554,9 +592,17 @@ impl TcpStream { /// Shuts down the read, write, or both halves of this connection. pub async fn shutdown(&self, how: Shutdown) -> Result<()> { - poll_fn(|cx| self.stream.poll_shutdown(cx, how)) - .await - .map_err(Into::into) + poll_fn(|cx| self.poll_shutdown(cx, how)).await + } + + /// Polling version of [`shutdown`]. + /// + /// [`shutdown`]: TcpStream::shutdown + pub fn poll_shutdown(&self, cx: &mut Context<'_>, how: Shutdown) -> Poll> { + match ready!(self.stream.poll_shutdown(cx, how)) { + Ok(()) => Poll::Ready(Ok(())), + Err(err) => Poll::Ready(Err(err.into())), + } } /// Sets the value of the `TCP_NODELAY` option on this socket. diff --git a/glommio/src/sys/source.rs b/glommio/src/sys/source.rs index b4889fc40..55bf31c83 100644 --- a/glommio/src/sys/source.rs +++ b/glommio/src/sys/source.rs @@ -22,7 +22,7 @@ use std::{ path::PathBuf, pin::Pin, rc::Rc, - task::{Poll, Waker}, + task::{Context, Poll, Waker}, time::Duration, }; @@ -303,15 +303,16 @@ impl Source { } pub(crate) async fn collect_rw(&self) -> io::Result { - future::poll_fn(|cx| { - if let Some(result) = self.result() { - return Poll::Ready(result); - } + future::poll_fn(|cx| self.poll_collect_rw(cx)).await + } - self.add_waiter_many(cx.waker().clone()); - Poll::Pending - }) - .await + pub(crate) fn poll_collect_rw(&self, cx: &mut Context<'_>) -> Poll> { + if let Some(result) = self.result() { + return Poll::Ready(result); + } + + self.add_waiter_many(cx.waker().clone()); + Poll::Pending } }