Skip to content

Commit

Permalink
chore: Improved WebSocket protocols handler (#370)
Browse files Browse the repository at this point in the history
  • Loading branch information
0x676e67 authored Jan 26, 2025
1 parent f315e48 commit 2abe066
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 59 deletions.
2 changes: 1 addition & 1 deletion examples/base_url.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use rquest::{Client, Impersonate};
#[tokio::main]
async fn main() -> Result<(), rquest::Error> {
env_logger::init_from_env(env_logger::Env::default().default_filter_or("debug"));

// Build a client to impersonate Edge131
let mut client = Client::builder()
.impersonate(Impersonate::Edge131)
Expand Down
2 changes: 1 addition & 1 deletion examples/headers_order.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ const HEADER_ORDER: &[HeaderName] = &[
#[tokio::main]
async fn main() -> Result<(), rquest::Error> {
env_logger::init_from_env(env_logger::Env::default().default_filter_or("debug"));

// Build a client to impersonate Chrome131
let client = rquest::Client::builder()
.impersonate(Impersonate::Chrome131)
Expand Down
2 changes: 1 addition & 1 deletion examples/set_cookie_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::sync::Arc;
#[tokio::main]
async fn main() -> Result<(), rquest::Error> {
env_logger::init_from_env(env_logger::Env::default().default_filter_or("debug"));

// Build a client to impersonate Chrome131
let mut client = rquest::Client::builder()
.impersonate(Impersonate::Chrome131)
Expand Down
2 changes: 1 addition & 1 deletion examples/set_cookies.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use rquest::Impersonate;
#[tokio::main]
async fn main() -> Result<(), rquest::Error> {
env_logger::init_from_env(env_logger::Env::default().default_filter_or("debug"));

// Build a client to impersonate Chrome131
let client = rquest::Client::builder()
.impersonate(Impersonate::Chrome131)
Expand Down
108 changes: 53 additions & 55 deletions src/client/websocket/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use crate::{
use crate::{Error, Response};
use async_tungstenite::tungstenite;
use futures_util::{Sink, SinkExt, Stream, StreamExt};
use http::{header, uri::Scheme, HeaderValue, StatusCode, Version};
use http::{header, uri::Scheme, HeaderMap, HeaderName, HeaderValue, StatusCode, Version};
pub use message::{CloseCode, Message};
use tokio_util::compat::TokioAsyncReadCompatExt;
use tungstenite::protocol::WebSocketConfig;
Expand All @@ -30,7 +30,7 @@ pub type WebSocketStream =
pub struct WebSocketRequestBuilder {
inner: RequestBuilder,
nonce: Option<Cow<'static, str>>,
protocols: Option<Cow<'static, [String]>>,
protocols: Option<Vec<Cow<'static, str>>>,
config: WebSocketConfig,
}

Expand Down Expand Up @@ -92,14 +92,16 @@ impl WebSocketRequestBuilder {
///
/// ```
/// let request = WebSocketRequestBuilder::new(builder)
/// .protocols(vec!["protocol1".to_string(), "protocol2".to_string()])
/// .protocols(["protocol1", "protocol2"])
/// .build();
/// ```
pub fn protocols<P>(mut self, protocols: P) -> Self
where
P: Into<Cow<'static, [String]>>,
P: IntoIterator,
P::Item: Into<Cow<'static, str>>,
{
self.protocols = Some(protocols.into());
let protocols = protocols.into_iter().map(Into::into).collect();
self.protocols = Some(protocols);
self
}

Expand Down Expand Up @@ -216,7 +218,7 @@ impl WebSocketRequestBuilder {
if !protocols.is_empty() {
let subprotocols = protocols
.iter()
.map(|s| s.as_str())
.map(|s| s.as_ref())
.collect::<Vec<&str>>()
.join(", ");

Expand Down Expand Up @@ -249,7 +251,7 @@ impl WebSocketRequestBuilder {
pub struct WebSocketResponse {
inner: Response,
nonce: Cow<'static, str>,
protocols: Option<Cow<'static, [String]>>,
protocols: Option<Vec<Cow<'static, str>>>,
config: WebSocketConfig,
}

Expand All @@ -274,76 +276,50 @@ impl WebSocketResponse {
let (inner, protocol) = {
let headers = self.inner.headers();

// Check the version
if !matches!(self.inner.version(), Version::HTTP_11 | Version::HTTP_10) {
return Err(Error::new(
Kind::Upgrade,
Some(format!("unexpected version: {:?}", self.inner.version())),
));
}

// Check the status code
if self.inner.status() != StatusCode::SWITCHING_PROTOCOLS {
return Err(Error::new(
Kind::Upgrade,
Some(format!("unexpected status code: {}", self.inner.status())),
));
}

// Check the connection header
if let Some(header) = headers.get(header::CONNECTION) {
if !header
.to_str()
.is_ok_and(|s| s.eq_ignore_ascii_case("upgrade"))
{
log::debug!("server responded with invalid Connection header: {header:?}");
return Err(Error::new(
Kind::Upgrade,
Some(format!("invalid connection header: {:?}", header)),
));
}
} else {
if !header_contains(self.inner.headers(), header::CONNECTION, "upgrade") {
log::debug!("missing Connection header");
return Err(Error::new(Kind::Upgrade, Some("missing connection header")));
}

// Check the upgrade header
if let Some(header) = headers.get(header::UPGRADE) {
if !header
.to_str()
.is_ok_and(|s| s.eq_ignore_ascii_case("websocket"))
{
log::debug!("server responded with invalid Upgrade header: {header:?}");
return Err(Error::new(
Kind::Upgrade,
Some(format!("invalid upgrade header: {:?}", header)),
));
}
} else {
log::debug!("missing Upgrade header");
return Err(Error::new(Kind::Upgrade, Some("missing upgrade header")));
if !header_eq(self.inner.headers(), header::UPGRADE, "websocket") {
log::debug!("server responded with invalid Upgrade header");
return Err(Error::new(Kind::Upgrade, Some("invalid upgrade header")));
}

// Check the accept key
if let Some(header) = headers.get(header::SEC_WEBSOCKET_ACCEPT) {
// Check the accept key
let expected_nonce =
tungstenite::handshake::derive_accept_key(self.nonce.as_bytes());
if !header.to_str().is_ok_and(|s| s == expected_nonce) {
log::debug!(
"server responded with invalid Sec-Websocket-Accept header: {header:?}"
);
return Err(Error::new(
Kind::Upgrade,
Some(format!("invalid accept key: {:?}", header)),
));
match headers.get(header::SEC_WEBSOCKET_ACCEPT) {
Some(header) => {
if !header.to_str().is_ok_and(|s| {
s == tungstenite::handshake::derive_accept_key(self.nonce.as_bytes())
}) {
log::debug!(
"server responded with invalid Sec-Websocket-Accept header: {header:?}"
);
return Err(Error::new(
Kind::Upgrade,
Some(format!("invalid accept key: {:?}", header)),
));
}
}
None => {
log::debug!("missing Sec-Websocket-Accept header");
return Err(Error::new(Kind::Upgrade, Some("missing accept key")));
}
} else {
log::debug!("missing Sec-Websocket-Accept header");
return Err(Error::new(Kind::Upgrade, Some("missing accept key")));
}

// Ensure the server responded with the requested protocol
let protocol = headers
.get(header::SEC_WEBSOCKET_PROTOCOL)
.and_then(|v| v.to_str().ok())
Expand All @@ -369,7 +345,7 @@ impl WebSocketResponse {
}
(false, Some(protocol)) => {
if let Some(ref protocols) = self.protocols {
if !protocols.contains(protocol) {
if protocols.iter().find(|p| *p == protocol).is_none() {
// the responded protocol is none which we requested
return Err(Error::new(
Kind::Status(self.res.status()),
Expand Down Expand Up @@ -407,6 +383,28 @@ impl WebSocketResponse {
}
}

fn header_eq(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
if let Some(header) = headers.get(&key) {
header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
} else {
false
}
}

fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
let header = if let Some(header) = headers.get(&key) {
header
} else {
return false;
};

if let Ok(header) = std::str::from_utf8(header.as_bytes()) {
header.to_ascii_lowercase().contains(value)
} else {
false
}
}

/// A websocket connection
#[derive(Debug)]
pub struct WebSocket {
Expand Down

0 comments on commit 2abe066

Please sign in to comment.