Skip to content

Commit

Permalink
net/tcp: Add poll_accept, poll_shared_accept and poll_shutdown (#533)
Browse files Browse the repository at this point in the history
* wip: add poll_accept and poll_shared_accept

* add poll_shutdown

* fixup: docs

---------

Co-authored-by: Glauber Costa <[email protected]>
  • Loading branch information
dignifiedquire and glommer authored Apr 26, 2024
1 parent 562adb8 commit 3d7f7f9
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 22 deletions.
72 changes: 59 additions & 13 deletions glommio/src/net/tcp_socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,20 @@ use crate::{
yolo_accept,
},
reactor::Reactor,
sys::Source,
GlommioError,
};
use futures_lite::{
future::poll_fn,
io::{AsyncBufRead, AsyncRead, AsyncWrite},
ready,
stream::{self, Stream},
};
use nix::sys::socket::SockaddrStorage;
use pin_project_lite::pin_project;
use socket2::{Domain, Protocol, Socket, Type};
use std::{
cell::RefCell,
io,
net::{self, Shutdown, SocketAddr, ToSocketAddrs},
os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
Expand Down Expand Up @@ -75,6 +78,7 @@ type Result<T> = crate::Result<T, ()>;
pub struct TcpListener {
reactor: Weak<Reactor>,
listener: net::TcpListener,
current_source: RefCell<Option<Source>>,
}

impl FromRawFd for TcpListener {
Expand All @@ -86,6 +90,7 @@ impl FromRawFd for TcpListener {
TcpListener {
reactor: Rc::downgrade(&crate::executor().reactor()),
listener,
current_source: Default::default(),
}
}
}
Expand Down Expand Up @@ -132,6 +137,7 @@ impl TcpListener {
Ok(TcpListener {
reactor: Rc::downgrade(&crate::executor().reactor()),
listener,
current_source: Default::default(),
})
}

Expand Down Expand Up @@ -164,19 +170,38 @@ impl TcpListener {
/// [`TcpStream`]: struct.TcpStream.html
/// [`Send`]: https://doc.rust-lang.org/std/marker/trait.Send.html
pub async fn shared_accept(&self) -> Result<AcceptedTcpStream> {
let reactor = self.reactor.upgrade().unwrap();
let raw_fd = self.listener.as_raw_fd();
if let Some(r) = yolo_accept(raw_fd) {
match r {
Ok(fd) => {
return Ok(AcceptedTcpStream { fd });
poll_fn(|cx| self.poll_shared_accept(cx)).await
}

/// Poll version of [`shared_accept`].
///
/// [`shared_accept`]: TcpListener::shared_accept
pub fn poll_shared_accept(&self, cx: &mut Context<'_>) -> Poll<Result<AcceptedTcpStream>> {
let mut poll_source = |source: Source| match source.poll_collect_rw(cx) {
Poll::Ready(Ok(fd)) => Poll::Ready(Ok(AcceptedTcpStream { fd: fd as RawFd })),
Poll::Ready(Err(err)) => Poll::Ready(Err(GlommioError::IoError(err))),
Poll::Pending => {
*self.current_source.borrow_mut() = Some(source);
Poll::Pending
}
};
match self.current_source.take() {
Some(source) => poll_source(source),
None => {
let reactor = self.reactor.upgrade().unwrap();
let raw_fd = self.listener.as_raw_fd();
match yolo_accept(raw_fd) {
Some(r) => match r {
Ok(fd) => Poll::Ready(Ok(AcceptedTcpStream { fd })),
Err(err) => Poll::Ready(Err(GlommioError::IoError(err))),
},
None => {
let source = reactor.accept(self.listener.as_raw_fd());
poll_source(source)
}
}
Err(err) => return Err(GlommioError::IoError(err)),
}
}
let source = reactor.accept(self.listener.as_raw_fd());
let fd = source.collect_rw().await?;
Ok(AcceptedTcpStream { fd: fd as RawFd })
}

/// Accepts a new incoming TCP connection in this executor
Expand Down Expand Up @@ -208,6 +233,19 @@ impl TcpListener {
Ok(a.bind_to_executor())
}

/// Poll version of [`accept`].
///
/// [`accept`]: TcpListener::accept
pub fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<Result<TcpStream>> {
match ready!(self.poll_shared_accept(cx)) {
Ok(a) => {
let a = a.bind_to_executor();
Poll::Ready(Ok(a))
}
Err(err) => Poll::Ready(Err(err)),
}
}

/// Creates a stream of incoming connections
///
/// # Examples
Expand Down Expand Up @@ -554,9 +592,17 @@ impl<B: RxBuf> TcpStream<B> {

/// Shuts down the read, write, or both halves of this connection.
pub async fn shutdown(&self, how: Shutdown) -> Result<()> {
poll_fn(|cx| self.stream.poll_shutdown(cx, how))
.await
.map_err(Into::into)
poll_fn(|cx| self.poll_shutdown(cx, how)).await
}

/// Polling version of [`shutdown`].
///
/// [`shutdown`]: TcpStream::shutdown
pub fn poll_shutdown(&self, cx: &mut Context<'_>, how: Shutdown) -> Poll<Result<()>> {
match ready!(self.stream.poll_shutdown(cx, how)) {
Ok(()) => Poll::Ready(Ok(())),
Err(err) => Poll::Ready(Err(err.into())),
}
}

/// Sets the value of the `TCP_NODELAY` option on this socket.
Expand Down
19 changes: 10 additions & 9 deletions glommio/src/sys/source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use std::{
path::PathBuf,
pin::Pin,
rc::Rc,
task::{Poll, Waker},
task::{Context, Poll, Waker},
time::Duration,
};

Expand Down Expand Up @@ -303,15 +303,16 @@ impl Source {
}

pub(crate) async fn collect_rw(&self) -> io::Result<usize> {
future::poll_fn(|cx| {
if let Some(result) = self.result() {
return Poll::Ready(result);
}
future::poll_fn(|cx| self.poll_collect_rw(cx)).await
}

self.add_waiter_many(cx.waker().clone());
Poll::Pending
})
.await
pub(crate) fn poll_collect_rw(&self, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
if let Some(result) = self.result() {
return Poll::Ready(result);
}

self.add_waiter_many(cx.waker().clone());
Poll::Pending
}
}

Expand Down

0 comments on commit 3d7f7f9

Please sign in to comment.