From a53f05930a65c1a7399bffdbf6b77d55feebcbe8 Mon Sep 17 00:00:00 2001 From: Matthias Seitz Date: Wed, 4 Oct 2023 00:54:39 +0200 Subject: [PATCH] fix: always try to flush after ready --- examples/simple-google.rs | 29 ++++++++++++++++ src/conn.rs | 69 ++++++++++++++++++++------------------- src/listeners.rs | 2 +- 3 files changed, 65 insertions(+), 35 deletions(-) create mode 100644 examples/simple-google.rs diff --git a/examples/simple-google.rs b/examples/simple-google.rs new file mode 100644 index 00000000..83de9456 --- /dev/null +++ b/examples/simple-google.rs @@ -0,0 +1,29 @@ +use chromiumoxide::browser::BrowserConfigBuilder; +use chromiumoxide::Browser; +use futures::StreamExt; +use std::time::Duration; +use tokio::task; + +#[tokio::main] +async fn main() { + tracing_subscriber::fmt::init(); + let (browser, mut handler) = Browser::launch( + BrowserConfigBuilder::default() + .request_timeout(Duration::from_secs(5)) + .build() + .unwrap(), + ) + .await + .unwrap(); + + let h = task::spawn(async move { + while let Some(h) = handler.next().await { + h.unwrap(); + } + }); + + let page = browser.new_page("https://www.google.com").await.unwrap(); + + println!("loaded page {:?}", page); + h.await.unwrap(); +} diff --git a/src/conn.rs b/src/conn.rs index ff5312b9..46d04b5d 100644 --- a/src/conn.rs +++ b/src/conn.rs @@ -1,11 +1,12 @@ use std::collections::VecDeque; use std::marker::PhantomData; use std::pin::Pin; +use std::task::ready; use async_tungstenite::{tungstenite::protocol::WebSocketConfig, WebSocketStream}; use futures::stream::Stream; use futures::task::{Context, Poll}; -use futures::Sink; +use futures::{SinkExt, StreamExt}; use chromiumoxide_cdp::cdp::browser_protocol::target::SessionId; use chromiumoxide_types::{CallId, EventMessage, Message, MethodCall, MethodId}; @@ -93,19 +94,15 @@ impl Connection { /// sink fn start_send_next(&mut self, cx: &mut Context<'_>) -> Result<()> { if self.needs_flush { - if let Poll::Ready(Ok(())) = Sink::poll_flush(Pin::new(&mut self.ws), cx) { + if let Poll::Ready(Ok(())) = self.ws.poll_flush_unpin(cx) { self.needs_flush = false; } } if self.pending_flush.is_none() && !self.needs_flush { if let Some(cmd) = self.pending_commands.pop_front() { - // if cmd.id.to_string().contains("1") { - // log::error!("CMD {:?}", cmd); - // return Ok(()) - // } tracing::trace!("Sending {:?}", cmd); let msg = serde_json::to_string(&cmd)?; - Sink::start_send(Pin::new(&mut self.ws), msg.into())?; + self.ws.start_send_unpin(msg.into())?; self.pending_flush = Some(cmd); } } @@ -119,38 +116,42 @@ impl Stream for Connection { fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let pin = self.get_mut(); - // queue in the next message if not currently flushing - if let Err(err) = pin.start_send_next(cx) { - return Poll::Ready(Some(Err(err))); - } + loop { + // queue in the next message if not currently flushing + if let Err(err) = pin.start_send_next(cx) { + return Poll::Ready(Some(Err(err))); + } - // send the message - if let Some(call) = pin.pending_flush.take() { - if Sink::poll_ready(Pin::new(&mut pin.ws), cx).is_ready() { - pin.needs_flush = true; - } else { - pin.pending_flush = Some(call); + // send the message + if let Some(call) = pin.pending_flush.take() { + if pin.ws.poll_ready_unpin(cx).is_ready() { + pin.needs_flush = true; + // try another flush + continue; + } else { + pin.pending_flush = Some(call); + } } + break; } + // read from the ws - match Stream::poll_next(Pin::new(&mut pin.ws), cx) { - Poll::Ready(Some(Ok(msg))) => { - return match serde_json::from_slice::>(&msg.into_data()) { - Ok(msg) => { - tracing::trace!("Received {:?}", msg); - Poll::Ready(Some(Ok(msg))) - } - Err(err) => { - tracing::error!("Failed to deserialize WS response {}", err); - Poll::Ready(Some(Err(err.into()))) - } - }; - } - Poll::Ready(Some(Err(err))) => { - return Poll::Ready(Some(Err(CdpError::Ws(err)))); + match ready!(pin.ws.poll_next_unpin(cx)) { + Some(Ok(msg)) => match serde_json::from_slice::>(&msg.into_data()) { + Ok(msg) => { + tracing::trace!("Received {:?}", msg); + Poll::Ready(Some(Ok(msg))) + } + Err(err) => { + tracing::error!("Failed to deserialize WS response {}", err); + Poll::Ready(Some(Err(err.into()))) + } + }, + Some(Err(err)) => Poll::Ready(Some(Err(CdpError::Ws(err)))), + None => { + // ws connection closed + Poll::Ready(None) } - _ => {} } - Poll::Pending } } diff --git a/src/listeners.rs b/src/listeners.rs index 6dcd3e39..ac482d48 100644 --- a/src/listeners.rs +++ b/src/listeners.rs @@ -26,7 +26,7 @@ impl EventListeners { method, kind, } = req; - let subs = self.listeners.entry(method).or_insert_with(Vec::new); + let subs = self.listeners.entry(method).or_default(); subs.push(EventListener { listener, kind,