From 5dc65444d27c92d91c4f189ea8e6e57769f74f26 Mon Sep 17 00:00:00 2001
From: John Nunley <dev@notgull.net>
Date: Mon, 15 Jan 2024 10:27:39 -0800
Subject: [PATCH 1/2] ex: Use smol-hyper in hyper-client example

Signed-off-by: John Nunley <dev@notgull.net>
---
 Cargo.toml               |   6 +-
 examples/hyper-client.rs | 215 ++++++++++++++++-----------------------
 2 files changed, 95 insertions(+), 126 deletions(-)

diff --git a/Cargo.toml b/Cargo.toml
index 4888b7a..cf92da9 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -37,11 +37,15 @@ ctrlc = "3"
 doc-comment = "0.3"
 futures = "0.3"
 http = "0.2"
+http-body-util = "0.1.0"
 http-types = "2"
-hyper = { version = "0.14", default-features = false, features = ["client", "http1", "server", "stream"] }
+hyper = { version = "1.0", default-features = false, features = ["client", "http1", "server"] }
+macro_rules_attribute = "0.2.0"
 native-tls = "0.2"
 scraper = "0.18"
 signal-hook = "0.3"
+smol-hyper = "0.1.0"
+smol-macros = "0.1.0"
 surf = { version = "2", default-features = false, features = ["h1-client"] }
 tempfile = "3"
 tokio = { version = "1", default-features = false, features = ["rt-multi-thread"] }
diff --git a/examples/hyper-client.rs b/examples/hyper-client.rs
index 63676cf..bd00997 100644
--- a/examples/hyper-client.rs
+++ b/examples/hyper-client.rs
@@ -6,175 +6,140 @@
 //! cargo run --example hyper-client
 //! ```
 
-use std::net::Shutdown;
-use std::net::{TcpStream, ToSocketAddrs};
+use std::convert::TryInto;
 use std::pin::Pin;
+use std::sync::Arc;
 use std::task::{Context, Poll};
 
-use anyhow::{bail, Context as _, Error, Result};
+use anyhow::{bail, Context as _, Result};
 use async_native_tls::TlsStream;
-use http::Uri;
-use hyper::{Body, Client, Request, Response};
-use smol::{io, prelude::*, Async};
+use http_body_util::{BodyStream, Empty};
+use hyper::body::Incoming;
+use hyper::{Request, Response};
+use macro_rules_attribute::apply;
+use smol::{io, net::TcpStream, prelude::*, Executor};
+use smol_hyper::rt::FuturesIo;
+use smol_macros::main;
 
 /// Sends a request and fetches the response.
-async fn fetch(req: Request<Body>) -> Result<Response<Body>> {
-    Ok(Client::builder()
-        .executor(SmolExecutor)
-        .build::<_, Body>(SmolConnector)
-        .request(req)
-        .await?)
-}
+async fn fetch(
+    ex: &Arc<Executor<'static>>,
+    req: Request<Empty<&'static [u8]>>,
+) -> Result<Response<Incoming>> {
+    // Connect to the HTTP server.
+    let io = {
+        let host = req.uri().host().context("cannot parse host")?;
+
+        match req.uri().scheme_str() {
+            Some("http") => {
+                let stream = {
+                    let port = req.uri().port_u16().unwrap_or(80);
+                    TcpStream::connect((host, port)).await?
+                };
+                SmolStream::Plain(stream)
+            }
+            Some("https") => {
+                // In case of HTTPS, establish a secure TLS connection first.
+                let stream = {
+                    let port = req.uri().port_u16().unwrap_or(443);
+                    TcpStream::connect((host, port)).await?
+                };
+                let stream = async_native_tls::connect(host, stream).await?;
+                SmolStream::Tls(stream)
+            }
+            scheme => bail!("unsupported scheme: {:?}", scheme),
+        }
+    };
 
-fn main() -> Result<()> {
-    smol::block_on(async {
-        // Create a request.
-        let req = Request::get("https://www.rust-lang.org").body(Body::empty())?;
-
-        // Fetch the response.
-        let resp = fetch(req).await?;
-        println!("{:#?}", resp);
-
-        // Read the message body.
-        let body = resp
-            .into_body()
-            .try_fold(Vec::new(), |mut body, chunk| {
-                body.extend_from_slice(&chunk);
-                Ok(body)
-            })
-            .await?;
-        println!("{}", String::from_utf8_lossy(&body));
-
-        Ok(())
+    // Spawn the HTTP/1 connection.
+    let (mut sender, conn) = hyper::client::conn::http1::handshake(FuturesIo::new(io)).await?;
+    ex.spawn(async move {
+        if let Err(e) = conn.await {
+            println!("Connection failed: {:?}", e);
+        }
     })
-}
+    .detach();
 
-/// Spawns futures.
-#[derive(Clone)]
-struct SmolExecutor;
-
-impl<F: Future + Send + 'static> hyper::rt::Executor<F> for SmolExecutor {
-    fn execute(&self, fut: F) {
-        smol::spawn(async { drop(fut.await) }).detach();
-    }
+    // Get the result
+    let result = sender.send_request(req).await?;
+    Ok(result)
 }
 
-/// Connects to URLs.
-#[derive(Clone)]
-struct SmolConnector;
-
-impl hyper::service::Service<Uri> for SmolConnector {
-    type Response = SmolStream;
-    type Error = Error;
-    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
-
-    fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
-        Poll::Ready(Ok(()))
-    }
-
-    fn call(&mut self, uri: Uri) -> Self::Future {
-        Box::pin(async move {
-            let host = uri.host().context("cannot parse host")?;
-
-            match uri.scheme_str() {
-                Some("http") => {
-                    let socket_addr = {
-                        let host = host.to_string();
-                        let port = uri.port_u16().unwrap_or(80);
-                        smol::unblock(move || (host.as_str(), port).to_socket_addrs())
-                            .await?
-                            .next()
-                            .context("cannot resolve address")?
-                    };
-                    let stream = Async::<TcpStream>::connect(socket_addr).await?;
-                    Ok(SmolStream::Plain(stream))
-                }
-                Some("https") => {
-                    // In case of HTTPS, establish a secure TLS connection first.
-                    let socket_addr = {
-                        let host = host.to_string();
-                        let port = uri.port_u16().unwrap_or(443);
-                        smol::unblock(move || (host.as_str(), port).to_socket_addrs())
-                            .await?
-                            .next()
-                            .context("cannot resolve address")?
-                    };
-                    let stream = Async::<TcpStream>::connect(socket_addr).await?;
-                    let stream = async_native_tls::connect(host, stream).await?;
-                    Ok(SmolStream::Tls(stream))
-                }
-                scheme => bail!("unsupported scheme: {:?}", scheme),
+#[apply(main!)]
+async fn main(ex: Arc<Executor<'static>>) -> Result<()> {
+    // Create a request.
+    let url: hyper::Uri = "https://www.rust-lang.org".try_into()?;
+    let req = Request::builder()
+        .header(
+            hyper::header::HOST,
+            url.authority().unwrap().clone().as_str(),
+        )
+        .uri(url)
+        .body(Empty::new())?;
+
+    // Fetch the response.
+    let resp = fetch(&ex, req).await?;
+    println!("{:#?}", resp);
+
+    // Read the message body.
+    let body: Vec<u8> = BodyStream::new(resp.into_body())
+        .try_fold(Vec::new(), |mut body, chunk| {
+            if let Some(chunk) = chunk.data_ref() {
+                body.extend_from_slice(chunk);
             }
+            Ok(body)
         })
-    }
+        .await?;
+    println!("{}", String::from_utf8_lossy(&body));
+
+    Ok(())
 }
 
 /// A TCP or TCP+TLS connection.
 enum SmolStream {
     /// A plain TCP connection.
-    Plain(Async<TcpStream>),
+    Plain(TcpStream),
 
     /// A TCP connection secured by TLS.
-    Tls(TlsStream<Async<TcpStream>>),
-}
-
-impl hyper::client::connect::Connection for SmolStream {
-    fn connected(&self) -> hyper::client::connect::Connected {
-        hyper::client::connect::Connected::new()
-    }
+    Tls(TlsStream<TcpStream>),
 }
 
-impl tokio::io::AsyncRead for SmolStream {
+impl AsyncRead for SmolStream {
     fn poll_read(
         mut self: Pin<&mut Self>,
         cx: &mut Context<'_>,
-        buf: &mut tokio::io::ReadBuf<'_>,
-    ) -> Poll<io::Result<()>> {
+        buf: &mut [u8],
+    ) -> Poll<io::Result<usize>> {
         match &mut *self {
-            SmolStream::Plain(s) => {
-                Pin::new(s)
-                    .poll_read(cx, buf.initialize_unfilled())
-                    .map_ok(|size| {
-                        buf.advance(size);
-                    })
-            }
-            SmolStream::Tls(s) => {
-                Pin::new(s)
-                    .poll_read(cx, buf.initialize_unfilled())
-                    .map_ok(|size| {
-                        buf.advance(size);
-                    })
-            }
+            SmolStream::Plain(stream) => Pin::new(stream).poll_read(cx, buf),
+            SmolStream::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
         }
     }
 }
 
-impl tokio::io::AsyncWrite for SmolStream {
+impl AsyncWrite for SmolStream {
     fn poll_write(
         mut self: Pin<&mut Self>,
         cx: &mut Context<'_>,
         buf: &[u8],
     ) -> Poll<io::Result<usize>> {
         match &mut *self {
-            SmolStream::Plain(s) => Pin::new(s).poll_write(cx, buf),
-            SmolStream::Tls(s) => Pin::new(s).poll_write(cx, buf),
+            SmolStream::Plain(stream) => Pin::new(stream).poll_write(cx, buf),
+            SmolStream::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
         }
     }
 
-    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
+    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
         match &mut *self {
-            SmolStream::Plain(s) => Pin::new(s).poll_flush(cx),
-            SmolStream::Tls(s) => Pin::new(s).poll_flush(cx),
+            SmolStream::Plain(stream) => Pin::new(stream).poll_close(cx),
+            SmolStream::Tls(stream) => Pin::new(stream).poll_close(cx),
         }
     }
 
-    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
+    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
         match &mut *self {
-            SmolStream::Plain(s) => {
-                s.get_ref().shutdown(Shutdown::Write)?;
-                Poll::Ready(Ok(()))
-            }
-            SmolStream::Tls(s) => Pin::new(s).poll_close(cx),
+            SmolStream::Plain(stream) => Pin::new(stream).poll_flush(cx),
+            SmolStream::Tls(stream) => Pin::new(stream).poll_flush(cx),
         }
     }
 }

From 76b46861ddf0eed271342b9da606bc8acbe8e2ad Mon Sep 17 00:00:00 2001
From: John Nunley <dev@notgull.net>
Date: Mon, 15 Jan 2024 10:47:14 -0800
Subject: [PATCH 2/2] ex: Use smol-hyper in hyper-server example

Signed-off-by: John Nunley <dev@notgull.net>
---
 examples/hyper-client.rs |   7 +-
 examples/hyper-server.rs | 219 +++++++++++++++------------------------
 2 files changed, 86 insertions(+), 140 deletions(-)

diff --git a/examples/hyper-client.rs b/examples/hyper-client.rs
index bd00997..bb09f2a 100644
--- a/examples/hyper-client.rs
+++ b/examples/hyper-client.rs
@@ -8,7 +8,6 @@
 
 use std::convert::TryInto;
 use std::pin::Pin;
-use std::sync::Arc;
 use std::task::{Context, Poll};
 
 use anyhow::{bail, Context as _, Result};
@@ -23,7 +22,7 @@ use smol_macros::main;
 
 /// Sends a request and fetches the response.
 async fn fetch(
-    ex: &Arc<Executor<'static>>,
+    ex: &Executor<'static>,
     req: Request<Empty<&'static [u8]>>,
 ) -> Result<Response<Incoming>> {
     // Connect to the HTTP server.
@@ -66,7 +65,7 @@ async fn fetch(
 }
 
 #[apply(main!)]
-async fn main(ex: Arc<Executor<'static>>) -> Result<()> {
+async fn main(ex: &Executor<'static>) -> Result<()> {
     // Create a request.
     let url: hyper::Uri = "https://www.rust-lang.org".try_into()?;
     let req = Request::builder()
@@ -78,7 +77,7 @@ async fn main(ex: Arc<Executor<'static>>) -> Result<()> {
         .body(Empty::new())?;
 
     // Fetch the response.
-    let resp = fetch(&ex, req).await?;
+    let resp = fetch(ex, req).await?;
     println!("{:#?}", resp);
 
     // Read the message body.
diff --git a/examples/hyper-server.rs b/examples/hyper-server.rs
index d086d21..90473ae 100644
--- a/examples/hyper-server.rs
+++ b/examples/hyper-server.rs
@@ -13,24 +13,54 @@
 //!
 //! Refer to `README.md` to see how to the TLS certificate was generated.
 
-use std::net::{Shutdown, TcpListener, TcpStream};
+use std::net::{TcpListener, TcpStream};
 use std::pin::Pin;
+use std::sync::Arc;
 use std::task::{Context, Poll};
 
-use anyhow::{Error, Result};
+use anyhow::Result;
 use async_native_tls::{Identity, TlsAcceptor, TlsStream};
-use hyper::service::{make_service_fn, service_fn};
-use hyper::{Body, Request, Response, Server};
-use smol::{future, io, prelude::*, Async};
+use http_body_util::Full;
+use hyper::body::Incoming;
+use hyper::service::service_fn;
+use hyper::{Request, Response};
+use macro_rules_attribute::apply;
+use smol::{future, io, prelude::*, Async, Executor};
+use smol_hyper::rt::{FuturesIo, SmolTimer};
+use smol_macros::main;
 
 /// Serves a request and returns a response.
-async fn serve(req: Request<Body>, host: String) -> Result<Response<Body>> {
-    println!("Serving {}{}", host, req.uri());
-    Ok(Response::new(Body::from("Hello from hyper!")))
+async fn serve(req: Request<Incoming>) -> Result<Response<Full<&'static [u8]>>> {
+    println!("Serving {}", req.uri());
+    Ok(Response::new(Full::new("Hello from hyper!".as_bytes())))
+}
+
+/// Handle a new client.
+async fn handle_client(client: Async<TcpStream>, tls: Option<TlsAcceptor>) -> Result<()> {
+    // Wrap it in TLS if necessary.
+    let client = match &tls {
+        None => SmolStream::Plain(client),
+        Some(tls) => {
+            // In case of HTTPS, establish a secure TLS connection.
+            SmolStream::Tls(tls.accept(client).await?)
+        }
+    };
+
+    // Build the server.
+    hyper::server::conn::http1::Builder::new()
+        .timer(SmolTimer::new())
+        .serve_connection(FuturesIo::new(client), service_fn(serve))
+        .await?;
+
+    Ok(())
 }
 
 /// Listens for incoming connections and serves them.
-async fn listen(listener: Async<TcpListener>, tls: Option<TlsAcceptor>) -> Result<()> {
+async fn listen(
+    ex: &Arc<Executor<'static>>,
+    listener: Async<TcpListener>,
+    tls: Option<TlsAcceptor>,
+) -> Result<()> {
     // Format the full host address.
     let host = &match tls {
         None => format!("http://{}", listener.get_ref().local_addr()?),
@@ -38,86 +68,42 @@ async fn listen(listener: Async<TcpListener>, tls: Option<TlsAcceptor>) -> Resul
     };
     println!("Listening on {}", host);
 
-    // Start a hyper server.
-    Server::builder(SmolListener::new(&listener, tls))
-        .executor(SmolExecutor)
-        .serve(make_service_fn(move |_| {
-            let host = host.clone();
-            async { Ok::<_, Error>(service_fn(move |req| serve(req, host.clone()))) }
-        }))
-        .await?;
+    loop {
+        // Wait for a new client.
+        let (client, _) = listener.accept().await?;
 
-    Ok(())
+        // Spawn a task to handle this connection.
+        ex.spawn({
+            let tls = tls.clone();
+            async move {
+                if let Err(e) = handle_client(client, tls).await {
+                    println!("Error while handling client: {}", e);
+                }
+            }
+        })
+        .detach();
+    }
 }
 
-fn main() -> Result<()> {
+#[apply(main!)]
+async fn main(ex: &Arc<Executor<'static>>) -> Result<()> {
     // Initialize TLS with the local certificate, private key, and password.
     let identity = Identity::from_pkcs12(include_bytes!("identity.pfx"), "password")?;
     let tls = TlsAcceptor::from(native_tls::TlsAcceptor::new(identity)?);
 
     // Start HTTP and HTTPS servers.
-    smol::block_on(async {
-        let http = listen(Async::<TcpListener>::bind(([127, 0, 0, 1], 8000))?, None);
-        let https = listen(
-            Async::<TcpListener>::bind(([127, 0, 0, 1], 8001))?,
-            Some(tls),
-        );
-        future::try_zip(http, https).await?;
-        Ok(())
-    })
-}
-
-/// Spawns futures.
-#[derive(Clone)]
-struct SmolExecutor;
-
-impl<F: Future + Send + 'static> hyper::rt::Executor<F> for SmolExecutor {
-    fn execute(&self, fut: F) {
-        smol::spawn(async { drop(fut.await) }).detach();
-    }
-}
-
-/// Listens for incoming connections.
-struct SmolListener<'a> {
-    tls: Option<TlsAcceptor>,
-    incoming: Pin<Box<dyn Stream<Item = io::Result<Async<TcpStream>>> + Send + 'a>>,
-}
-
-impl<'a> SmolListener<'a> {
-    fn new(listener: &'a Async<TcpListener>, tls: Option<TlsAcceptor>) -> Self {
-        Self {
-            incoming: Box::pin(listener.incoming()),
-            tls,
-        }
-    }
-}
-
-impl hyper::server::accept::Accept for SmolListener<'_> {
-    type Conn = SmolStream;
-    type Error = Error;
-
-    fn poll_accept(
-        mut self: Pin<&mut Self>,
-        cx: &mut Context,
-    ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
-        let stream = smol::ready!(self.incoming.as_mut().poll_next(cx)).unwrap()?;
-
-        let stream = match &self.tls {
-            None => SmolStream::Plain(stream),
-            Some(tls) => {
-                // In case of HTTPS, start establishing a secure TLS connection.
-                let tls = tls.clone();
-                SmolStream::Handshake(Box::pin(async move {
-                    tls.accept(stream).await.map_err(|err| {
-                        println!("Failed to establish secure TLS connection: {:#?}", err);
-                        io::Error::new(io::ErrorKind::Other, Box::new(err))
-                    })
-                }))
-            }
-        };
-
-        Poll::Ready(Some(Ok(stream)))
-    }
+    let http = listen(
+        ex,
+        Async::<TcpListener>::bind(([127, 0, 0, 1], 8000))?,
+        None,
+    );
+    let https = listen(
+        ex,
+        Async::<TcpListener>::bind(([127, 0, 0, 1], 8001))?,
+        Some(tls),
+    );
+    future::try_zip(http, https).await?;
+    Ok(())
 }
 
 /// A TCP or TCP+TLS connection.
@@ -127,83 +113,44 @@ enum SmolStream {
 
     /// A TCP connection secured by TLS.
     Tls(TlsStream<Async<TcpStream>>),
-
-    /// A TCP connection that is in process of getting secured by TLS.
-    #[allow(clippy::type_complexity)]
-    Handshake(Pin<Box<dyn Future<Output = io::Result<TlsStream<Async<TcpStream>>>> + Send>>),
-}
-
-impl hyper::client::connect::Connection for SmolStream {
-    fn connected(&self) -> hyper::client::connect::Connected {
-        hyper::client::connect::Connected::new()
-    }
 }
 
-impl tokio::io::AsyncRead for SmolStream {
+impl AsyncRead for SmolStream {
     fn poll_read(
         mut self: Pin<&mut Self>,
         cx: &mut Context<'_>,
-        buf: &mut tokio::io::ReadBuf<'_>,
-    ) -> Poll<io::Result<()>> {
-        loop {
-            match &mut *self {
-                SmolStream::Plain(s) => {
-                    return Pin::new(s)
-                        .poll_read(cx, buf.initialize_unfilled())
-                        .map_ok(|size| {
-                            buf.advance(size);
-                        });
-                }
-                SmolStream::Tls(s) => {
-                    return Pin::new(s)
-                        .poll_read(cx, buf.initialize_unfilled())
-                        .map_ok(|size| {
-                            buf.advance(size);
-                        });
-                }
-                SmolStream::Handshake(f) => {
-                    let s = smol::ready!(f.as_mut().poll(cx))?;
-                    *self = SmolStream::Tls(s);
-                }
-            }
+        buf: &mut [u8],
+    ) -> Poll<io::Result<usize>> {
+        match &mut *self {
+            Self::Plain(s) => Pin::new(s).poll_read(cx, buf),
+            Self::Tls(s) => Pin::new(s).poll_read(cx, buf),
         }
     }
 }
 
-impl tokio::io::AsyncWrite for SmolStream {
+impl AsyncWrite for SmolStream {
     fn poll_write(
         mut self: Pin<&mut Self>,
         cx: &mut Context<'_>,
         buf: &[u8],
     ) -> Poll<io::Result<usize>> {
-        loop {
-            match &mut *self {
-                SmolStream::Plain(s) => return Pin::new(s).poll_write(cx, buf),
-                SmolStream::Tls(s) => return Pin::new(s).poll_write(cx, buf),
-                SmolStream::Handshake(f) => {
-                    let s = smol::ready!(f.as_mut().poll(cx))?;
-                    *self = SmolStream::Tls(s);
-                }
-            }
+        match &mut *self {
+            Self::Plain(s) => Pin::new(s).poll_write(cx, buf),
+            Self::Tls(s) => Pin::new(s).poll_write(cx, buf),
         }
     }
 
-    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
+    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
         match &mut *self {
-            SmolStream::Plain(s) => Pin::new(s).poll_flush(cx),
-            SmolStream::Tls(s) => Pin::new(s).poll_flush(cx),
-            SmolStream::Handshake(_) => Poll::Ready(Ok(())),
+            Self::Plain(s) => Pin::new(s).poll_close(cx),
+            Self::Tls(s) => Pin::new(s).poll_close(cx),
         }
     }
 
-    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
+    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
         match &mut *self {
-            SmolStream::Plain(s) => {
-                s.get_ref().shutdown(Shutdown::Write)?;
-                Poll::Ready(Ok(()))
-            }
-            SmolStream::Tls(s) => Pin::new(s).poll_close(cx),
-            SmolStream::Handshake(_) => Poll::Ready(Ok(())),
+            Self::Plain(s) => Pin::new(s).poll_close(cx),
+            Self::Tls(s) => Pin::new(s).poll_close(cx),
         }
     }
 }