From 0a9eee26b08b2ed4d5019c761e9ed4171f597384 Mon Sep 17 00:00:00 2001 From: timofey Date: Sun, 10 Mar 2024 01:29:42 +0000 Subject: [PATCH] CloseAll --- Cargo.toml | 2 +- src/close_all.rs | 163 +++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 1 + 3 files changed, 165 insertions(+), 1 deletion(-) create mode 100644 src/close_all.rs diff --git a/Cargo.toml b/Cargo.toml index 2cc0ef3..adbbd83 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ members = ["internal/ruchei-sample"] [workspace.package] -version = "0.0.71" # ad7038ef3b571dc133c108e14e6bb0f8cdcd812d and earlier have invalid versions +version = "0.0.72" # ad7038ef3b571dc133c108e14e6bb0f8cdcd812d and earlier have invalid versions edition = "2021" publish = true license = "MIT OR Apache-2.0" diff --git a/src/close_all.rs b/src/close_all.rs new file mode 100644 index 0000000..4020ea1 --- /dev/null +++ b/src/close_all.rs @@ -0,0 +1,163 @@ +use std::{ + marker::PhantomData, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; + +use futures_util::{ + future::FusedFuture, + lock::{Mutex, OwnedMutexGuard, OwnedMutexLockFuture}, + ready, + stream::FusedStream, + Future, Sink, Stream, +}; +use pin_project::pin_project; + +struct Closed; + +/// Yielded by [`CloseAll`]. Gets closed when incoming stream terminates. +#[pin_project] +pub struct CloseOne { + #[pin] + stream: S, + #[pin] + closing: OwnedMutexLockFuture, + terminated: bool, + _out: PhantomData, +} + +impl> + Sink> Stream + for CloseOne +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + if *this.terminated { + Poll::Ready(None) + } else if this.closing.is_terminated() || this.closing.poll(cx).is_ready() { + let r = ready!(this.stream.poll_close(cx)); + *this.terminated = true; + r?; + Poll::Ready(None) + } else { + Poll::Ready(match ready!(this.stream.poll_next(cx)) { + Some(Ok(item)) => Some(Ok(item)), + Some(Err(e)) => { + *this.terminated = true; + Some(Err(e)) + } + None => { + *this.terminated = true; + None + } + }) + } + } +} + +impl> + Sink> FusedStream + for CloseOne +{ + fn is_terminated(&self) -> bool { + self.terminated + } +} + +impl> + Sink> Sink + for CloseOne +{ + type Error = E; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + if this.closing.is_terminated() { + Poll::Pending + } else { + this.stream.poll_ready(cx) + } + } + + fn start_send(self: Pin<&mut Self>, item: Out) -> Result<(), Self::Error> { + let this = self.project(); + if this.closing.is_terminated() { + Ok(()) + } else { + this.stream.start_send(item) + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + if this.closing.is_terminated() { + Poll::Pending + } else { + this.stream.poll_flush(cx) + } + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + if *this.terminated { + Poll::Ready(Ok(())) + } else { + let r = ready!(this.stream.poll_close(cx)); + *this.terminated = true; + Poll::Ready(r) + } + } +} + +/// Closes all yielded streams ([`CloseOne`]s) on termination of incoming stream. +#[pin_project] +pub struct CloseAll { + #[pin] + stream: R, + guard: Option>, + lock: Arc>, + _out: PhantomData, +} + +impl> Stream for CloseAll { + type Item = CloseOne; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + Poll::Ready(match ready!(this.stream.poll_next(cx)) { + Some(stream) => Some(CloseOne { + stream, + closing: this.lock.clone().lock_owned(), + terminated: false, + _out: PhantomData, + }), + None => { + this.guard.take(); + None + } + }) + } +} + +impl> FusedStream for CloseAll { + fn is_terminated(&self) -> bool { + self.guard.is_none() + } +} + +pub trait CloseAllExt: Sized { + fn close_all(self) -> CloseAll; +} + +impl, R: Stream> CloseAllExt for R { + fn close_all(self) -> CloseAll { + let lock = Arc::new(Mutex::new(Closed)); + let guard = lock.clone().try_lock_owned().unwrap(); + CloseAll { + stream: self, + guard: Some(guard), + lock, + _out: PhantomData, + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 62da7a6..ddb7088 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,7 @@ extern crate self as ruchei; pub mod callback; +pub mod close_all; pub mod concurrent; pub mod echo; pub mod group_by_key;