From 3fc9a79dab56eb802614bee967426f9d477e5c96 Mon Sep 17 00:00:00 2001 From: kana-rus Date: Mon, 30 Oct 2023 22:46:58 +0900 Subject: [PATCH 01/37] Add: `websocket` feature --- ohkami/Cargo.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/ohkami/Cargo.toml b/ohkami/Cargo.toml index c793618d..5d720891 100644 --- a/ohkami/Cargo.toml +++ b/ohkami/Cargo.toml @@ -28,6 +28,7 @@ byte_reader = "1.1.2" [features] rt_tokio = ["dep:tokio"] rt_async-std = ["dep:async-std"] +websocket = [] nightly = [] ##### DEBUG ##### From 016f9963f3290ad78ecd594f498cd1d3fe1a1d59 Mon Sep 17 00:00:00 2001 From: kana-rus Date: Tue, 31 Oct 2023 18:28:02 +0900 Subject: [PATCH 02/37] Add: draft section `web socket` in README --- README.md | 34 ++++++++++++++++++++++ ohkami/Cargo.toml | 4 +-- ohkami/src/layer1_req_res/mod.rs | 5 +++- ohkami/src/layer1_req_res/websocket/mod.rs | 0 4 files changed, 40 insertions(+), 3 deletions(-) create mode 100644 ohkami/src/layer1_req_res/websocket/mod.rs diff --git a/README.md b/README.md index 714aabef..f5e85bbc 100644 --- a/README.md +++ b/README.md @@ -184,6 +184,40 @@ async fn main() {
+### web socket +Activate `websocket` feature. + +```rust +use ohkami::prelude::*; +use ohkami::websocket::{WebSocket, Message}; + +async fn handle_websocket(ws: WebSocket) { + while let Some(Ok(message)) = ws.recv().await { + match message { + Message::Text(text) => { + let response = Message::from(text); + if let Err(e) = ws.send(response).await { + tracing::error!("{e}"); + break + } + } + Message::Close(_) => break, + other => tracing::warning!("Unsupported message type: {other}"), + } + } +} + +#[tokio::main] +async fn main() { + Ohkami::new(( + "/websocket" + .GET(handle_websocket) + )).howl(8080).await +} +``` + +
+ ### testing ```rust use ohkami::prelude::*; diff --git a/ohkami/Cargo.toml b/ohkami/Cargo.toml index 5d720891..1d42b3ec 100644 --- a/ohkami/Cargo.toml +++ b/ohkami/Cargo.toml @@ -32,5 +32,5 @@ websocket = [] nightly = [] ##### DEBUG ##### -DEBUG = ["serde/derive", "tokio?/macros", "async-std?/attributes"] -# default = ["rt_tokio", "DEBUG"] \ No newline at end of file +DEBUG = ["websocket", "serde/derive", "tokio?/macros", "async-std?/attributes"] +default = ["rt_tokio", "DEBUG"] \ No newline at end of file diff --git a/ohkami/src/layer1_req_res/mod.rs b/ohkami/src/layer1_req_res/mod.rs index 011e3cc0..11440ac2 100644 --- a/ohkami/src/layer1_req_res/mod.rs +++ b/ohkami/src/layer1_req_res/mod.rs @@ -1,6 +1,9 @@ -mod request; pub use request::*; +mod request; pub use request::*; mod response; pub use response::*; +#[cfg(feature="websocket")] mod websocket; +#[cfg(feature="websocket")] pub use websocket::*; + #[cfg(test)] #[allow(unused)] mod __ { use serde::Serialize; diff --git a/ohkami/src/layer1_req_res/websocket/mod.rs b/ohkami/src/layer1_req_res/websocket/mod.rs new file mode 100644 index 00000000..e69de29b From e6098bdbde4e5fd66b79f7730568ea44efc65a71 Mon Sep 17 00:00:00 2001 From: kana-rus Date: Tue, 31 Oct 2023 18:49:04 +0900 Subject: [PATCH 03/37] @2023-10-31 18:48+9:00 --- ohkami/Cargo.toml | 3 ++- ohkami/src/layer1_req_res/mod.rs | 3 --- ohkami/src/layer1_req_res/websocket/mod.rs | 0 ohkami/src/lib.rs | 8 +++++++ ohkami/src/x_websocket/message.rs | 26 ++++++++++++++++++++++ ohkami/src/x_websocket/mod.rs | 1 + 6 files changed, 37 insertions(+), 4 deletions(-) delete mode 100644 ohkami/src/layer1_req_res/websocket/mod.rs create mode 100644 ohkami/src/x_websocket/message.rs create mode 100644 ohkami/src/x_websocket/mod.rs diff --git a/ohkami/Cargo.toml b/ohkami/Cargo.toml index 1d42b3ec..96e25466 100644 --- a/ohkami/Cargo.toml +++ b/ohkami/Cargo.toml @@ -24,11 +24,12 @@ serde_json = "1.0" chrono = "0.4" percent-encoding = "2.2.0" byte_reader = "1.1.2" +tungstenite = { version = "0.20", optional = true, default-features = false } [features] rt_tokio = ["dep:tokio"] rt_async-std = ["dep:async-std"] -websocket = [] +websocket = ["dep:tungstenite"] nightly = [] ##### DEBUG ##### diff --git a/ohkami/src/layer1_req_res/mod.rs b/ohkami/src/layer1_req_res/mod.rs index 11440ac2..ec95a6cf 100644 --- a/ohkami/src/layer1_req_res/mod.rs +++ b/ohkami/src/layer1_req_res/mod.rs @@ -1,9 +1,6 @@ mod request; pub use request::*; mod response; pub use response::*; -#[cfg(feature="websocket")] mod websocket; -#[cfg(feature="websocket")] pub use websocket::*; - #[cfg(test)] #[allow(unused)] mod __ { use serde::Serialize; diff --git a/ohkami/src/layer1_req_res/websocket/mod.rs b/ohkami/src/layer1_req_res/websocket/mod.rs deleted file mode 100644 index e69de29b..00000000 diff --git a/ohkami/src/lib.rs b/ohkami/src/lib.rs index 7056f66f..b69dae03 100644 --- a/ohkami/src/lib.rs +++ b/ohkami/src/lib.rs @@ -255,6 +255,9 @@ mod layer5_ohkami; #[cfg(test)] mod layer6_testing; +#[cfg(feature="websocket")] +mod x_websocket; + /*===== visibility managements =====*/ @@ -282,6 +285,11 @@ pub mod testing { pub use crate::layer6_testing::*; } +#[cfg(feature="websocket")] +pub mod websocket { + pub use crate::x_websocket::*; +} + #[doc(hidden)] pub mod __internal__ { pub use crate::layer1_req_res::{ diff --git a/ohkami/src/x_websocket/message.rs b/ohkami/src/x_websocket/message.rs new file mode 100644 index 00000000..c32e6a97 --- /dev/null +++ b/ohkami/src/x_websocket/message.rs @@ -0,0 +1,26 @@ +use std::borrow::Cow; + + +pub enum Message { + Text (String), + Binary(Vec), + Ping (PingPongFrame), + Pong (PingPongFrame), + Close (Option), +} +pub struct PingPongFrame { + buf: [u8; 125], + len: usize/* less than 125 */ +} +pub struct CloseFrame { + pub code: u16, + pub reason: Cow<'static, str>, +} + +impl Message { + fn into_tungstenite(self) -> tungstenite::Message { + match self { + + } + } +} diff --git a/ohkami/src/x_websocket/mod.rs b/ohkami/src/x_websocket/mod.rs new file mode 100644 index 00000000..e935b02c --- /dev/null +++ b/ohkami/src/x_websocket/mod.rs @@ -0,0 +1 @@ +mod message; From 23cb4982b7128a8079bd792821ee1a82cf681a96 Mon Sep 17 00:00:00 2001 From: kana-rus Date: Wed, 1 Nov 2023 15:57:50 +0900 Subject: [PATCH 04/37] @2023-11-01 15:57+9:00 --- README.md | 28 +++++----- ohkami/Cargo.toml | 3 +- ohkami/src/x_websocket/context.rs | 86 +++++++++++++++++++++++++++++++ ohkami/src/x_websocket/message.rs | 8 --- ohkami/src/x_websocket/mod.rs | 6 +++ 5 files changed, 108 insertions(+), 23 deletions(-) create mode 100644 ohkami/src/x_websocket/context.rs diff --git a/README.md b/README.md index f5e85bbc..2e4adf41 100644 --- a/README.md +++ b/README.md @@ -189,22 +189,24 @@ Activate `websocket` feature. ```rust use ohkami::prelude::*; -use ohkami::websocket::{WebSocket, Message}; - -async fn handle_websocket(ws: WebSocket) { - while let Some(Ok(message)) = ws.recv().await { - match message { - Message::Text(text) => { - let response = Message::from(text); - if let Err(e) = ws.send(response).await { - tracing::error!("{e}"); - break +use ohkami::websocket::{WebSocketContext, Message}; + +fn handle_websocket(c: WebSocketContext) -> Response { + c.on_upgrade(|ws| async move { + while let Some(Ok(message)) = ws.recv().await { + match message { + Message::Text(text) => { + let response = Message::from(text); + if let Err(e) = ws.send(response).await { + tracing::error!("{e}"); + break + } } + Message::Close(_) => break, + other => tracing::warning!("Unsupported message type: {other}"), } - Message::Close(_) => break, - other => tracing::warning!("Unsupported message type: {other}"), } - } + }).await } #[tokio::main] diff --git a/ohkami/Cargo.toml b/ohkami/Cargo.toml index 96e25466..1d42b3ec 100644 --- a/ohkami/Cargo.toml +++ b/ohkami/Cargo.toml @@ -24,12 +24,11 @@ serde_json = "1.0" chrono = "0.4" percent-encoding = "2.2.0" byte_reader = "1.1.2" -tungstenite = { version = "0.20", optional = true, default-features = false } [features] rt_tokio = ["dep:tokio"] rt_async-std = ["dep:async-std"] -websocket = ["dep:tungstenite"] +websocket = [] nightly = [] ##### DEBUG ##### diff --git a/ohkami/src/x_websocket/context.rs b/ohkami/src/x_websocket/context.rs new file mode 100644 index 00000000..255f95df --- /dev/null +++ b/ohkami/src/x_websocket/context.rs @@ -0,0 +1,86 @@ +use std::{future::Future, borrow::Cow}; +use super::{WebSocket}; +use crate::{Response}; + + +pub struct WebSocketContext { + config: Config, + + on_failed_upgrade: FU, + + selected_protocol: Option>, + sec_websocket_key: Cow<'static, str>, + sec_websocket_protocol: Option>, +} +pub struct Config { + write_buffer_size: usize, + max_write_buffer_size: usize, + max_message_size: Option, + max_frame_size: Option, + accept_unmasked_frames: bool, +} const _: () = { + impl Default for Config { + fn default() -> Self { + Self { + write_buffer_size: 128 * 1024, // 128 KiB + max_write_buffer_size: usize::MAX, + max_message_size: Some(64 << 20), + max_frame_size: Some(16 << 20), + accept_unmasked_frames: false, + } + } + } +}; +pub trait OnFailedUpgrade: Send + 'static { + fn handle(self, error: UpgradeError); +} +pub struct UpgradeError { /* TODO */ } +pub struct DefaultOnFailedUpgrade; const _: () = { + impl OnFailedUpgrade for DefaultOnFailedUpgrade { + fn handle(self, _: UpgradeError) { /* DO NOTHING (discard error) */ } + } +}; + +impl WebSocketContext { + pub fn write_buffer_size(mut self, size: usize) -> Self { + self.config.write_buffer_size = size; + self + } + pub fn max_write_buffer_size(mut self, size: usize) -> Self { + self.config.max_write_buffer_size = size; + self + } + pub fn max_message_size(mut self, size: usize) -> Self { + self.config.max_message_size = Some(size); + self + } + pub fn max_frame_size(mut self, size: usize) -> Self { + self.config.max_frame_size = Some(size); + self + } + pub fn accept_unmasked_frames(mut self) -> Self { + self.config.accept_unmasked_frames = true; + self + } +} + +impl WebSocketContext { + pub fn protocols>>(mut self, protocols: impl Iterator) -> Self { + if let Some(req_protocols) = &self.sec_websocket_protocol { + self.selected_protocol = protocols.map(Into::into) + .find(|p| req_protocols.split(',').any(|req_p| req_p.trim() == p)) + } + self + } +} + +impl WebSocketContext { + pub fn on_upgrade< + Fut: Future + Send + 'static, + >( + self, + callback: impl Fn(WebSocket) -> Fut + Send + 'static + ) -> Response { + todo!() + } +} diff --git a/ohkami/src/x_websocket/message.rs b/ohkami/src/x_websocket/message.rs index c32e6a97..f8510b70 100644 --- a/ohkami/src/x_websocket/message.rs +++ b/ohkami/src/x_websocket/message.rs @@ -16,11 +16,3 @@ pub struct CloseFrame { pub code: u16, pub reason: Cow<'static, str>, } - -impl Message { - fn into_tungstenite(self) -> tungstenite::Message { - match self { - - } - } -} diff --git a/ohkami/src/x_websocket/mod.rs b/ohkami/src/x_websocket/mod.rs index e935b02c..f6fe00b8 100644 --- a/ohkami/src/x_websocket/mod.rs +++ b/ohkami/src/x_websocket/mod.rs @@ -1 +1,7 @@ +mod context; mod message; + + +pub struct WebSocket { + +} From 058cc7bc5aee16d62d328a59c83eab8bcb7dd99a Mon Sep 17 00:00:00 2001 From: kana-rus Date: Thu, 2 Nov 2023 00:46:44 +0900 Subject: [PATCH 05/37] @2023-11-02 00:46+9:00 --- ohkami/src/layer0_lib/status.rs | 7 +++ ohkami/src/layer0_lib/string.rs | 1 + ohkami/src/layer1_req_res/response/headers.rs | 39 +++++++----- ohkami/src/layer2_context/mod.rs | 60 +++++++------------ ohkami/src/x_websocket/context.rs | 43 ++++++++++++- 5 files changed, 94 insertions(+), 56 deletions(-) diff --git a/ohkami/src/layer0_lib/status.rs b/ohkami/src/layer0_lib/status.rs index 8cbc04ad..0550e6dc 100644 --- a/ohkami/src/layer0_lib/status.rs +++ b/ohkami/src/layer0_lib/status.rs @@ -1,19 +1,26 @@ #[derive(PartialEq)] pub enum Status { + SwitchingProtocols, + OK, Created, NoContent, + MovedPermanently, Found, + BadRequest, Unauthorized, Forbidden, NotFound, + InternalServerError, NotImplemented, } impl Status { #[inline(always)] pub(crate) const fn as_str(&self) -> &'static str { match self { + Self::SwitchingProtocols => "101 Switching Protocols", + Self::OK => "200 OK", Self::Created => "201 Created", Self::NoContent => "204 No Content", diff --git a/ohkami/src/layer0_lib/string.rs b/ohkami/src/layer0_lib/string.rs index 66e4a45e..2a9fd8df 100644 --- a/ohkami/src/layer0_lib/string.rs +++ b/ohkami/src/layer0_lib/string.rs @@ -14,3 +14,4 @@ pub trait IntoCows<'l> { } impl IntoCows<'static> for &'static str {fn into_cow(self) -> Cow<'static, str> {Cow::Borrowed(self)}} impl IntoCows<'static> for String {fn into_cow(self) -> Cow<'static, str> {Cow::Owned(self)}} +impl IntoCows<'static> for Cow<'static, str> {fn into_cow(self) -> Cow<'static, str> {self}} diff --git a/ohkami/src/layer1_req_res/response/headers.rs b/ohkami/src/layer1_req_res/response/headers.rs index 00ad1c3e..564e2104 100644 --- a/ohkami/src/layer1_req_res/response/headers.rs +++ b/ohkami/src/layer1_req_res/response/headers.rs @@ -2,18 +2,26 @@ #![allow(non_snake_case)] #![allow(unused)] // until .... -use std::{collections::BTreeMap, sync::OnceLock}; -use crate::{layer0_lib::now}; +use std::{collections::BTreeMap, sync::OnceLock, borrow::Cow}; +use crate::{layer0_lib::{now, IntoCows}}; -struct Header(Option<&'static str>); +struct Header(Option>); pub trait HeaderValue { - fn into_header_value(self) -> Option<&'static str>; -} -impl HeaderValue for &'static str {fn into_header_value(self) -> Option<&'static str> {Some(self)}} -impl HeaderValue for Option<&'static str> {fn into_header_value(self) -> Option<&'static str> {self}} - + fn into_header_value(self) -> Option>; +} const _: () = { + impl> HeaderValue for S { + fn into_header_value(self) -> Option> { + Some(self.into_cow()) + } + } + impl HeaderValue for Option<()> { + fn into_header_value(self) -> Option> { + None + } + } +}; macro_rules! ResponseHeaders { ($( @@ -23,18 +31,19 @@ macro_rules! ResponseHeaders { )*) => { /// Headers in a response. /// - /// In current version, this expects values are `&'static str` or `None`. + /// Expected values: &'static str, String, Cow<'static, str>, or `None` /// - /// - `&'static str` sets the header value to it - /// - `None` removes the header value + /// - `None` clears value of the header + /// - others set the header to thet value /// ///
/// /// - Content-Type /// - Content-Length /// - Access-Control-* + /// - headers related to WebSocket handshake /// - /// are managed by ohkami and MUST NOT be set by `.custom` ( `.custom` has to be used **ONLY** to set custom HTTP headers ) + /// are managed by ohkami and MUST NOT be set by `.custom` ( `.custom` has to be used **ONLY** to set custom HTTP headers like `X-MyApp-Data: amazing` ) pub struct ResponseHeaders { $( $group: bool, )* $($( $name: Header, )*)* @@ -77,7 +86,7 @@ macro_rules! ResponseHeaders { if self.$group { $( if let Some(value) = self.$name.0 { - h.push_str($key);h.push_str(value);h.push('\r');h.push('\n'); + h.push_str($key);h.push_str(&value);h.push('\r');h.push('\n'); } )* } @@ -142,8 +151,8 @@ impl ResponseHeaders { match value.into_header_value() { Some(value) => { self.custom.entry(key) - .and_modify(|v| *v = value) - .or_insert(value); + .and_modify(|v| *v = &value) + .or_insert(&value); self } None => { diff --git a/ohkami/src/layer2_context/mod.rs b/ohkami/src/layer2_context/mod.rs index 4ad1936a..89bb610b 100644 --- a/ohkami/src/layer2_context/mod.rs +++ b/ohkami/src/layer2_context/mod.rs @@ -82,52 +82,36 @@ impl Context { } } -impl Context { - #[inline] pub fn OK(&self) -> Response { - Response { - status: Status::OK, - headers: self.headers.to_string(), - content: None, - } - } - #[inline] pub fn Created(&self) -> Response { - Response { - status: Status::Created, - headers: self.headers.to_string(), - content: None, - } - } - #[inline] pub fn NoContent(&self) -> Response { - Response { - status: Status::NoContent, - headers: self.headers.to_string(), - content: None, - } - } -} - -macro_rules! impl_error_response { - ($( $name:ident ),*) => { +macro_rules! generate_response { + ($( $status:ident ),* $(,)?) => {$( impl Context { - $( - #[inline] pub fn $name(&self) -> Response { - Response { - status: Status::$name, - headers: self.headers.to_string(), - content: None, - } + #[inline] pub fn $status(&self) -> Response { + Response { + status: Status::$status, + headers: self.headers.to_string(), + content: None, } - )* + } } - }; -} impl_error_response!( + )*}; +} generate_response! { + SwitchingProtocols, + + OK, + Created, + NoContent, + + // MovedPermanently, + // Found, + BadRequest, Unauthorized, Forbidden, NotFound, + InternalServerError, - NotImplemented -); + NotImplemented, +} impl Context { #[inline] pub fn redirect_to(&self, location: impl AsStr) -> Response { diff --git a/ohkami/src/x_websocket/context.rs b/ohkami/src/x_websocket/context.rs index 255f95df..e4a0aa19 100644 --- a/ohkami/src/x_websocket/context.rs +++ b/ohkami/src/x_websocket/context.rs @@ -1,9 +1,10 @@ use std::{future::Future, borrow::Cow}; use super::{WebSocket}; -use crate::{Response}; +use crate::{Response, Context, __rt__}; pub struct WebSocketContext { + c: Context, config: Config, on_failed_upgrade: FU, @@ -12,6 +13,7 @@ pub struct WebSocketContext { sec_websocket_key: Cow<'static, str>, sec_websocket_protocol: Option>, } + pub struct Config { write_buffer_size: usize, max_write_buffer_size: usize, @@ -31,16 +33,30 @@ pub struct Config { } } }; + +pub enum UpgradeError { /* TODO */ } pub trait OnFailedUpgrade: Send + 'static { fn handle(self, error: UpgradeError); } -pub struct UpgradeError { /* TODO */ } pub struct DefaultOnFailedUpgrade; const _: () = { impl OnFailedUpgrade for DefaultOnFailedUpgrade { fn handle(self, _: UpgradeError) { /* DO NOTHING (discard error) */ } } }; + +impl WebSocketContext { + pub(crate) fn new(c: Context) -> Self { + Self {c, + config: Config::default(), + on_failed_upgrade: DefaultOnFailedUpgrade, + selected_protocol: None, + sec_websocket_key: todo!(), + sec_websocket_protocol: None, + } + } +} + impl WebSocketContext { pub fn write_buffer_size(mut self, size: usize) -> Self { self.config.write_buffer_size = size; @@ -81,6 +97,27 @@ impl WebSocketContext { self, callback: impl Fn(WebSocket) -> Fut + Send + 'static ) -> Response { - todo!() + let Self { + mut c, + config, + on_failed_upgrade, + selected_protocol, + sec_websocket_key, + sec_websocket_protocol, + } = self; + + __rt__::task::spawn(async move { + todo!() + }); + + c.headers + .custom("Connection", "Upgrade") + .custom("Upgrade", "websocket") + .custom("Sec-WebSocket-Accept", sign(sec_websocket_key.as_bytes())); + if let Some(protocol) = selected_protocol { + c.headers + .custom("Sec-WebSocket-Protocol", protocol); + } + c.SwitchingProtocols() } } From 603b860f1b385a9a624bb7ef4ee0f53b34d7fb3d Mon Sep 17 00:00:00 2001 From: kana-rus Date: Thu, 2 Nov 2023 15:01:35 +0900 Subject: [PATCH 06/37] @2023-11-02 15:01+9:00 --- ohkami/src/x_websocket/context.rs | 35 ++++++++++++++++++++++++------- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/ohkami/src/x_websocket/context.rs b/ohkami/src/x_websocket/context.rs index e4a0aa19..d8e6bed7 100644 --- a/ohkami/src/x_websocket/context.rs +++ b/ohkami/src/x_websocket/context.rs @@ -1,6 +1,7 @@ use std::{future::Future, borrow::Cow}; use super::{WebSocket}; -use crate::{Response, Context, __rt__}; +use crate::{Response, Context, __rt__, Request}; +use crate::http::{Method}; pub struct WebSocketContext { @@ -46,14 +47,34 @@ pub struct DefaultOnFailedUpgrade; const _: () = { impl WebSocketContext { - pub(crate) fn new(c: Context) -> Self { - Self {c, - config: Config::default(), + pub(crate) fn new(c: Context, req: &mut Request) -> Result> { + if req.method() != Method::GET { + return Err(Cow::Borrowed("Method is not `GET`")) + } + if req.header("Connection") != Some("upgrade") { + return Err(Cow::Borrowed("Connection header is not `upgrade`")) + } + if req.header("Upgrade") != Some("websocket") { + return Err(Cow::Borrowed("Upgrade header is not `websocket`")) + } + if req.header("Sec-WebSocket-Version") != Some("13") { + return Err(Cow::Borrowed("Sec-WebSocket-Version header is not `13`")) + } + + let sec_websocket_key = Cow::Owned(req.header("Sec-WebSocket-Key") + .ok_or(Cow::Borrowed("Sec-WebSocket-Key header is missing"))? + .to_string()); + + let sec_websocket_protocol = req.header("Sec-WebSocket-Protocol") + .map(|swp| Cow::Owned(swp.to_string())); + + Ok(Self {c, + config: Config::default(), on_failed_upgrade: DefaultOnFailedUpgrade, selected_protocol: None, - sec_websocket_key: todo!(), - sec_websocket_protocol: None, - } + sec_websocket_key, + sec_websocket_protocol, + }) } } From b068780cecfb68937c0ad2b08574a8dc99497a5a Mon Sep 17 00:00:00 2001 From: kana-rus Date: Thu, 2 Nov 2023 15:21:11 +0900 Subject: [PATCH 07/37] @2023-11-02 15:21+9:00 --- ohkami/src/layer1_req_res/response/headers.rs | 8 +++----- ohkami/src/x_websocket/context.rs | 5 +++++ 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/ohkami/src/layer1_req_res/response/headers.rs b/ohkami/src/layer1_req_res/response/headers.rs index 564e2104..a88b26f4 100644 --- a/ohkami/src/layer1_req_res/response/headers.rs +++ b/ohkami/src/layer1_req_res/response/headers.rs @@ -47,7 +47,7 @@ macro_rules! ResponseHeaders { pub struct ResponseHeaders { $( $group: bool, )* $($( $name: Header, )*)* - custom: BTreeMap<&'static str, &'static str>, + custom: BTreeMap<&'static str, Cow<'static, str>>, cors_str: &'static str, } @@ -85,7 +85,7 @@ macro_rules! ResponseHeaders { $( if self.$group { $( - if let Some(value) = self.$name.0 { + if let Some(value) = &self.$name.0 { h.push_str($key);h.push_str(&value);h.push('\r');h.push('\n'); } )* @@ -150,9 +150,7 @@ impl ResponseHeaders { pub fn custom(&mut self, key: &'static str, value: impl HeaderValue) -> &mut Self { match value.into_header_value() { Some(value) => { - self.custom.entry(key) - .and_modify(|v| *v = &value) - .or_insert(&value); + self.custom.insert(key, value); self } None => { diff --git a/ohkami/src/x_websocket/context.rs b/ohkami/src/x_websocket/context.rs index d8e6bed7..8d211a57 100644 --- a/ohkami/src/x_websocket/context.rs +++ b/ohkami/src/x_websocket/context.rs @@ -142,3 +142,8 @@ impl WebSocketContext { c.SwitchingProtocols() } } + + +fn sign(key: &[u8]) -> String { + todo!() +} From 4cd17dd2f3e7303d0f849bf6c295facb78648a76 Mon Sep 17 00:00:00 2001 From: kana-rus Date: Fri, 3 Nov 2023 00:38:45 +0900 Subject: [PATCH 08/37] @2023-11-03 00:38+9:00 --- ohkami/src/x_websocket/mod.rs | 3 +- ohkami/src/x_websocket/sign.rs | 82 ++++++++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+), 1 deletion(-) create mode 100644 ohkami/src/x_websocket/sign.rs diff --git a/ohkami/src/x_websocket/mod.rs b/ohkami/src/x_websocket/mod.rs index f6fe00b8..c89082fd 100644 --- a/ohkami/src/x_websocket/mod.rs +++ b/ohkami/src/x_websocket/mod.rs @@ -1,7 +1,8 @@ mod context; mod message; +mod sign; pub struct WebSocket { - + } diff --git a/ohkami/src/x_websocket/sign.rs b/ohkami/src/x_websocket/sign.rs new file mode 100644 index 00000000..15c9dca5 --- /dev/null +++ b/ohkami/src/x_websocket/sign.rs @@ -0,0 +1,82 @@ +fn sha1(message: &str) -> [u8; 160] { + //let message_bits = message.as_bytes(); +// + //u32::from_be_bytes(bytes) +// + //let ml = message.len() as u64; +// + //let mut h0: u32 = 0x67452301; + //let mut h1: u32 = 0xEFCDAB89; + //let mut h2: u32 = 0x98BADCFE; + //let mut h3: u32 = 0x10325476; + //let mut h4: u32 = 0xC3D2E1F0; +// + ////let message_u16 = message as u64; +// + todo!() +} + +const CHANK: usize = 64; +struct Digest { + h: [u32; 5], + x: [u8; CHANK], + nx: usize, + len: u64, +} + +const K0: u32 = 0x5A827999; +const K1: u32 = 0x6ED9EBA1; +const K2: u32 = 0x8F1BBCDC; +const K3: u32 = 0xCA62C1D6; + +impl Digest { + fn reset() -> Self { + Self { + h: [0x67452301, 0xEFCDAB89, 0x98BADCFE, 0x10325476, 0xC3D2E1F0], + x: [0; CHANK], + nx: 0, + len: 0, + } + } + + fn write(&mut self, mut data: &[u8]) { + self.len += data.len() as u64; + if self.nx > 0 { + let n = (CHANK - self.nx).min(data.len()); + self.x[self.nx..(self.nx + n)].copy_from_slice(data); + self.nx += n; + if self.nx == CHANK { + self.block_x(); + self.nx = 0; + } + data = &data[n..] + } + } + + fn block(&mut self, mut data: &[u8]) { + + } + + fn block_x(&mut self) { + let mut w = [0u32; 16]; + + let (h0, h1, h2, h3, h4) = (self.h[0], self.h[1], self.h[2], self.h[3], self.h[4]); + while self.x.len() >= CHANK { + for i in 0..16 { + let j = i * 4; + w[i] = (self.x[j] as u32) << 24 | (self.x[j+1] as u32) << 16 | (self.x[j+2] as u32) << 8 | (self.x[j+3] as u32); + } + + let (mut a, mut b, mut c, mut d, mut e) = (h0, h1, h2, h3, h4); + + for i in 0..16 { + let f = (b & c) | ((!b) & d); + let t = a.rotate_left(5) + f + e + w[i & 0xf] + K0; + (a, b, c, d, e) = (t, a, b.rotate_left(30), c, d) + } + for i in 16..20 { + tmp := + } + } + } +} From d52e6fed4c313abba265be8dd43da966f399c90d Mon Sep 17 00:00:00 2001 From: kana-rus Date: Fri, 3 Nov 2023 19:53:28 +0900 Subject: [PATCH 09/37] TODO: test of sign (sha1, base64) --- ohkami/src/x_websocket/context.rs | 12 ++- ohkami/src/x_websocket/sign.rs | 84 +-------------- ohkami/src/x_websocket/sign/base64.rs | 50 +++++++++ ohkami/src/x_websocket/sign/sha1.rs | 142 ++++++++++++++++++++++++++ 4 files changed, 202 insertions(+), 86 deletions(-) create mode 100644 ohkami/src/x_websocket/sign/base64.rs create mode 100644 ohkami/src/x_websocket/sign/sha1.rs diff --git a/ohkami/src/x_websocket/context.rs b/ohkami/src/x_websocket/context.rs index 8d211a57..ed6a768c 100644 --- a/ohkami/src/x_websocket/context.rs +++ b/ohkami/src/x_websocket/context.rs @@ -1,5 +1,5 @@ use std::{future::Future, borrow::Cow}; -use super::{WebSocket}; +use super::{WebSocket, sign}; use crate::{Response, Context, __rt__, Request}; use crate::http::{Method}; @@ -134,7 +134,7 @@ impl WebSocketContext { c.headers .custom("Connection", "Upgrade") .custom("Upgrade", "websocket") - .custom("Sec-WebSocket-Accept", sign(sec_websocket_key.as_bytes())); + .custom("Sec-WebSocket-Accept", sign(&sec_websocket_key)); if let Some(protocol) = selected_protocol { c.headers .custom("Sec-WebSocket-Protocol", protocol); @@ -143,7 +143,11 @@ impl WebSocketContext { } } +fn sign(sec_websocket_key: &str) -> String { + let mut sha1 = sign::Sha1::new(); + sha1.write(sec_websocket_key.as_bytes()); + sha1.write(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); -fn sign(key: &[u8]) -> String { - todo!() + let sec_websocket_accept_bytes = sign::encode_sha1_to_base64(sha1.sum()); + unsafe {String::from_utf8_unchecked(sec_websocket_accept_bytes.to_vec())} } diff --git a/ohkami/src/x_websocket/sign.rs b/ohkami/src/x_websocket/sign.rs index 15c9dca5..be6b4693 100644 --- a/ohkami/src/x_websocket/sign.rs +++ b/ohkami/src/x_websocket/sign.rs @@ -1,82 +1,2 @@ -fn sha1(message: &str) -> [u8; 160] { - //let message_bits = message.as_bytes(); -// - //u32::from_be_bytes(bytes) -// - //let ml = message.len() as u64; -// - //let mut h0: u32 = 0x67452301; - //let mut h1: u32 = 0xEFCDAB89; - //let mut h2: u32 = 0x98BADCFE; - //let mut h3: u32 = 0x10325476; - //let mut h4: u32 = 0xC3D2E1F0; -// - ////let message_u16 = message as u64; -// - todo!() -} - -const CHANK: usize = 64; -struct Digest { - h: [u32; 5], - x: [u8; CHANK], - nx: usize, - len: u64, -} - -const K0: u32 = 0x5A827999; -const K1: u32 = 0x6ED9EBA1; -const K2: u32 = 0x8F1BBCDC; -const K3: u32 = 0xCA62C1D6; - -impl Digest { - fn reset() -> Self { - Self { - h: [0x67452301, 0xEFCDAB89, 0x98BADCFE, 0x10325476, 0xC3D2E1F0], - x: [0; CHANK], - nx: 0, - len: 0, - } - } - - fn write(&mut self, mut data: &[u8]) { - self.len += data.len() as u64; - if self.nx > 0 { - let n = (CHANK - self.nx).min(data.len()); - self.x[self.nx..(self.nx + n)].copy_from_slice(data); - self.nx += n; - if self.nx == CHANK { - self.block_x(); - self.nx = 0; - } - data = &data[n..] - } - } - - fn block(&mut self, mut data: &[u8]) { - - } - - fn block_x(&mut self) { - let mut w = [0u32; 16]; - - let (h0, h1, h2, h3, h4) = (self.h[0], self.h[1], self.h[2], self.h[3], self.h[4]); - while self.x.len() >= CHANK { - for i in 0..16 { - let j = i * 4; - w[i] = (self.x[j] as u32) << 24 | (self.x[j+1] as u32) << 16 | (self.x[j+2] as u32) << 8 | (self.x[j+3] as u32); - } - - let (mut a, mut b, mut c, mut d, mut e) = (h0, h1, h2, h3, h4); - - for i in 0..16 { - let f = (b & c) | ((!b) & d); - let t = a.rotate_left(5) + f + e + w[i & 0xf] + K0; - (a, b, c, d, e) = (t, a, b.rotate_left(30), c, d) - } - for i in 16..20 { - tmp := - } - } - } -} +mod sha1; pub use sha1:: Sha1; +mod base64; pub use base64::encode_sha1_to_base64; diff --git a/ohkami/src/x_websocket/sign/base64.rs b/ohkami/src/x_websocket/sign/base64.rs new file mode 100644 index 00000000..f480361b --- /dev/null +++ b/ohkami/src/x_websocket/sign/base64.rs @@ -0,0 +1,50 @@ +/* https://github.com/golang/go/blob/master/src/encoding/base64/base64.go */ + +const ENCODER: [u8; 64] = *b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; +const PADDING: char = '='; + +const SIZE_FROM_SHA1: usize = 28; + +pub fn encode_sha1_to_base64(sha1_bytes: [u8; super::sha1::SIZE]) -> [u8; SIZE_FROM_SHA1] { + let mut dst = [0; SIZE_FROM_SHA1]; + + let (mut di, mut si) = (0, 0); + let n = (super::sha1::SIZE / 3) * 3; + while si < n { + let val = (sha1_bytes[si+0] as usize)<<16 | (sha1_bytes[si+1] as usize)<<8 | (sha1_bytes[si+2] as usize); + + dst[di+0] = ENCODER[val>>18&0x3F]; + dst[di+1] = ENCODER[val>>12&0x3F]; + dst[di+2] = ENCODER[val>>6&0x3F]; + dst[di+3] = ENCODER[val&0x3F]; + + si += 3; + di += 4; + } + + let remain = super::sha1::SIZE - si; + /* unreachable because `si` is a multiple of 3 and `sha1::SIZE` is 20 */ + // if remain == 0 {return dst} + + let mut val = (sha1_bytes[si+0] as usize) << 16; + if remain == 2 { + val |= (sha1_bytes[si+1] as usize) << 8; + } + + dst[di+0] = ENCODER[val>>18&0x3F]; + dst[di+1] = ENCODER[val>>12&0x3F]; + + match remain { + 2 => { + dst[di+2] = ENCODER[val>>6&0x3F]; + dst[di+3] = b'='; + } + 1 => { + dst[di+2] = b'='; + dst[di+3] = b'='; + } + _ => unsafe {std::hint::unreachable_unchecked()} + } + + dst +} diff --git a/ohkami/src/x_websocket/sign/sha1.rs b/ohkami/src/x_websocket/sign/sha1.rs new file mode 100644 index 00000000..88b55800 --- /dev/null +++ b/ohkami/src/x_websocket/sign/sha1.rs @@ -0,0 +1,142 @@ +pub const CHANK: usize = 64; +pub const SIZE: usize = 20; // bytes; 160 bits + +pub struct Sha1 { + h: [u32; 5], + x: [u8; CHANK], + nx: usize, + len: u64, +} + +const K0: u32 = 0x5A827999; +const K1: u32 = 0x6ED9EBA1; +const K2: u32 = 0x8F1BBCDC; +const K3: u32 = 0xCA62C1D6; + +// https://github.com/golang/go/blob/master/src/crypto/sha1/sha1.go +impl Sha1 { + pub fn new() -> Self { + Self { + h: [0x67452301, 0xEFCDAB89, 0x98BADCFE, 0x10325476, 0xC3D2E1F0], + x: [0; CHANK], + nx: 0, + len: 0, + } + } + + pub fn write(&mut self, mut data: &[u8]) { + self.len += data.len() as u64; + if self.nx > 0 { + let n = (CHANK - self.nx).min(data.len()); + self.x[self.nx..(self.nx + n)].copy_from_slice(&data[..n]); + self.nx += n; + if self.nx == CHANK { + let mut p = [0; CHANK]; p.copy_from_slice(&self.x); + self.block(&p); + self.nx = 0; + } + data = &data[n..] + } + if data.len() >= CHANK { + let n = data.len() & (!(CHANK - 1)); + self.block(&data[..n]); + data = &data[n..] + } + if data.len() > 0 { + self.nx = (data.len()).min(self.x.len()); + self.x.copy_from_slice(data); + } + } + + pub fn sum(mut self) -> [u8; SIZE] { + let mut len = self.len; + + let mut tmp = [0; 64 + 8]; + tmp[0] = 0x80; + let t = if len%64 < 56 { + 56 - len%64 + } else { + 64 + 56 - len%64 + }; + + len <<= 3; + let padlen = &mut tmp[..(t as usize + 8)]; + padlen[(t as usize)..].copy_from_slice(&len.to_be_bytes()); + self.write(padlen); + + #[cfg(debug_assertions)] assert_eq!(self.nx, 0); + + let mut digest = [0; SIZE]; + digest[0.. 4].copy_from_slice(&self.h[0].to_be_bytes()); + digest[4.. 8].copy_from_slice(&self.h[1].to_be_bytes()); + digest[8.. 12].copy_from_slice(&self.h[2].to_be_bytes()); + digest[12..16].copy_from_slice(&self.h[3].to_be_bytes()); + digest[16.. ].copy_from_slice(&self.h[4].to_be_bytes()); + digest + } +} + +// https://github.com/golang/go/blob/master/src/crypto/sha1/sha1block.go +impl Sha1 { + fn block(&mut self, mut p: &[u8]) { + let mut w = [0u32; 16]; + + let (mut h0, mut h1, mut h2, mut h3, mut h4) = (self.h[0], self.h[1], self.h[2], self.h[3], self.h[4]); + while p.len() >= CHANK { + for i in 0..16 { + let j = i * 4; + w[i] = (p[j] as u32) << 24 | (p[j+1] as u32) << 16 | (p[j+2] as u32) << 8 | (p[j+3] as u32); + } + + let (mut a, mut b, mut c, mut d, mut e) = (h0, h1, h2, h3, h4); + + for i in 0..16 { + let f = (b & c) | ((!b) & d); + let t = a.rotate_left(5) + f + e + w[i&0xf] + K0; + (a, b, c, d, e) = (t, a, b.rotate_left(30), c, d) + } + for i in 16..20 { + let tmp = w[(i-3)&0xf] ^ w[(i-8)&0xf] ^ w[(i-14)&0xf] ^ w[(i)&0xf]; + w[i&0xf] = tmp.rotate_left(1); + + let f = (b & c) | ((!b) & d); + let t = a.rotate_left(5) + f + e + w[i & 0xf] + K0; + (a, b, c, d, e) = (t, a, b.rotate_left(30), c, d) + } + for i in 20..40 { + let tmp = w[(i-3)&0xf] ^ w[(i-8)&0xf] ^ w[(i-14)&0xf] ^ w[(i)&0xf]; + w[i&0xf] = tmp.rotate_left(1); + + let f = b ^ c ^ d; + let t = a.rotate_left(5) + f + e + w[i&0xf] + K1; + (a, b, c, d, e) = (t, a, b.rotate_left(30), c, d); + } + for i in 40..60 { + let tmp = w[(i-3)&0xf] ^ w[(i-8)&0xf] ^ w[(i-14)&0xf] ^ w[(i)&0xf]; + w[i&0xf] = tmp.rotate_left(1); + + let f = ((b | c) & d) | (b & c); + let t = a.rotate_left(5) + f + e + w[i&0xf] + K2; + (a, b, c, d, e) = (t, a, b.rotate_left(30), c, d); + } + for i in 60..80 { + let tmp = w[(i-3)&0xf] ^ w[(i-8)&0xf] ^ w[(i-14)&0xf] ^ w[(i)&0xf]; + w[i&0xf] = tmp.rotate_left(1); + + let f = b ^ c ^ d; + let t = a.rotate_left(5) + f + e + w[i&0xf] + K3; + (a, b, c, d, e) = (t, a, b.rotate_left(30), c, d); + } + + h0 += a; + h1 += b; + h2 += c; + h3 += d; + h4 += e; + + p = &p[CHANK..] + } + + (self.h[0], self.h[1], self.h[2], self.h[3], self.h[4]) = (h0, h1, h2, h3, h4) + } +} From e325b36032859ebafc5257a87f77e8ca6fb91354 Mon Sep 17 00:00:00 2001 From: kana-rus Date: Sat, 4 Nov 2023 02:55:04 +0900 Subject: [PATCH 10/37] TODO: test sha1 --- ohkami/Cargo.toml | 6 +- ohkami/src/lib.rs | 5 + ohkami/src/x_websocket/context.rs | 4 +- ohkami/src/x_websocket/sign.rs | 38 ++++- ohkami/src/x_websocket/sign/base64.rs | 203 ++++++++++++++++++++++---- 5 files changed, 218 insertions(+), 38 deletions(-) diff --git a/ohkami/Cargo.toml b/ohkami/Cargo.toml index 1d42b3ec..2842ab44 100644 --- a/ohkami/Cargo.toml +++ b/ohkami/Cargo.toml @@ -33,4 +33,8 @@ nightly = [] ##### DEBUG ##### DEBUG = ["websocket", "serde/derive", "tokio?/macros", "async-std?/attributes"] -default = ["rt_tokio", "DEBUG"] \ No newline at end of file +default = [ + "rt_tokio", + "DEBUG", + #"nightly" +] \ No newline at end of file diff --git a/ohkami/src/lib.rs b/ohkami/src/lib.rs index b69dae03..dd22f36c 100644 --- a/ohkami/src/lib.rs +++ b/ohkami/src/lib.rs @@ -182,8 +182,13 @@ #![doc(html_root_url = "https://docs.rs/ohkami")] +#![allow(incomplete_features)] #![cfg_attr(feature="nightly", feature( try_trait_v2, + generic_arg_infer, + + /* imcomplete features */ + generic_const_exprs, ))] diff --git a/ohkami/src/x_websocket/context.rs b/ohkami/src/x_websocket/context.rs index ed6a768c..b6b694a5 100644 --- a/ohkami/src/x_websocket/context.rs +++ b/ohkami/src/x_websocket/context.rs @@ -147,7 +147,5 @@ fn sign(sec_websocket_key: &str) -> String { let mut sha1 = sign::Sha1::new(); sha1.write(sec_websocket_key.as_bytes()); sha1.write(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); - - let sec_websocket_accept_bytes = sign::encode_sha1_to_base64(sha1.sum()); - unsafe {String::from_utf8_unchecked(sec_websocket_accept_bytes.to_vec())} + sign::Base64::<{sign::SHA1_SIZE}>::encode(sha1.sum()) } diff --git a/ohkami/src/x_websocket/sign.rs b/ohkami/src/x_websocket/sign.rs index be6b4693..6f4ff444 100644 --- a/ohkami/src/x_websocket/sign.rs +++ b/ohkami/src/x_websocket/sign.rs @@ -1,2 +1,36 @@ -mod sha1; pub use sha1:: Sha1; -mod base64; pub use base64::encode_sha1_to_base64; +mod sha1; pub use sha1:: {Sha1, SIZE as SHA1_SIZE}; +mod base64; pub use base64::{Base64}; + +#[cfg(test)] mod sign_test { + use super::*; + + #[test] fn test_sha1() {// https://github.com/golang/go/blob/master/src/crypto/sha1/sha1_test.go + + } + + #[test] fn test_base64() {// https://github.com/golang/go/blob/master/src/encoding/base64/base64_test.go + // RFC 3548 examples + assert_eq!(Base64::<6>::encode(*b"\x14\xfb\x9c\x03\xd9\x7e"), "FPucA9l+"); + assert_eq!(Base64::<5>::encode(*b"\x14\xfb\x9c\x03\xd9"), "FPucA9k="); + assert_eq!(Base64::<4>::encode(*b"\x14\xfb\x9c\x03"), "FPucAw=="); + + // RFC 4648 examples + assert_eq!(Base64::<0>::encode(*b""), ""); + assert_eq!(Base64::<1>::encode(*b"f"), "Zg=="); + assert_eq!(Base64::<2>::encode(*b"fo"), "Zm8="); + assert_eq!(Base64::<3>::encode(*b"foo"), "Zm9v"); + assert_eq!(Base64::<4>::encode(*b"foob"), "Zm9vYg=="); + assert_eq!(Base64::<5>::encode(*b"fooba"), "Zm9vYmE="); + assert_eq!(Base64::<6>::encode(*b"foobar"), "Zm9vYmFy"); + + // Wikipedia examples + assert_eq!(Base64::<5>::encode(*b"sure."), "c3VyZS4="); + assert_eq!(Base64::<4>::encode(*b"sure"), "c3VyZQ=="); + assert_eq!(Base64::<3>::encode(*b"sur"), "c3Vy"); + assert_eq!(Base64::<2>::encode(*b"su"), "c3U="); + assert_eq!(Base64::<8>::encode(*b"leasure."), "bGVhc3VyZS4="); + assert_eq!(Base64::<7>::encode(*b"easure."), "ZWFzdXJlLg=="); + assert_eq!(Base64::<6>::encode(*b"asure."), "YXN1cmUu"); + assert_eq!(Base64::<5>::encode(*b"sure."), "c3VyZS4="); + } +} diff --git a/ohkami/src/x_websocket/sign/base64.rs b/ohkami/src/x_websocket/sign/base64.rs index f480361b..379bad8b 100644 --- a/ohkami/src/x_websocket/sign/base64.rs +++ b/ohkami/src/x_websocket/sign/base64.rs @@ -1,50 +1,189 @@ /* https://github.com/golang/go/blob/master/src/encoding/base64/base64.go */ const ENCODER: [u8; 64] = *b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; -const PADDING: char = '='; +const PADDING: u8 = b'='; -const SIZE_FROM_SHA1: usize = 28; +pub struct Base64< + const SRC_SIZE: usize, + #[cfg(feature="nightly")] const DST_SIZE: usize = {(SRC_SIZE + 2) / 3 * 4}, + #[cfg(feature="nightly")] const SRC_SIZE_REM3_0: bool = {SRC_SIZE % 3 == 0}, + #[cfg(feature="nightly")] const SRC_SIZE_REM3_1: bool = {SRC_SIZE % 3 == 1}, + #[cfg(feature="nightly")] const SRC_SIZE_REM3_2: bool = {SRC_SIZE % 3 == 2}, +>; -pub fn encode_sha1_to_base64(sha1_bytes: [u8; super::sha1::SIZE]) -> [u8; SIZE_FROM_SHA1] { - let mut dst = [0; SIZE_FROM_SHA1]; +#[cfg(feature="nightly")] impl< + const SRC_SIZE: usize, + const DST_SIZE: usize, + const SRC_SIZE_REM3_0: bool, + const SRC_SIZE_REM3_1: bool, + const SRC_SIZE_REM3_2: bool, +> Base64 { + pub fn encode(src: [u8; SRC_SIZE]) -> String { + if SRC_SIZE == 0 {// may deleted by compiler when `SRC_SIZE` is not 0 + return String::new() + } - let (mut di, mut si) = (0, 0); - let n = (super::sha1::SIZE / 3) * 3; - while si < n { - let val = (sha1_bytes[si+0] as usize)<<16 | (sha1_bytes[si+1] as usize)<<8 | (sha1_bytes[si+2] as usize); + #[cfg(feature="nightly")] + let mut dst = vec![0; DST_SIZE]; - dst[di+0] = ENCODER[val>>18&0x3F]; - dst[di+1] = ENCODER[val>>12&0x3F]; - dst[di+2] = ENCODER[val>>6&0x3F]; - dst[di+3] = ENCODER[val&0x3F]; + let (mut di, mut si) = (0, 0); + let n = (SRC_SIZE / 3) * 3; // `n` is `SRC_SIZE - (SRC_SIZE % 3)` + while si < n { + let val = (src[si+0] as usize)<<16 | (src[si+1] as usize)<<8 | (src[si+2] as usize); - si += 3; - di += 4; - } + dst[di+0] = ENCODER[val>>18&0x3F]; + dst[di+1] = ENCODER[val>>12&0x3F]; + dst[di+2] = ENCODER[val>>6&0x3F]; + dst[di+3] = ENCODER[val&0x3F]; - let remain = super::sha1::SIZE - si; - /* unreachable because `si` is a multiple of 3 and `sha1::SIZE` is 20 */ - // if remain == 0 {return dst} + si += 3; + di += 4; + } + + if SRC_SIZE_REM3_0 {// may deleted by compiler when `SRC_SIZE` is not a multiple of 3 + return (|| unsafe {String::from_utf8_unchecked(dst)})() + } - let mut val = (sha1_bytes[si+0] as usize) << 16; - if remain == 2 { - val |= (sha1_bytes[si+1] as usize) << 8; - } + let mut val = (src[si+0] as usize) << 16; + if SRC_SIZE_REM3_2 {// may be deleted by compiler when `SRC_SIZE` is congruent to 2 mod 3 + val |= (src[si+1] as usize) << 8; + } - dst[di+0] = ENCODER[val>>18&0x3F]; - dst[di+1] = ENCODER[val>>12&0x3F]; + dst[di+0] = ENCODER[val>>18&0x3F]; + dst[di+1] = ENCODER[val>>12&0x3F]; - match remain { - 2 => { + if SRC_SIZE_REM3_2 {// may be deleted by compiler when `SRC_SIZE` is congruent to 2 mod 3 dst[di+2] = ENCODER[val>>6&0x3F]; - dst[di+3] = b'='; + dst[di+3] = PADDING; } - 1 => { - dst[di+2] = b'='; - dst[di+3] = b'='; + if SRC_SIZE_REM3_1 {// may be deleted by compiler when `SRC_SIZE` is congruent to 1 mod 3 + dst[di+2] = PADDING; + dst[di+3] = PADDING; } - _ => unsafe {std::hint::unreachable_unchecked()} + + unsafe {String::from_utf8_unchecked(dst)} } +} + + +#[cfg(not(feature="nightly"))] impl< + const SRC_SIZE: usize, +> Base64 { + pub fn encode(src: [u8; SRC_SIZE]) -> String { + if SRC_SIZE == 0 {// may deleted by compiler when `SRC_SIZE` is not 0 + return String::new() + } + + let mut dst = vec![0; (SRC_SIZE + 2) / 3 * 4]; + + let (mut di, mut si) = (0, 0); + let n = (SRC_SIZE / 3) * 3; // `n` is `SRC_SIZE - (SRC_SIZE % 3)` + while si < n { + let val = (src[si+0] as usize)<<16 | (src[si+1] as usize)<<8 | (src[si+2] as usize); - dst + dst[di+0] = ENCODER[val>>18&0x3F]; + dst[di+1] = ENCODER[val>>12&0x3F]; + dst[di+2] = ENCODER[val>>6&0x3F]; + dst[di+3] = ENCODER[val&0x3F]; + + si += 3; + di += 4; + } + + let remain = SRC_SIZE - si; // `remain` is `SRC_SIZE % 3` + if remain == 0 { + return (|| unsafe {String::from_utf8_unchecked(dst)})() + } + + let mut val = (src[si+0] as usize) << 16; + if remain == 2 { + val |= (src[si+1] as usize) << 8; + } + + dst[di+0] = ENCODER[val>>18&0x3F]; + dst[di+1] = ENCODER[val>>12&0x3F]; + + match remain { + 2 => { + dst[di+2] = ENCODER[val>>6&0x3F]; + dst[di+3] = PADDING; + } + 1 => { + dst[di+2] = PADDING; + dst[di+3] = PADDING; + } + _ => unsafe {std::hint::unreachable_unchecked()} + } + + unsafe {String::from_utf8_unchecked(dst)} + } } + +//} + +//impl Base64 { +// pub fn encode(src: [u8; SRC_SIZE]) -> String { +// if src.len() == 0 { +// return String::new() +// } +// +// let mut dst = vec![0; (src.len() + 2) / 3 * 4]; +// +// let (mut di, mut si) = (0, 0); +// let n = (SRC_SIZE / 3) * 3; // `n` is `SRC_SIZE - (SRC_SIZE % 3)` +// while si < n { +// let val = (src[si+0] as usize)<<16 | (src[si+1] as usize)<<8 | (src[si+2] as usize); +// +// dst[di+0] = ENCODER[val>>18&0x3F]; +// dst[di+1] = ENCODER[val>>12&0x3F]; +// dst[di+2] = ENCODER[val>>6&0x3F]; +// dst[di+3] = ENCODER[val&0x3F]; +// +// si += 3; +// di += 4; +// } +// +// #[cfg(feature="nightly")] if SRC_SIZE_REM3_0 {// may deleted by compiler when `SRC_SIZE` is not a multiple of 3 +// return (|| unsafe {String::from_utf8_unchecked(dst)})() +// } +// +// #[cfg(not(feature="nightly"))] let remain = SRC_SIZE - si; // `remain` is `SRC_SIZE % 3` +// #[cfg(not(feature="nightly"))] { +// if remain == 0 {return (|| unsafe {String::from_utf8_unchecked(dst)})()} +// } +// +// let mut val = (src[si+0] as usize) << 16; +// #[cfg(feature="nightly")] if SRC_SIZE_REM3_2 {// may be deleted by compiler when `SRC_SIZE` is congruent to 2 mod 3 +// val |= (src[si+1] as usize) << 8; +// } +// #[cfg(not(feature="nightly"))] if remain == 2 { +// val |= (src[si+1] as usize) << 8; +// } +// +// dst[di+0] = ENCODER[val>>18&0x3F]; +// dst[di+1] = ENCODER[val>>12&0x3F]; +// +// #[cfg(feature="nightly")] if SRC_SIZE_REM3_2 {// may be deleted by compiler when `SRC_SIZE` is congruent to 2 mod 3 +// dst[di+2] = ENCODER[val>>6&0x3F]; +// dst[di+3] = PADDING; +// } +// #[cfg(feature="nightly")] if SRC_SIZE_REM3_1 {// may be deleted by compiler when `SRC_SIZE` is congruent to 1 mod 3 +// dst[di+2] = PADDING; +// dst[di+3] = PADDING; +// } +// #[cfg(not(feature="nightly"))] match remain { +// 2 => { +// dst[di+2] = ENCODER[val>>6&0x3F]; +// dst[di+3] = PADDING; +// } +// 1 => { +// dst[di+2] = PADDING; +// dst[di+3] = PADDING; +// } +// _ => unsafe {std::hint::unreachable_unchecked()} +// } +// +// unsafe {String::from_utf8_unchecked(dst)} +// } +//} +// \ No newline at end of file From 4bfde186c8b78c3dd7af0aad13e2fdaf293d2a1b Mon Sep 17 00:00:00 2001 From: kana-rus Date: Sat, 4 Nov 2023 18:17:26 +0900 Subject: [PATCH 11/37] TODO: test sha1 --- ohkami/src/x_websocket/sign.rs | 59 +++++++++++++++++++++++++++++ ohkami/src/x_websocket/sign/sha1.rs | 37 +++++++++--------- 2 files changed, 78 insertions(+), 18 deletions(-) diff --git a/ohkami/src/x_websocket/sign.rs b/ohkami/src/x_websocket/sign.rs index 6f4ff444..c76e0936 100644 --- a/ohkami/src/x_websocket/sign.rs +++ b/ohkami/src/x_websocket/sign.rs @@ -5,7 +5,66 @@ mod base64; pub use base64::{Base64}; use super::*; #[test] fn test_sha1() {// https://github.com/golang/go/blob/master/src/crypto/sha1/sha1_test.go + for (encoded/* hex literal */, input) in [ + //("76245dbf96f661bd221046197ab8b9f063f11bad", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\n"),//, "sha\x01\v\xa0)I\xdeq(8h\x9ev\xe5\x88[\xf8\x81\x17\xba4Daaaaaaaaaaaaaaaaaaaaaa\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x96"}, + ("da39a3ee5e6b4b0d3255bfef95601890afd80709", ""), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, + //("86f7e437faa5a7fce15d1ddcb9eaeaea377667b8", "a"), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, + //("da23614e02469a0d7c7bd1bdab5c9c474b1904dc", "ab"), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"}, + //("a9993e364706816aba3e25717850c26c9cd0d89d", "abc"), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"}, + //("81fe8bfe87576c3ecb22426f8e57847382917acf", "abcd"), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0ab\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"}, + //("03de6c570bfe24bfc328ccd7ca46b76eadaf4334", "abcde"), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0ab\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"}, + //("1f8ac10f23c5b5bc1167bda84b833e5c057a77d2", "abcdef"), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0abc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03"}, + //("2fb5e13419fc89246865e7a324f476ec624e8740", "abcdefg"), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0abc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03"}, + //("425af12a0743502b322e93a015bcf868e324d56a", "abcdefgh"), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0abcd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04"}, + //("c63b19f1e4c8b5f76b25c49b8b87f57d8e4872a1", "abcdefghi"), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0abcd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04"}, + //("d68c19a0a345b7eab78d5e11e991c026ec60db63", "abcdefghij"), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0abcde\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x05"}, + //("ebf81ddcbe5bf13aaabdc4d65354fdf2044f38a7", "Discard medicine more than two years old."), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0Discard medicine mor\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x14"}, + //("e5dea09392dd886ca63531aaa00571dc07554bb6", "He who has a shady past knows that nice guys finish last."), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0He who has a shady past know\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1c"}, + //("45988f7234467b94e3e9494434c96ee3609d8f8f", "I wouldn't marry him with a ten foot pole."), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0I wouldn't marry him \x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x15"}, + //("55dee037eb7460d5a692d1ce11330b260e40c988", "Free! Free!/A trip/to Mars/for 900/empty jars/Burma Shave"), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0Free! Free!/A trip/to Mars/f\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1c"}, + //("b7bc5fb91080c7de6b582ea281f8a396d7c0aee8", "The days of the digital watch are numbered. -Tom Stoppard"), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0The days of the digital watch\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1d"}, + //("c3aed9358f7c77f523afe86135f06b95b3999797", "Nepal premier won't resign."), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0Nepal premier\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\r"}, + //("6e29d302bf6e3a5e4305ff318d983197d6906bb9", "For every action there is an equal and opposite government program."), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0For every action there is an equa\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00!"}, + //("597f6a540010f94c15d71806a99a2c8710e747bd", "His money is twice tainted: 'taint yours and 'taint mine."), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0His money is twice tainted: \x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1c"}, + //("6859733b2590a8a091cecf50086febc5ceef1e80", "There is no reason for any individual to have a computer in their home. -Ken Olsen, 1977"), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0There is no reason for any individual to hav\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00,"}, + //("514b2630ec089b8aee18795fc0cf1f4860cdacad", "It's a tiny change to the code and not completely disgusting. - Bob Manchek"), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0It's a tiny change to the code and no\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00%"}, + //("c5ca0d4a7b6676fc7aa72caa41cc3d5df567ed69", "size: a.out: bad magic"), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0size: a.out\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\f"}, + //("74c51fa9a04eadc8c1bbeaa7fc442f834b90a00a", "The major problem is with sendmail. -Mark Horton"), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0The major problem is wit\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x18"}, + //("0b4c4ce5f52c3ad2821852a8dc00217fa18b8b66", "Give me a rock, paper and scissors and I will move the world. CCFestoon"), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0Give me a rock, paper and scissors a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00$"}, + //("3ae7937dd790315beb0f48330e8642237c61550a", "If the enemy is within range, then so are you."), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0If the enemy is within \x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x17"}, + //("410a2b296df92b9a47412b13281df8f830a9f44b", "It's well we cannot hear the screams/That we create in others' dreams."), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0It's well we cannot hear the scream\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00#"}, + //("841e7c85ca1adcddbdd0187f1289acb5c642f7f5", "You remind me of a TV show, but that's all right: I watch it anyway."), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0You remind me of a TV show, but th\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\""}, + //("163173b825d03b952601376b25212df66763e1db", "C is as portable as Stonehedge!!"), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0C is as portable\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x10"}, + //("32b0377f2687eb88e22106f133c586ab314d5279", "Even if I could be Shakespeare, I think I should still choose to be Faraday. - A. Huxley"), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0Even if I could be Shakespeare, I think I sh\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00,"}, + //("0885aaf99b569542fd165fa44e322718f4a984e0", "The fugacity of a constituent in a mixture of gases at a given temperature is proportional to its mole fraction. Lewis-Randall Rule"), //"sha\x01x}\xf4\r\xeb\xf2\x10\x87\xe8[\xb2JA$D\xb7\u063ax8em\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00B"}, + //("6627d6904d71420b0bf3886ab629623538689f45", "How can you write a big system without C++? -Paul Glick"), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0How can you write a big syst\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1c"}, + ] { + let expected = std::array::from_fn(|i| i).map(|i| -> [u8; 8] { + [encoded.as_bytes()[2*i], encoded.as_bytes()[2*i+1]].map(|b| match b { + b'0' => [0, 0, 0, 0], + b'1' => [0, 0, 0, 1], + b'2' => [0, 0, 1, 0], + b'3' => [0, 0, 1, 1], + b'4' => [0, 1, 0, 0], + b'5' => [0, 1, 0, 1], + b'6' => [0, 1, 1, 0], + b'7' => [0, 1, 1, 1], + b'8' => [1, 0, 0, 0], + b'9' => [1, 0, 0, 1], + b'a' => [1, 0, 1, 0], + b'b' => [1, 0, 1, 1], + b'c' => [1, 1, 0, 0], + b'd' => [1, 1, 0, 1], + b'e' => [1, 1, 1, 0], + b'f' => [1, 1, 1, 1], + _ => unreachable!() + }).concat().try_into().unwrap() + }).map(|bits| bits.into_iter().fold(0, |byte, b| byte * 2 + b)); + let mut s = Sha1::new(); + s.write(input.as_bytes()); + assert_eq!(s.sum(), expected); + } } #[test] fn test_base64() {// https://github.com/golang/go/blob/master/src/encoding/base64/base64_test.go diff --git a/ohkami/src/x_websocket/sign/sha1.rs b/ohkami/src/x_websocket/sign/sha1.rs index 88b55800..358ab3a7 100644 --- a/ohkami/src/x_websocket/sign/sha1.rs +++ b/ohkami/src/x_websocket/sign/sha1.rs @@ -24,27 +24,26 @@ impl Sha1 { } } - pub fn write(&mut self, mut data: &[u8]) { - self.len += data.len() as u64; + pub fn write(&mut self, mut p: &[u8]) { + self.len += p.len() as u64; if self.nx > 0 { - let n = (CHANK - self.nx).min(data.len()); - self.x[self.nx..(self.nx + n)].copy_from_slice(&data[..n]); + let n = (CHANK - self.nx).min(p.len()); + self.x[self.nx..(self.nx + n)].copy_from_slice(&p[..n]); self.nx += n; if self.nx == CHANK { - let mut p = [0; CHANK]; p.copy_from_slice(&self.x); - self.block(&p); + self.block(&self.x.clone()); self.nx = 0; } - data = &data[n..] + p = &p[n..] } - if data.len() >= CHANK { - let n = data.len() & (!(CHANK - 1)); - self.block(&data[..n]); - data = &data[n..] + if p.len() >= CHANK { + let n = p.len() & (!(CHANK - 1)); + self.block(&p[..n]); + p = &p[n..] } - if data.len() > 0 { - self.nx = (data.len()).min(self.x.len()); - self.x.copy_from_slice(data); + if p.len() > 0 { + self.nx = self.x.len().min(p.len()); + self.x.copy_from_slice(p); } } @@ -60,9 +59,11 @@ impl Sha1 { }; len <<= 3; - let padlen = &mut tmp[..(t as usize + 8)]; - padlen[(t as usize)..].copy_from_slice(&len.to_be_bytes()); - self.write(padlen); + //let padlen = &mut tmp[..(t as usize + 8)]; + //padlen[(t as usize)..].copy_from_slice(&len.to_be_bytes()); + //self.write(padlen); + tmp[(t as usize)..(t as usize + 8)].copy_from_slice(&len.to_be_bytes()); + self.write(&tmp[..(t as usize + 8)]); #[cfg(debug_assertions)] assert_eq!(self.nx, 0); @@ -92,7 +93,7 @@ impl Sha1 { for i in 0..16 { let f = (b & c) | ((!b) & d); - let t = a.rotate_left(5) + f + e + w[i&0xf] + K0; + let t = dbg!(a.rotate_left(5)) + dbg!(f) + e + w[i&0xf] + K0; (a, b, c, d, e) = (t, a, b.rotate_left(30), c, d) } for i in 16..20 { From e59d2332c160792e75f33ecde3ed60f839543156 Mon Sep 17 00:00:00 2001 From: kana-rus Date: Sun, 5 Nov 2023 18:36:51 +0900 Subject: [PATCH 12/37] passed --- ohkami/src/x_websocket/sign.rs | 90 ++++++++++++----------------- ohkami/src/x_websocket/sign/sha1.rs | 54 ++++++++++------- 2 files changed, 72 insertions(+), 72 deletions(-) diff --git a/ohkami/src/x_websocket/sign.rs b/ohkami/src/x_websocket/sign.rs index c76e0936..c0dd7ef0 100644 --- a/ohkami/src/x_websocket/sign.rs +++ b/ohkami/src/x_websocket/sign.rs @@ -6,60 +6,46 @@ mod base64; pub use base64::{Base64}; #[test] fn test_sha1() {// https://github.com/golang/go/blob/master/src/crypto/sha1/sha1_test.go for (encoded/* hex literal */, input) in [ - //("76245dbf96f661bd221046197ab8b9f063f11bad", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\n"),//, "sha\x01\v\xa0)I\xdeq(8h\x9ev\xe5\x88[\xf8\x81\x17\xba4Daaaaaaaaaaaaaaaaaaaaaa\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x96"}, - ("da39a3ee5e6b4b0d3255bfef95601890afd80709", ""), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - //("86f7e437faa5a7fce15d1ddcb9eaeaea377667b8", "a"), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - //("da23614e02469a0d7c7bd1bdab5c9c474b1904dc", "ab"), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"}, - //("a9993e364706816aba3e25717850c26c9cd0d89d", "abc"), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"}, - //("81fe8bfe87576c3ecb22426f8e57847382917acf", "abcd"), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0ab\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"}, - //("03de6c570bfe24bfc328ccd7ca46b76eadaf4334", "abcde"), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0ab\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"}, - //("1f8ac10f23c5b5bc1167bda84b833e5c057a77d2", "abcdef"), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0abc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03"}, - //("2fb5e13419fc89246865e7a324f476ec624e8740", "abcdefg"), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0abc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03"}, - //("425af12a0743502b322e93a015bcf868e324d56a", "abcdefgh"), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0abcd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04"}, - //("c63b19f1e4c8b5f76b25c49b8b87f57d8e4872a1", "abcdefghi"), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0abcd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04"}, - //("d68c19a0a345b7eab78d5e11e991c026ec60db63", "abcdefghij"), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0abcde\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x05"}, - //("ebf81ddcbe5bf13aaabdc4d65354fdf2044f38a7", "Discard medicine more than two years old."), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0Discard medicine mor\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x14"}, - //("e5dea09392dd886ca63531aaa00571dc07554bb6", "He who has a shady past knows that nice guys finish last."), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0He who has a shady past know\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1c"}, - //("45988f7234467b94e3e9494434c96ee3609d8f8f", "I wouldn't marry him with a ten foot pole."), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0I wouldn't marry him \x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x15"}, - //("55dee037eb7460d5a692d1ce11330b260e40c988", "Free! Free!/A trip/to Mars/for 900/empty jars/Burma Shave"), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0Free! Free!/A trip/to Mars/f\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1c"}, - //("b7bc5fb91080c7de6b582ea281f8a396d7c0aee8", "The days of the digital watch are numbered. -Tom Stoppard"), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0The days of the digital watch\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1d"}, - //("c3aed9358f7c77f523afe86135f06b95b3999797", "Nepal premier won't resign."), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0Nepal premier\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\r"}, - //("6e29d302bf6e3a5e4305ff318d983197d6906bb9", "For every action there is an equal and opposite government program."), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0For every action there is an equa\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00!"}, - //("597f6a540010f94c15d71806a99a2c8710e747bd", "His money is twice tainted: 'taint yours and 'taint mine."), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0His money is twice tainted: \x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1c"}, - //("6859733b2590a8a091cecf50086febc5ceef1e80", "There is no reason for any individual to have a computer in their home. -Ken Olsen, 1977"), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0There is no reason for any individual to hav\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00,"}, - //("514b2630ec089b8aee18795fc0cf1f4860cdacad", "It's a tiny change to the code and not completely disgusting. - Bob Manchek"), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0It's a tiny change to the code and no\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00%"}, - //("c5ca0d4a7b6676fc7aa72caa41cc3d5df567ed69", "size: a.out: bad magic"), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0size: a.out\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\f"}, - //("74c51fa9a04eadc8c1bbeaa7fc442f834b90a00a", "The major problem is with sendmail. -Mark Horton"), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0The major problem is wit\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x18"}, - //("0b4c4ce5f52c3ad2821852a8dc00217fa18b8b66", "Give me a rock, paper and scissors and I will move the world. CCFestoon"), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0Give me a rock, paper and scissors a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00$"}, - //("3ae7937dd790315beb0f48330e8642237c61550a", "If the enemy is within range, then so are you."), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0If the enemy is within \x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x17"}, - //("410a2b296df92b9a47412b13281df8f830a9f44b", "It's well we cannot hear the screams/That we create in others' dreams."), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0It's well we cannot hear the scream\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00#"}, - //("841e7c85ca1adcddbdd0187f1289acb5c642f7f5", "You remind me of a TV show, but that's all right: I watch it anyway."), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0You remind me of a TV show, but th\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\""}, - //("163173b825d03b952601376b25212df66763e1db", "C is as portable as Stonehedge!!"), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0C is as portable\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x10"}, - //("32b0377f2687eb88e22106f133c586ab314d5279", "Even if I could be Shakespeare, I think I should still choose to be Faraday. - A. Huxley"), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0Even if I could be Shakespeare, I think I sh\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00,"}, - //("0885aaf99b569542fd165fa44e322718f4a984e0", "The fugacity of a constituent in a mixture of gases at a given temperature is proportional to its mole fraction. Lewis-Randall Rule"), //"sha\x01x}\xf4\r\xeb\xf2\x10\x87\xe8[\xb2JA$D\xb7\u063ax8em\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00B"}, - //("6627d6904d71420b0bf3886ab629623538689f45", "How can you write a big system without C++? -Paul Glick"), //"sha\x01gE#\x01\xef\u036b\x89\x98\xba\xdc\xfe\x102Tv\xc3\xd2\xe1\xf0How can you write a big syst\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1c"}, + ("76245dbf96f661bd221046197ab8b9f063f11bad", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\n"), + ("da39a3ee5e6b4b0d3255bfef95601890afd80709", ""), + ("86f7e437faa5a7fce15d1ddcb9eaeaea377667b8", "a"), + ("da23614e02469a0d7c7bd1bdab5c9c474b1904dc", "ab"), + ("a9993e364706816aba3e25717850c26c9cd0d89d", "abc"), + ("81fe8bfe87576c3ecb22426f8e57847382917acf", "abcd"), + ("03de6c570bfe24bfc328ccd7ca46b76eadaf4334", "abcde"), + ("1f8ac10f23c5b5bc1167bda84b833e5c057a77d2", "abcdef"), + ("2fb5e13419fc89246865e7a324f476ec624e8740", "abcdefg"), + ("425af12a0743502b322e93a015bcf868e324d56a", "abcdefgh"), + ("c63b19f1e4c8b5f76b25c49b8b87f57d8e4872a1", "abcdefghi"), + ("d68c19a0a345b7eab78d5e11e991c026ec60db63", "abcdefghij"), + ("ebf81ddcbe5bf13aaabdc4d65354fdf2044f38a7", "Discard medicine more than two years old."), + ("e5dea09392dd886ca63531aaa00571dc07554bb6", "He who has a shady past knows that nice guys finish last."), + ("45988f7234467b94e3e9494434c96ee3609d8f8f", "I wouldn't marry him with a ten foot pole."), + ("55dee037eb7460d5a692d1ce11330b260e40c988", "Free! Free!/A trip/to Mars/for 900/empty jars/Burma Shave"), + ("b7bc5fb91080c7de6b582ea281f8a396d7c0aee8", "The days of the digital watch are numbered. -Tom Stoppard"), + ("c3aed9358f7c77f523afe86135f06b95b3999797", "Nepal premier won't resign."), + ("6e29d302bf6e3a5e4305ff318d983197d6906bb9", "For every action there is an equal and opposite government program."), + ("597f6a540010f94c15d71806a99a2c8710e747bd", "His money is twice tainted: 'taint yours and 'taint mine."), + ("6859733b2590a8a091cecf50086febc5ceef1e80", "There is no reason for any individual to have a computer in their home. -Ken Olsen, 1977"), + ("514b2630ec089b8aee18795fc0cf1f4860cdacad", "It's a tiny change to the code and not completely disgusting. - Bob Manchek"), + ("c5ca0d4a7b6676fc7aa72caa41cc3d5df567ed69", "size: a.out: bad magic"), + ("74c51fa9a04eadc8c1bbeaa7fc442f834b90a00a", "The major problem is with sendmail. -Mark Horton"), + ("0b4c4ce5f52c3ad2821852a8dc00217fa18b8b66", "Give me a rock, paper and scissors and I will move the world. CCFestoon"), + ("3ae7937dd790315beb0f48330e8642237c61550a", "If the enemy is within range, then so are you."), + ("410a2b296df92b9a47412b13281df8f830a9f44b", "It's well we cannot hear the screams/That we create in others' dreams."), + ("841e7c85ca1adcddbdd0187f1289acb5c642f7f5", "You remind me of a TV show, but that's all right: I watch it anyway."), + ("163173b825d03b952601376b25212df66763e1db", "C is as portable as Stonehedge!!"), + ("32b0377f2687eb88e22106f133c586ab314d5279", "Even if I could be Shakespeare, I think I should still choose to be Faraday. - A. Huxley"), + ("0885aaf99b569542fd165fa44e322718f4a984e0", "The fugacity of a constituent in a mixture of gases at a given temperature is proportional to its mole fraction. Lewis-Randall Rule"), + ("6627d6904d71420b0bf3886ab629623538689f45", "How can you write a big system without C++? -Paul Glick"), ] { - let expected = std::array::from_fn(|i| i).map(|i| -> [u8; 8] { - [encoded.as_bytes()[2*i], encoded.as_bytes()[2*i+1]].map(|b| match b { - b'0' => [0, 0, 0, 0], - b'1' => [0, 0, 0, 1], - b'2' => [0, 0, 1, 0], - b'3' => [0, 0, 1, 1], - b'4' => [0, 1, 0, 0], - b'5' => [0, 1, 0, 1], - b'6' => [0, 1, 1, 0], - b'7' => [0, 1, 1, 1], - b'8' => [1, 0, 0, 0], - b'9' => [1, 0, 0, 1], - b'a' => [1, 0, 1, 0], - b'b' => [1, 0, 1, 1], - b'c' => [1, 1, 0, 0], - b'd' => [1, 1, 0, 1], - b'e' => [1, 1, 1, 0], - b'f' => [1, 1, 1, 1], + let expected = std::array::from_fn(|i| i).map(|i| + [encoded.as_bytes()[2*i], encoded.as_bytes()[2*i+1]].map(|b| match b { + b @ b'0'..=b'9' => b - b'0', + b @ b'a'..=b'f' => 10 + b - b'a', _ => unreachable!() - }).concat().try_into().unwrap() - }).map(|bits| bits.into_iter().fold(0, |byte, b| byte * 2 + b)); + }).into_iter().fold(0, |byte, b| byte * 2_u8.pow(4) + b) + ); let mut s = Sha1::new(); s.write(input.as_bytes()); diff --git a/ohkami/src/x_websocket/sign/sha1.rs b/ohkami/src/x_websocket/sign/sha1.rs index 358ab3a7..b68d3f86 100644 --- a/ohkami/src/x_websocket/sign/sha1.rs +++ b/ohkami/src/x_websocket/sign/sha1.rs @@ -1,11 +1,15 @@ +#[cfg(not(target_pointer_width = "64"))] +compile_error!{ "pointer width must be 64" } + pub const CHANK: usize = 64; pub const SIZE: usize = 20; // bytes; 160 bits +#[derive(Debug)] pub struct Sha1 { h: [u32; 5], x: [u8; CHANK], nx: usize, - len: u64, + len: usize, } const K0: u32 = 0x5A827999; @@ -25,7 +29,7 @@ impl Sha1 { } pub fn write(&mut self, mut p: &[u8]) { - self.len += p.len() as u64; + self.len += p.len(); if self.nx > 0 { let n = (CHANK - self.nx).min(p.len()); self.x[self.nx..(self.nx + n)].copy_from_slice(&p[..n]); @@ -42,8 +46,9 @@ impl Sha1 { p = &p[n..] } if p.len() > 0 { - self.nx = self.x.len().min(p.len()); - self.x.copy_from_slice(p); + let n = (self.x.len()).min(p.len()); + self.nx = n; + self.x[..n].copy_from_slice(&p[..n]); } } @@ -59,11 +64,8 @@ impl Sha1 { }; len <<= 3; - //let padlen = &mut tmp[..(t as usize + 8)]; - //padlen[(t as usize)..].copy_from_slice(&len.to_be_bytes()); - //self.write(padlen); - tmp[(t as usize)..(t as usize + 8)].copy_from_slice(&len.to_be_bytes()); - self.write(&tmp[..(t as usize + 8)]); + tmp[t..(t + 8)].copy_from_slice(&len.to_be_bytes()); + self.write(&tmp[..(t + 8)]); #[cfg(debug_assertions)] assert_eq!(self.nx, 0); @@ -72,7 +74,7 @@ impl Sha1 { digest[4.. 8].copy_from_slice(&self.h[1].to_be_bytes()); digest[8.. 12].copy_from_slice(&self.h[2].to_be_bytes()); digest[12..16].copy_from_slice(&self.h[3].to_be_bytes()); - digest[16.. ].copy_from_slice(&self.h[4].to_be_bytes()); + digest[16..20].copy_from_slice(&self.h[4].to_be_bytes()); digest } } @@ -80,6 +82,18 @@ impl Sha1 { // https://github.com/golang/go/blob/master/src/crypto/sha1/sha1block.go impl Sha1 { fn block(&mut self, mut p: &[u8]) { + fn wrapping_sum(u32_1: u32, u32_2: u32, u32_3: u32, u32_4: u32, u32_5: u32) -> u32 { + u32_1.wrapping_add( + u32_2.wrapping_add( + u32_3.wrapping_add( + u32_4.wrapping_add( + u32_5 + ) + ) + ) + ) + } + let mut w = [0u32; 16]; let (mut h0, mut h1, mut h2, mut h3, mut h4) = (self.h[0], self.h[1], self.h[2], self.h[3], self.h[4]); @@ -93,7 +107,7 @@ impl Sha1 { for i in 0..16 { let f = (b & c) | ((!b) & d); - let t = dbg!(a.rotate_left(5)) + dbg!(f) + e + w[i&0xf] + K0; + let t = wrapping_sum(a.rotate_left(5), f, e, w[i&0xf], K0); (a, b, c, d, e) = (t, a, b.rotate_left(30), c, d) } for i in 16..20 { @@ -101,7 +115,7 @@ impl Sha1 { w[i&0xf] = tmp.rotate_left(1); let f = (b & c) | ((!b) & d); - let t = a.rotate_left(5) + f + e + w[i & 0xf] + K0; + let t = wrapping_sum(a.rotate_left(5), f, e, w[i & 0xf], K0); (a, b, c, d, e) = (t, a, b.rotate_left(30), c, d) } for i in 20..40 { @@ -109,7 +123,7 @@ impl Sha1 { w[i&0xf] = tmp.rotate_left(1); let f = b ^ c ^ d; - let t = a.rotate_left(5) + f + e + w[i&0xf] + K1; + let t = wrapping_sum(a.rotate_left(5), f, e, w[i&0xf], K1); (a, b, c, d, e) = (t, a, b.rotate_left(30), c, d); } for i in 40..60 { @@ -117,7 +131,7 @@ impl Sha1 { w[i&0xf] = tmp.rotate_left(1); let f = ((b | c) & d) | (b & c); - let t = a.rotate_left(5) + f + e + w[i&0xf] + K2; + let t = wrapping_sum(a.rotate_left(5), f, e, w[i&0xf], K2); (a, b, c, d, e) = (t, a, b.rotate_left(30), c, d); } for i in 60..80 { @@ -125,15 +139,15 @@ impl Sha1 { w[i&0xf] = tmp.rotate_left(1); let f = b ^ c ^ d; - let t = a.rotate_left(5) + f + e + w[i&0xf] + K3; + let t = wrapping_sum(a.rotate_left(5), f, e, w[i&0xf], K3); (a, b, c, d, e) = (t, a, b.rotate_left(30), c, d); } - h0 += a; - h1 += b; - h2 += c; - h3 += d; - h4 += e; + h0 = h0.wrapping_add(a); + h1 = h1.wrapping_add(b); + h2 = h2.wrapping_add(c); + h3 = h3.wrapping_add(d); + h4 = h4.wrapping_add(e); p = &p[CHANK..] } From 81dbfc363a8831b836d9f26326606653a6f420a6 Mon Sep 17 00:00:00 2001 From: kana-rus Date: Mon, 6 Nov 2023 19:03:09 +0900 Subject: [PATCH 13/37] TODO: parse web socket message --- ohkami/src/layer5_ohkami/howl.rs | 10 +++++----- ohkami/src/lib.rs | 5 +++++ ohkami/src/x_websocket/message.rs | 32 ++++++++++++++++++++++++++++++- ohkami/src/x_websocket/mod.rs | 17 ++++++++++++++++ 4 files changed, 58 insertions(+), 6 deletions(-) diff --git a/ohkami/src/layer5_ohkami/howl.rs b/ohkami/src/layer5_ohkami/howl.rs index 9aa33e57..b2e3948c 100644 --- a/ohkami/src/layer5_ohkami/howl.rs +++ b/ohkami/src/layer5_ohkami/howl.rs @@ -65,10 +65,10 @@ impl Ohkami { } )); - let router = Arc::clone(&router); - let c = Context::new(); - - if let Err(e) = __rt__::task::spawn({let stream = stream.clone(); + if let Err(e) = __rt__::task::spawn({ + let router = router.clone(); + let stream = stream.clone(); + async move { let stream = &mut *stream.lock().await; @@ -76,7 +76,7 @@ impl Ohkami { let mut req = unsafe {Pin::new_unchecked(&mut req)}; req.as_mut().read(stream).await; - let res = router.handle(c, req.get_mut()).await; + let res = router.handle(Context::new(), req.get_mut()).await; res.send(stream).await } }).await { diff --git a/ohkami/src/lib.rs b/ohkami/src/lib.rs index dd22f36c..d2f0c27b 100644 --- a/ohkami/src/lib.rs +++ b/ohkami/src/lib.rs @@ -228,6 +228,11 @@ mod __rt__ { #[cfg(feature="rt_async-std")] pub(crate) use async_std::net::TcpListener; + #[cfg(feature="rt_tokio")] + pub(crate) use tokio::net::TcpStream; + #[cfg(feature="rt_async-std")] + pub(crate) use async_std::net::TcpStream; + #[cfg(feature="rt_tokio")] pub(crate) use tokio::task; #[cfg(feature="rt_async-std")] diff --git a/ohkami/src/x_websocket/message.rs b/ohkami/src/x_websocket/message.rs index f8510b70..2482c6c8 100644 --- a/ohkami/src/x_websocket/message.rs +++ b/ohkami/src/x_websocket/message.rs @@ -1,4 +1,5 @@ -use std::borrow::Cow; +use std::{borrow::Cow, io::Result}; +use crate::{__rt__::AsyncReader}; pub enum Message { @@ -16,3 +17,32 @@ pub struct CloseFrame { pub code: u16, pub reason: Cow<'static, str>, } + +const _: (/* `From` impls */) = { + impl From<&str> for Message { + fn from(string: &str) -> Self { + Self::Text(string.to_string()) + } + } + impl From for Message { + fn from(string: String) -> Self { + Self::Text(string) + } + } + impl From<&[u8]> for Message { + fn from(data: &[u8]) -> Self { + Self::Binary(data.to_vec()) + } + } + impl From> for Message { + fn from(data: Vec) -> Self { + Self::Binary(data) + } + } +}; + +impl Message { + pub(super) async fn from(stream: impl AsyncReader + Unpin) -> Result { + + } +} diff --git a/ohkami/src/x_websocket/mod.rs b/ohkami/src/x_websocket/mod.rs index c89082fd..3c7fb498 100644 --- a/ohkami/src/x_websocket/mod.rs +++ b/ohkami/src/x_websocket/mod.rs @@ -2,7 +2,24 @@ mod context; mod message; mod sign; +use std::{sync::Arc, io::Result}; +use crate::__rt__::{TcpStream, Mutex, AsyncReader, AsyncWriter}; +use self::message::Message; + pub struct WebSocket { + stream: Arc>, +} + +impl WebSocket { + fn new(stream: Arc>) -> Self { + Self { stream } + } +} +impl WebSocket { + //pub async fn recv(&self) -> Option> { + // ( self.stream.lock().await) + // .rea + //} } From 3403dc12127153976f58d9f409222f7e75450252 Mon Sep 17 00:00:00 2001 From: kana-rus Date: Mon, 6 Nov 2023 23:32:56 +0900 Subject: [PATCH 14/37] @2023-11-06 23:32+9:00 --- ohkami/src/x_websocket/frame.rs | 84 +++++++++++++++++++++++++++++++ ohkami/src/x_websocket/message.rs | 6 +-- ohkami/src/x_websocket/mod.rs | 1 + 3 files changed, 88 insertions(+), 3 deletions(-) create mode 100644 ohkami/src/x_websocket/frame.rs diff --git a/ohkami/src/x_websocket/frame.rs b/ohkami/src/x_websocket/frame.rs new file mode 100644 index 00000000..d3ba0dde --- /dev/null +++ b/ohkami/src/x_websocket/frame.rs @@ -0,0 +1,84 @@ +use std::io::{Error, ErrorKind}; +use crate::__rt__::{AsyncReader}; + + +pub enum OpCode { + /* data op codes */ + Continue /* 0x0 */, + Text /* 0x1 */, + Binary /* 0x2 */, + /* control op codes */ + Close /* 0x8 */, + Ping /* 0x9 */, + Pong /* 0xa */, + /* reserved op codes */ + Reserved /* 0x[3-7,b-f] */, +} impl From for OpCode { + fn from(byte: u8) -> Self {match byte { + 0x0 => Self::Continue, 0x1 => Self::Text, 0x2 => Self::Binary, + 0x8 => Self::Close, 0x9 => Self::Ping, 0xa => Self::Pong, + 0x3..=0x7 | 0xb..=0xf => Self::Reserved, + _ => panic!("OpCode out of range: {byte}") + }} +} + +pub enum CloseCode { + Normal, Away, Protocol, Unsupported, Status, Abnormal, Invalid, + Policy, Size, Extension, Error, Restart,Again, Tls, Reserved, + Iana(u16), Library(u16), Bad(u16), +} impl From for CloseCode { + fn from(code: u16) -> Self {match code { + 1000 => Self::Normal, 1001 => Self::Away, 1002 => Self::Protocol, 1003 => Self::Unsupported, + 1005 => Self::Status, 1006 => Self::Abnormal, 1007 => Self::Invalid, 1008 => Self::Policy, + 1009 => Self::Size, 1010 => Self::Extension, 1011 => Self::Error, 1012 => Self::Restart, + 1013 => Self::Again, 1015 => Self::Tls, 1016..=2999 => Self::Reserved, + 3000..=3999 => Self::Iana(code), 4000..=4999 => Self::Library(code), _ => Self::Bad(code), + }} +} + +pub struct Frame { + pub is_final: bool, + pub opcode: OpCode, + pub mask: Option<[u8; 4]>, + pub payload: Vec, +} impl Frame { + pub async fn read_header(stream: &mut (impl AsyncReader + Unpin)) -> Result, Error> { + let [first, second] = { + let mut head = [0; 2]; + stream.read_exact(&mut head).await?; + head + }; + + let is_final = first & 0x80 != 0; + let opcode = OpCode::from(first & 0x0F); + if matches!(opcode, OpCode::Reserved) { + return Err(Error::new(ErrorKind::Unsupported, "Ohkami doesn't handle reserved op codes")) + } + + let length = { + let length_byte = second & 0x7F; + let length_part_length = match length_byte {126=>2, 127=>8, _=>0}; + match length_part_length { + 0 => length_byte as u64, + _ => { + let mut bytes = [0; 8]; + stream.read_exact(&mut bytes[(8 - length_part_length)..]).await?; + u64::from_be_bytes(bytes) + } + } + }; + + let mask = (second & 0x7F != 0).then_some((|| async move { + let mut mask_bytes = [0; 4]; + stream.read_exact(&mut mask_bytes).await?; + Result::<_, Error>::Ok(mask_bytes) + })().await?); + + Ok(Some((Self { + is_final, + opcode, + mask, + payload: Vec::new(), + }, length))) + } +} diff --git a/ohkami/src/x_websocket/message.rs b/ohkami/src/x_websocket/message.rs index 2482c6c8..3d6d869d 100644 --- a/ohkami/src/x_websocket/message.rs +++ b/ohkami/src/x_websocket/message.rs @@ -42,7 +42,7 @@ const _: (/* `From` impls */) = { }; impl Message { - pub(super) async fn from(stream: impl AsyncReader + Unpin) -> Result { - - } + //pub(super) async fn from(stream: impl AsyncReader + Unpin) -> Result { + // + //} } diff --git a/ohkami/src/x_websocket/mod.rs b/ohkami/src/x_websocket/mod.rs index 3c7fb498..19099b07 100644 --- a/ohkami/src/x_websocket/mod.rs +++ b/ohkami/src/x_websocket/mod.rs @@ -1,5 +1,6 @@ mod context; mod message; +mod frame; mod sign; use std::{sync::Arc, io::Result}; From 50cb41c5d9abc86baedc3bbc315814114a232514 Mon Sep 17 00:00:00 2001 From: kana-rus Date: Tue, 7 Nov 2023 10:21:11 +0900 Subject: [PATCH 15/37] @2023-11-07 10:21+9:00 --- ohkami/src/x_websocket/frame.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ohkami/src/x_websocket/frame.rs b/ohkami/src/x_websocket/frame.rs index d3ba0dde..480ccc43 100644 --- a/ohkami/src/x_websocket/frame.rs +++ b/ohkami/src/x_websocket/frame.rs @@ -68,16 +68,16 @@ pub struct Frame { } }; - let mask = (second & 0x7F != 0).then_some((|| async move { + let mask = (second & 0x7F != 0).then(|| async move { let mut mask_bytes = [0; 4]; stream.read_exact(&mut mask_bytes).await?; Result::<_, Error>::Ok(mask_bytes) - })().await?); + }); Ok(Some((Self { is_final, opcode, - mask, + mask: match mask {None => None, Some(f) => Some(f.await?)}, payload: Vec::new(), }, length))) } From 303a7ac8ae1db11345ab227d71a98ae435c42da2 Mon Sep 17 00:00:00 2001 From: kana-rus Date: Tue, 7 Nov 2023 12:48:13 +0900 Subject: [PATCH 16/37] @2023-11-07 12:48+9:00 --- ohkami/src/x_websocket/mod.rs | 11 +++++++---- ohkami/src/x_websocket/sign/sha1.rs | 3 --- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/ohkami/src/x_websocket/mod.rs b/ohkami/src/x_websocket/mod.rs index 19099b07..c25d1078 100644 --- a/ohkami/src/x_websocket/mod.rs +++ b/ohkami/src/x_websocket/mod.rs @@ -1,3 +1,6 @@ +#[cfg(not(target_pointer_width = "64"))] +compile_error!{ "pointer width must be 64" } + mod context; mod message; mod frame; @@ -19,8 +22,8 @@ impl WebSocket { } impl WebSocket { - //pub async fn recv(&self) -> Option> { - // ( self.stream.lock().await) - // .rea - //} + async fn handle(self, handle_message: impl Fn(Message) -> Message) { + let stream = &mut *self.stream.lock().await; + //while let Some(Ok(_)) = stream.re; + } } diff --git a/ohkami/src/x_websocket/sign/sha1.rs b/ohkami/src/x_websocket/sign/sha1.rs index b68d3f86..44c80f06 100644 --- a/ohkami/src/x_websocket/sign/sha1.rs +++ b/ohkami/src/x_websocket/sign/sha1.rs @@ -1,6 +1,3 @@ -#[cfg(not(target_pointer_width = "64"))] -compile_error!{ "pointer width must be 64" } - pub const CHANK: usize = 64; pub const SIZE: usize = 20; // bytes; 160 bits From 66ed48a4b26d0c9379d09730caee6f804be08890 Mon Sep 17 00:00:00 2001 From: kana-rus Date: Tue, 7 Nov 2023 17:01:26 +0900 Subject: [PATCH 17/37] @2023-11-07 17:01+9:00 --- ohkami/src/x_websocket/frame.rs | 56 ++++++++++++++++++++------------- 1 file changed, 34 insertions(+), 22 deletions(-) diff --git a/ohkami/src/x_websocket/frame.rs b/ohkami/src/x_websocket/frame.rs index 480ccc43..91f7cfc5 100644 --- a/ohkami/src/x_websocket/frame.rs +++ b/ohkami/src/x_websocket/frame.rs @@ -37,12 +37,13 @@ pub enum CloseCode { } pub struct Frame { - pub is_final: bool, - pub opcode: OpCode, - pub mask: Option<[u8; 4]>, - pub payload: Vec, + pub is_final: bool, + pub opcode: OpCode, + pub mask: Option<[u8; 4]>, + pub payload_len: usize, + pub payload: Vec, } impl Frame { - pub async fn read_header(stream: &mut (impl AsyncReader + Unpin)) -> Result, Error> { + pub async fn read_header(stream: &mut (impl AsyncReader + Unpin)) -> Result, Error> { let [first, second] = { let mut head = [0; 2]; stream.read_exact(&mut head).await?; @@ -55,30 +56,41 @@ pub struct Frame { return Err(Error::new(ErrorKind::Unsupported, "Ohkami doesn't handle reserved op codes")) } - let length = { - let length_byte = second & 0x7F; - let length_part_length = match length_byte {126=>2, 127=>8, _=>0}; - match length_part_length { - 0 => length_byte as u64, + let payload_len = { + let payload_len_byte = second & 0x7F; + let len_part_size = match payload_len_byte {126=>2, 127=>8, _=>0}; + match len_part_size { + 0 => payload_len_byte as usize, _ => { let mut bytes = [0; 8]; - stream.read_exact(&mut bytes[(8 - length_part_length)..]).await?; - u64::from_be_bytes(bytes) + if let Err(e) = stream.read_exact(&mut bytes[(8 - len_part_size)..]).await { + return match e.kind() { + ErrorKind::UnexpectedEof => Ok(None), + _ => Err(e.into()), + } + } + usize::from_be_bytes(bytes) } } }; - let mask = (second & 0x7F != 0).then(|| async move { + let mask = if second & 0x80 == 0 {None} else { let mut mask_bytes = [0; 4]; - stream.read_exact(&mut mask_bytes).await?; - Result::<_, Error>::Ok(mask_bytes) - }); + if let Err(e) = stream.read_exact(&mut mask_bytes).await { + return match e.kind() { + ErrorKind::UnexpectedEof => Ok(None), + _ => Err(e.into()), + } + } + Some(mask_bytes) + }; + + let payload = { + let mut payload = Vec::with_capacity(payload_len); + stream.read_exact(&mut payload).await?; + payload + }; - Ok(Some((Self { - is_final, - opcode, - mask: match mask {None => None, Some(f) => Some(f.await?)}, - payload: Vec::new(), - }, length))) + Ok(Some(Self { is_final, opcode, payload_len, mask, payload })) } } From dee766ddd5547ba2759a58e9542b55790c92e87e Mon Sep 17 00:00:00 2001 From: kana-rus Date: Tue, 7 Nov 2023 22:33:18 +0900 Subject: [PATCH 18/37] @2023-11-07 23:33+9:00 --- ohkami/src/x_websocket/frame.rs | 21 +++++------- ohkami/src/x_websocket/message.rs | 56 ++++++++++++++++++++++++++++--- 2 files changed, 61 insertions(+), 16 deletions(-) diff --git a/ohkami/src/x_websocket/frame.rs b/ohkami/src/x_websocket/frame.rs index 91f7cfc5..a31514ed 100644 --- a/ohkami/src/x_websocket/frame.rs +++ b/ohkami/src/x_websocket/frame.rs @@ -2,6 +2,7 @@ use std::io::{Error, ErrorKind}; use crate::__rt__::{AsyncReader}; +#[derive(PartialEq)] pub enum OpCode { /* data op codes */ Continue /* 0x0 */, @@ -12,12 +13,12 @@ pub enum OpCode { Ping /* 0x9 */, Pong /* 0xa */, /* reserved op codes */ - Reserved /* 0x[3-7,b-f] */, + // Reserved /* 0x[3-7,b-f] */, } impl From for OpCode { fn from(byte: u8) -> Self {match byte { 0x0 => Self::Continue, 0x1 => Self::Text, 0x2 => Self::Binary, 0x8 => Self::Close, 0x9 => Self::Ping, 0xa => Self::Pong, - 0x3..=0x7 | 0xb..=0xf => Self::Reserved, + // 0x3..=0x7 | 0xb..=0xf => Self::Reserved, _ => panic!("OpCode out of range: {byte}") }} } @@ -37,13 +38,12 @@ pub enum CloseCode { } pub struct Frame { - pub is_final: bool, - pub opcode: OpCode, - pub mask: Option<[u8; 4]>, - pub payload_len: usize, - pub payload: Vec, + pub is_final: bool, + pub opcode: OpCode, + pub mask: Option<[u8; 4]>, + pub payload: Vec, } impl Frame { - pub async fn read_header(stream: &mut (impl AsyncReader + Unpin)) -> Result, Error> { + pub async fn read_from(stream: &mut (impl AsyncReader + Unpin)) -> Result, Error> { let [first, second] = { let mut head = [0; 2]; stream.read_exact(&mut head).await?; @@ -52,9 +52,6 @@ pub struct Frame { let is_final = first & 0x80 != 0; let opcode = OpCode::from(first & 0x0F); - if matches!(opcode, OpCode::Reserved) { - return Err(Error::new(ErrorKind::Unsupported, "Ohkami doesn't handle reserved op codes")) - } let payload_len = { let payload_len_byte = second & 0x7F; @@ -91,6 +88,6 @@ pub struct Frame { payload }; - Ok(Some(Self { is_final, opcode, payload_len, mask, payload })) + Ok(Some(Self { is_final, opcode, mask, payload })) } } diff --git a/ohkami/src/x_websocket/message.rs b/ohkami/src/x_websocket/message.rs index 3d6d869d..279d2120 100644 --- a/ohkami/src/x_websocket/message.rs +++ b/ohkami/src/x_websocket/message.rs @@ -1,5 +1,6 @@ -use std::{borrow::Cow, io::Result}; +use std::{borrow::Cow, io::{Error, ErrorKind}}; use crate::{__rt__::AsyncReader}; +use super::frame::{Frame, OpCode}; pub enum Message { @@ -42,7 +43,54 @@ const _: (/* `From` impls */) = { }; impl Message { - //pub(super) async fn from(stream: impl AsyncReader + Unpin) -> Result { - // - //} + pub(super) async fn read_from(stream: &mut (impl AsyncReader + Unpin)) -> Result, Error> { + let head_frame = match Frame::read_from(stream).await? { + Some(frame) => frame, + None => return Ok(None), + }; + + match &head_frame.opcode { + OpCode::Text => { + let mut payload = String::from_utf8(head_frame.payload) + .map_err(|_| Error::new(ErrorKind::InvalidData, "Text frame's payload is not valid UTF-8"))?; + if !head_frame.is_final { + while let Ok(Some(next_frame)) = Frame::read_from(stream).await { + if next_frame.opcode != OpCode::Continue { + return Err(Error::new(ErrorKind::InvalidData, "Expected continue frame")); + } + payload.push_str(std::str::from_utf8(&next_frame.payload) + .map_err(|_| Error::new(ErrorKind::InvalidData, "Text frame's payload is not valid UTF-8"))? + ); + if next_frame.is_final { + break + } + } + } + Ok(Some(Message::Text(payload))) + } + OpCode::Binary => { + let mut payload = head_frame.payload; + if !head_frame.is_final { + while let Ok(Some(mut next_frame)) = Frame::read_from(stream).await { + if next_frame.opcode != OpCode::Continue { + return Err(Error::new(ErrorKind::InvalidData, "Expected continue frame")); + } + payload.append( + &mut next_frame.payload + ); + if next_frame.is_final { + break + } + } + } + Ok(Some(Message::Binary(payload))) + } + OpCode::Ping => { + todo!() + } + OpCode::Close => return Ok(None), + OpCode::Pong => return Err(Error::new(ErrorKind::InvalidData, "Unexpected pong frame")), + OpCode::Continue => return Err(Error::new(ErrorKind::InvalidData, "Unexpected continue frame")), + } + } } From 7503eb385630cf48e0dc2eb868376e9f1a78a916 Mon Sep 17 00:00:00 2001 From: kana-rus Date: Tue, 7 Nov 2023 23:13:54 +0900 Subject: [PATCH 19/37] @2023-11-07 23:13+9:00 --- ohkami/src/x_websocket/frame.rs | 38 +++++++++++++++++++++++-------- ohkami/src/x_websocket/message.rs | 29 ++++++++++++++++++++--- 2 files changed, 54 insertions(+), 13 deletions(-) diff --git a/ohkami/src/x_websocket/frame.rs b/ohkami/src/x_websocket/frame.rs index a31514ed..ebfd0594 100644 --- a/ohkami/src/x_websocket/frame.rs +++ b/ohkami/src/x_websocket/frame.rs @@ -1,5 +1,5 @@ use std::io::{Error, ErrorKind}; -use crate::__rt__::{AsyncReader}; +use crate::__rt__::{AsyncReader, AsyncWriter}; #[derive(PartialEq)] @@ -14,13 +14,23 @@ pub enum OpCode { Pong /* 0xa */, /* reserved op codes */ // Reserved /* 0x[3-7,b-f] */, -} impl From for OpCode { - fn from(byte: u8) -> Self {match byte { - 0x0 => Self::Continue, 0x1 => Self::Text, 0x2 => Self::Binary, - 0x8 => Self::Close, 0x9 => Self::Ping, 0xa => Self::Pong, - // 0x3..=0x7 | 0xb..=0xf => Self::Reserved, - _ => panic!("OpCode out of range: {byte}") - }} +} impl OpCode { + fn from_byte(byte: u8) -> Result { + Ok(match byte { + 0x0 => Self::Continue, 0x1 => Self::Text, 0x2 => Self::Binary, + 0x8 => Self::Close, 0x9 => Self::Ping, 0xa => Self::Pong, + 0x3..=0x7 | 0xb..=0xf => return Err(Error::new( + ErrorKind::Unsupported, "Ohkami doesn't handle reserved opcodes")), + _ => return Err(Error::new( + ErrorKind::InvalidData, "OpCode out of range")), + }) + } + fn into_byte(self) -> u8 { + match self { + Self::Continue => 0x0, Self::Text => 0x1, Self::Binary => 0x2, + Self::Close => 0x8, Self::Ping => 0x9, Self::Pong => 0xa, + } + } } pub enum CloseCode { @@ -43,7 +53,7 @@ pub struct Frame { pub mask: Option<[u8; 4]>, pub payload: Vec, } impl Frame { - pub async fn read_from(stream: &mut (impl AsyncReader + Unpin)) -> Result, Error> { + pub(super) async fn read_from(stream: &mut (impl AsyncReader + Unpin)) -> Result, Error> { let [first, second] = { let mut head = [0; 2]; stream.read_exact(&mut head).await?; @@ -51,7 +61,7 @@ pub struct Frame { }; let is_final = first & 0x80 != 0; - let opcode = OpCode::from(first & 0x0F); + let opcode = OpCode::from_byte(first & 0x0F)?; let payload_len = { let payload_len_byte = second & 0x7F; @@ -90,4 +100,12 @@ pub struct Frame { Ok(Some(Self { is_final, opcode, mask, payload })) } + + pub(super) async fn write_to(self, stream: &mut (impl AsyncWriter + Unpin)) -> Result<(), Error> { + fn into_bytes(frame: Frame) -> Vec { + + } + + stream.write_all(&into_bytes(self)).await + } } diff --git a/ohkami/src/x_websocket/message.rs b/ohkami/src/x_websocket/message.rs index 279d2120..287fb033 100644 --- a/ohkami/src/x_websocket/message.rs +++ b/ohkami/src/x_websocket/message.rs @@ -1,5 +1,5 @@ use std::{borrow::Cow, io::{Error, ErrorKind}}; -use crate::{__rt__::AsyncReader}; +use crate::{__rt__::{AsyncReader, AsyncWriter}}; use super::frame::{Frame, OpCode}; @@ -8,7 +8,7 @@ pub enum Message { Binary(Vec), Ping (PingPongFrame), Pong (PingPongFrame), - Close (Option), + Close (CloseFrame), } pub struct PingPongFrame { buf: [u8; 125], @@ -16,7 +16,7 @@ pub struct PingPongFrame { } pub struct CloseFrame { pub code: u16, - pub reason: Cow<'static, str>, + pub reason: Option>, } const _: (/* `From` impls */) = { @@ -42,6 +42,29 @@ const _: (/* `From` impls */) = { } }; +impl Message { + pub(super) async fn send(self, stream: &mut (impl AsyncWriter + Unpin)) -> Result<(), Error> { + fn into_frame(message: Message) -> Frame { + let (opcode, payload) = match message { + Message::Text (text) => (OpCode::Text, text.into_bytes()), + Message::Binary(vec) => (OpCode::Binary, vec), + Message::Ping (PingPongFrame { buf, len }) => (OpCode::Ping, buf[..len].to_vec()), + Message::Pong (PingPongFrame { buf, len }) => (OpCode::Pong, buf[..len].to_vec()), + Message::Close (CloseFrame { code, reason }) => { + let mut payload = code.to_be_bytes().to_vec(); + if let Some(reason_text) = reason { + payload.extend_from_slice(reason_text.as_bytes()) + } + (OpCode::Close, payload) + } + }; + Frame { is_final: false, mask: None, opcode, payload } + } + + into_frame(self).write_to(stream).await + } +} + impl Message { pub(super) async fn read_from(stream: &mut (impl AsyncReader + Unpin)) -> Result, Error> { let head_frame = match Frame::read_from(stream).await? { From cb6690b78a04e5676be46ec8083355143c348438 Mon Sep 17 00:00:00 2001 From: kana-rus Date: Wed, 8 Nov 2023 22:14:32 +0900 Subject: [PATCH 20/37] @2023-11-08 22:14+9:00 --- ohkami/src/x_websocket/context.rs | 72 ++++++++++++++++++------------- ohkami/src/x_websocket/frame.rs | 21 ++++++++- ohkami/src/x_websocket/mod.rs | 16 ++++--- 3 files changed, 70 insertions(+), 39 deletions(-) diff --git a/ohkami/src/x_websocket/context.rs b/ohkami/src/x_websocket/context.rs index b6b694a5..8e4591a2 100644 --- a/ohkami/src/x_websocket/context.rs +++ b/ohkami/src/x_websocket/context.rs @@ -1,17 +1,20 @@ -use std::{future::Future, borrow::Cow}; +use std::{future::Future, borrow::Cow, sync::Arc}; use super::{WebSocket, sign}; -use crate::{Response, Context, __rt__, Request}; +use crate::{Response, Context, Request}; +use crate::__rt__::{task, Mutex, TcpStream}; use crate::http::{Method}; -pub struct WebSocketContext { +pub struct WebSocketContext { c: Context, + stream: Arc>, + config: Config, - on_failed_upgrade: FU, + on_failed_upgrade: Box, - selected_protocol: Option>, sec_websocket_key: Cow<'static, str>, + selected_protocol: Option>, sec_websocket_protocol: Option>, } @@ -36,18 +39,9 @@ pub struct Config { }; pub enum UpgradeError { /* TODO */ } -pub trait OnFailedUpgrade: Send + 'static { - fn handle(self, error: UpgradeError); -} -pub struct DefaultOnFailedUpgrade; const _: () = { - impl OnFailedUpgrade for DefaultOnFailedUpgrade { - fn handle(self, _: UpgradeError) { /* DO NOTHING (discard error) */ } - } -}; - impl WebSocketContext { - pub(crate) fn new(c: Context, req: &mut Request) -> Result> { + pub(crate) fn new(c: Context, stream: Arc>, req: &mut Request) -> Result> { if req.method() != Method::GET { return Err(Cow::Borrowed("Method is not `GET`")) } @@ -68,9 +62,9 @@ impl WebSocketContext { let sec_websocket_protocol = req.header("Sec-WebSocket-Protocol") .map(|swp| Cow::Owned(swp.to_string())); - Ok(Self {c, + Ok(Self {c, stream, config: Config::default(), - on_failed_upgrade: DefaultOnFailedUpgrade, + on_failed_upgrade: Box::new(|_| (/* discard error */)), selected_protocol: None, sec_websocket_key, sec_websocket_protocol, @@ -78,7 +72,7 @@ impl WebSocketContext { } } -impl WebSocketContext { +impl WebSocketContext { pub fn write_buffer_size(mut self, size: usize) -> Self { self.config.write_buffer_size = size; self @@ -99,9 +93,7 @@ impl WebSocketContext { self.config.accept_unmasked_frames = true; self } -} -impl WebSocketContext { pub fn protocols>>(mut self, protocols: impl Iterator) -> Self { if let Some(req_protocols) = &self.sec_websocket_protocol { self.selected_protocol = protocols.map(Into::into) @@ -111,15 +103,23 @@ impl WebSocketContext { } } -impl WebSocketContext { +impl WebSocketContext { pub fn on_upgrade< Fut: Future + Send + 'static, >( self, - callback: impl Fn(WebSocket) -> Fut + Send + 'static + handler: impl Fn(WebSocket) -> Fut + Send + Sync + 'static ) -> Response { + fn sign(sec_websocket_key: &str) -> String { + let mut sha1 = sign::Sha1::new(); + sha1.write(sec_websocket_key.as_bytes()); + sha1.write(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); + sign::Base64::<{sign::SHA1_SIZE}>::encode(sha1.sum()) + } + let Self { mut c, + stream, config, on_failed_upgrade, selected_protocol, @@ -127,8 +127,25 @@ impl WebSocketContext { sec_websocket_protocol, } = self; - __rt__::task::spawn(async move { - todo!() + task::spawn({ + #[cfg(debug_assertions)] let mut __loop_count = 0; + + let stream = loop { + #[cfg(debug_assertions)] { + if __loop_count == usize::MAX {panic!("Infinite loop in web socket handshake")}} + + if Arc::strong_count(&stream) == 1 { + break Arc::into_inner(stream).unwrap().into_inner() + } + + #[cfg(debug_assertions)] { + __loop_count += 1} + }; + + async move { + let ws = WebSocket::new(stream); + handler(ws).await + } }); c.headers @@ -142,10 +159,3 @@ impl WebSocketContext { c.SwitchingProtocols() } } - -fn sign(sec_websocket_key: &str) -> String { - let mut sha1 = sign::Sha1::new(); - sha1.write(sec_websocket_key.as_bytes()); - sha1.write(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); - sign::Base64::<{sign::SHA1_SIZE}>::encode(sha1.sum()) -} diff --git a/ohkami/src/x_websocket/frame.rs b/ohkami/src/x_websocket/frame.rs index ebfd0594..187fd305 100644 --- a/ohkami/src/x_websocket/frame.rs +++ b/ohkami/src/x_websocket/frame.rs @@ -103,7 +103,26 @@ pub struct Frame { pub(super) async fn write_to(self, stream: &mut (impl AsyncWriter + Unpin)) -> Result<(), Error> { fn into_bytes(frame: Frame) -> Vec { - + let Frame { is_final, opcode, mask, payload } = frame; + + let (payload_len_byte, payload_len_bytes) = match payload.len() { + ..=125 => (payload.len() as u8, None), + 126..=65535 => (126, Some((|| (payload.len() as u16).to_be_bytes().to_vec())())), + _ => (127, Some((|| (payload.len() as u64).to_be_bytes().to_vec())())), + }; + + let first = is_final.then_some(1).unwrap_or(0) << 7 + opcode.into_byte(); + let second = mask.is_some().then_some(1).unwrap_or(0) << 7 + payload_len_byte; + + let mut header_bytes = vec![first, second]; + if let Some(mut payload_len_bytes) = payload_len_bytes { + header_bytes.append(&mut payload_len_bytes) + } + if let Some(mask_bytes) = mask { + header_bytes.extend_from_slice(&mask_bytes) + } + + [header_bytes, payload].concat() } stream.write_all(&into_bytes(self)).await diff --git a/ohkami/src/x_websocket/mod.rs b/ohkami/src/x_websocket/mod.rs index c25d1078..5175f2f1 100644 --- a/ohkami/src/x_websocket/mod.rs +++ b/ohkami/src/x_websocket/mod.rs @@ -6,24 +6,26 @@ mod message; mod frame; mod sign; -use std::{sync::Arc, io::Result}; -use crate::__rt__::{TcpStream, Mutex, AsyncReader, AsyncWriter}; +use std::io::{Error, ErrorKind}; +use crate::__rt__::TcpStream; use self::message::Message; pub struct WebSocket { - stream: Arc>, + stream: TcpStream } impl WebSocket { - fn new(stream: Arc>) -> Self { + fn new(stream: TcpStream) -> Self { Self { stream } } } impl WebSocket { - async fn handle(self, handle_message: impl Fn(Message) -> Message) { - let stream = &mut *self.stream.lock().await; - //while let Some(Ok(_)) = stream.re; + pub async fn recv(&mut self) -> Result, Error> { + Message::read_from(&mut self.stream).await + } + pub async fn send(&mut self, message: Message) -> Result<(), Error> { + message.send(&mut self.stream).await } } From 6c10467228a59b1ebcc7669f756d8b7738aa6984 Mon Sep 17 00:00:00 2001 From: kana-rus Date: Fri, 10 Nov 2023 00:35:21 +0900 Subject: [PATCH 21/37] @2023-11-10 00:35+9:00 --- ohkami/src/layer4_router/radix.rs | 3 ++ ohkami/src/layer6_testing/mod.rs | 53 +++++++++++++++++++++++++++++-- ohkami/src/x_websocket/context.rs | 27 ++++++++++------ 3 files changed, 71 insertions(+), 12 deletions(-) diff --git a/ohkami/src/layer4_router/radix.rs b/ohkami/src/layer4_router/radix.rs index b1432299..11a0d5a4 100644 --- a/ohkami/src/layer4_router/radix.rs +++ b/ohkami/src/layer4_router/radix.rs @@ -6,6 +6,9 @@ use crate::{ layer3_fang_handler::{Handler, FrontFang, PathParams, BackFang}, }; +#[cfg(feature="websocket")] +use {std::sync::Arc, crate::__rt__::{Mutex, AsyncReader, AsyncWriter}}; + /*===== defs =====*/ pub(crate) struct RadixRouter { diff --git a/ohkami/src/layer6_testing/mod.rs b/ohkami/src/layer6_testing/mod.rs index 74c68743..9cef6241 100644 --- a/ohkami/src/layer6_testing/mod.rs +++ b/ohkami/src/layer6_testing/mod.rs @@ -8,6 +8,9 @@ use byte_reader::Reader; use crate::{Response, Request, Ohkami, Context}; use crate::layer0_lib::{IntoCows, Status, Method, ContentType}; +#[cfg(feature="websocket")] +use {std::sync::Arc, crate::__rt__::Mutex}; + pub trait Testing { fn oneshot(&self, req: TestRequest) -> TestFuture; @@ -17,11 +20,57 @@ pub struct TestFuture( impl Future for TestFuture { type Output = TestResponse; fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { - unsafe {self.map_unchecked_mut(|this| this.0.as_mut())} - .poll(cx) + unsafe {self.map_unchecked_mut(|this| this.0.as_mut())}.poll(cx) } } +#[cfg(feature="websocket")] +pub struct TestWebSocket(Vec); +#[cfg(feature="websocket")] const _: () = { + impl TestWebSocket { + fn new(size: usize) -> Self { + Self(Vec::with_capacity(size)) + } + } + + #[cfg(feature="rt_tokio")] const _: () = { + impl tokio::io::AsyncRead for TestWebSocket { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + let mut pin_inner = unsafe {self.map_unchecked_mut(|this| &mut this.0)}; + + let amt = std::cmp::min(pin_inner.len(), buf.remaining()); + let (a, b) = pin_inner.split_at(amt); + buf.put_slice(a); + *pin_inner = b.to_vec(); + + std::task::Poll::Ready(Ok(())) + } + } + impl tokio::io::AsyncWrite for TestWebSocket { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + unsafe {self.map_unchecked_mut(|this| &mut this.0)} + .poll_write(cx, buf) + } + fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { + unsafe {self.map_unchecked_mut(|this| &mut this.0)} + .poll_flush(cx) + } + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { + unsafe {self.map_unchecked_mut(|this| &mut this.0)} + .poll_shutdown(cx) + } + } + }; +}; + impl Testing for Ohkami { fn oneshot(&self, request: TestRequest) -> TestFuture { let router = { diff --git a/ohkami/src/x_websocket/context.rs b/ohkami/src/x_websocket/context.rs index 8e4591a2..362a4ec5 100644 --- a/ohkami/src/x_websocket/context.rs +++ b/ohkami/src/x_websocket/context.rs @@ -128,18 +128,25 @@ impl WebSocketContext { } = self; task::spawn({ - #[cfg(debug_assertions)] let mut __loop_count = 0; - - let stream = loop { - #[cfg(debug_assertions)] { - if __loop_count == usize::MAX {panic!("Infinite loop in web socket handshake")}} - - if Arc::strong_count(&stream) == 1 { - break Arc::into_inner(stream).unwrap().into_inner() + //#[cfg(debug_assertions)] let mut __loop_count = 0; +// + //let stream = loop { + // #[cfg(debug_assertions)] { + // if __loop_count == usize::MAX {panic!("Infinite loop in web socket handshake")}} +// + // if Arc::strong_count(&stream) == 1 { + // break Arc::into_inner(stream).unwrap().into_inner() + // } +// + // #[cfg(debug_assertions)] { + // __loop_count += 1} + //}; + let stream = { + while Arc::strong_count(&stream) > 1 { + } - #[cfg(debug_assertions)] { - __loop_count += 1} + Arc::into_inner(stream).unwrap().into_inner() }; async move { From ff7a5b28161bc3e0c5b33952605aa25fadbd5b9f Mon Sep 17 00:00:00 2001 From: kana-rus Date: Fri, 10 Nov 2023 23:34:39 +0900 Subject: [PATCH 22/37] @2023-11-10 23:34+9:00 --- ohkami/src/lib.rs | 7 ++++ ohkami/src/x_websocket/mod.rs | 5 ++- ohkami/src/x_websocket/upgrade.rs | 68 +++++++++++++++++++++++++++++++ 3 files changed, 78 insertions(+), 2 deletions(-) create mode 100644 ohkami/src/x_websocket/upgrade.rs diff --git a/ohkami/src/lib.rs b/ohkami/src/lib.rs index d2f0c27b..1d9e15d5 100644 --- a/ohkami/src/lib.rs +++ b/ohkami/src/lib.rs @@ -222,6 +222,13 @@ mod __rt__ { #[cfg(feature="rt_tokio")] pub(crate) use tokio::sync::Mutex; + #[cfg(feature="rt_async-std")] + pub(crate) use async_std::sync::Mutex; + + #[cfg(feature="rt_tokio")] + pub(crate) use tokio::sync::RwLock; + #[cfg(feature="rt_async-std")] + pub(crate) use async_std::sync::RwLock; #[cfg(feature="rt_tokio")] pub(crate) use tokio::net::TcpListener; diff --git a/ohkami/src/x_websocket/mod.rs b/ohkami/src/x_websocket/mod.rs index 5175f2f1..396508bf 100644 --- a/ohkami/src/x_websocket/mod.rs +++ b/ohkami/src/x_websocket/mod.rs @@ -3,11 +3,12 @@ compile_error!{ "pointer width must be 64" } mod context; mod message; +mod upgrade; mod frame; mod sign; -use std::io::{Error, ErrorKind}; -use crate::__rt__::TcpStream; +use std::{io::{Error, ErrorKind}, sync::{Arc, OnceLock, atomic::AtomicUsize}, collections::HashMap}; +use crate::__rt__::{TcpStream, RwLock}; use self::message::Message; diff --git a/ohkami/src/x_websocket/upgrade.rs b/ohkami/src/x_websocket/upgrade.rs new file mode 100644 index 00000000..4bfba3bd --- /dev/null +++ b/ohkami/src/x_websocket/upgrade.rs @@ -0,0 +1,68 @@ +use std::{sync::{Arc, OnceLock}, future::Future, pin::Pin}; +use crate::__rt__::{TcpStream, Mutex}; +type UpgradeLock = Mutex; + + +pub static UPGRADE_STREAMS: OnceLock = OnceLock::new(); +pub async fn wait_upgrade(arc_stream: Arc>) -> UpgradeID { + UPGRADE_STREAMS.get_or_init(UpgradeStreams::new) + .push(arc_stream).await +} +pub async fn assume_upgraded(id: UpgradeID) -> TcpStream { + UPGRADE_STREAMS.get_or_init(UpgradeStreams::new) + .get(id).await +} + +pub struct UpgradeID(usize); + +pub struct UpgradeStreams { + streams: UpgradeLock>, +} const _: () = { + impl UpgradeStreams { + fn new() -> Self { + Self { + streams: UpgradeLock::new(Vec::new()), + } + } + } + impl UpgradeStreams { + async fn push(&self, arc_stream: Arc>) -> UpgradeID { + let mut this = self.streams.lock().await; + let id = match this.iter().position(|us| us.is_available()) { + Some(i) => {this[i] = UpgradeStream::new(arc_stream); i} + None => {this.push(UpgradeStream::new(arc_stream)); this.len()-1} + }; + UpgradeID(id) + } + async fn get(&self, id: UpgradeID) -> TcpStream { + let mut this = self.streams.lock().await; + Pin::new(this.get_mut(id.0).unwrap()).await + } + } +}; + +struct UpgradeStream( + Option>> +); const _: () = { + impl UpgradeStream { + fn new(arc_stream: Arc>) -> Self { + Self(Some(arc_stream)) + } + fn is_available(&self) -> bool { + self.0.is_none() + } + } + impl Default for UpgradeStream { + fn default() -> Self {Self(None)} + } + impl Future for UpgradeStream { + type Output = TcpStream; + fn poll(self: std::pin::Pin<&mut Self>, _: &mut std::task::Context<'_>) -> std::task::Poll { + match Arc::strong_count(self.0.as_ref().unwrap()) { + 1 => std::task::Poll::Ready(Arc::into_inner(self.get_mut().0.take().unwrap()).unwrap().into_inner()), + _ => std::task::Poll::Pending, + } + } + } +}; + From 13ae665ace51f97cd919b8fb80a2870bf78d79c3 Mon Sep 17 00:00:00 2001 From: kana-rus Date: Sat, 11 Nov 2023 01:02:14 +0900 Subject: [PATCH 23/37] @2023-11-11 01:02+9:00 --- ohkami/src/layer2_context/mod.rs | 10 ++- .../handler/into_handler.rs | 32 ++++----- ohkami/src/layer3_fang_handler/handler/mod.rs | 20 +++--- ohkami/src/layer4_router/radix.rs | 70 ++++++++++++------- ohkami/src/layer5_ohkami/howl.rs | 19 +++-- ohkami/src/layer6_testing/mod.rs | 2 +- ohkami/src/x_websocket/context.rs | 32 ++------- ohkami/src/x_websocket/mod.rs | 2 + ohkami/src/x_websocket/upgrade.rs | 60 +++++++++++----- 9 files changed, 148 insertions(+), 99 deletions(-) diff --git a/ohkami/src/layer2_context/mod.rs b/ohkami/src/layer2_context/mod.rs index 89bb610b..282faf19 100644 --- a/ohkami/src/layer2_context/mod.rs +++ b/ohkami/src/layer2_context/mod.rs @@ -73,12 +73,20 @@ use crate::{ /// } /// ``` pub struct Context { + #[cfg(feature="websocket")] + pub(crate) upgrade_id: Option, + pub headers: ResponseHeaders, } impl Context { #[inline(always)] pub(crate) fn new() -> Self { - Self { headers: ResponseHeaders::new() } + Self { + #[cfg(feature="websocket")] + upgrade_id: None, + + headers: ResponseHeaders::new(), + } } } diff --git a/ohkami/src/layer3_fang_handler/handler/into_handler.rs b/ohkami/src/layer3_fang_handler/handler/into_handler.rs index fd503ff2..df8ee44a 100644 --- a/ohkami/src/layer3_fang_handler/handler/into_handler.rs +++ b/ohkami/src/layer3_fang_handler/handler/into_handler.rs @@ -18,7 +18,7 @@ const _: (/* only Context */) = { Fut: Future + Send + Sync + 'static, { fn into_handler(self) -> Handler { - Handler::new(move |_, c, _| + Handler::new(false, move |_, c, _| Box::pin({ let res = self(c); async {res.await} @@ -37,7 +37,7 @@ const _: (/* PathParam */) = { Fut: Future + Send + Sync + 'static, { fn into_handler(self) -> Handler { - Handler::new(move |_, c, params| + Handler::new(false, move |_, c, params| match <$param_type as PathParam>::parse(unsafe {params.assume_init_first().as_bytes()}) { Ok(p1) => Box::pin({ let res = self(c, p1); @@ -62,7 +62,7 @@ const _: (/* PathParam */) = { Fut: Future + Send + Sync + 'static, { fn into_handler(self) -> Handler { - Handler::new(move |_, c, params| + Handler::new(false, move |_, c, params| // SAFETY: Due to the architecture of `Router`, // `params` has already `append`ed once before this code match ::parse(unsafe {params.assume_init_first().as_bytes()}) { @@ -85,7 +85,7 @@ const _: (/* PathParam */) = { Fut: Future + Send + Sync + 'static, { fn into_handler(self) -> Handler { - Handler::new(move |_, c, params| { + Handler::new(false, move |_, c, params| { let (p1_range, p2_range) = params.assume_init_extract(); // SAFETY: Due to the architecture of `Router`, // `params` has already `append`ed twice before this code @@ -117,7 +117,7 @@ const _: (/* FromRequest items */) = { Fut: Future + Send + Sync + 'static, { fn into_handler(self) -> Handler { - Handler::new(move |req, c, _| + Handler::new(false, move |req, c, _| match Item1::parse(&req) { Ok(item1) => Box::pin({ let res = self(c, item1); @@ -138,7 +138,7 @@ const _: (/* FromRequest items */) = { Fut: Future + Send + Sync + 'static, { fn into_handler(self) -> Handler { - Handler::new(move |req, c, _| + Handler::new(false, move |req, c, _| match Item1::parse(&req) { Ok(item1) => match Item2::parse(&req) { Ok(item2) => Box::pin({ @@ -165,7 +165,7 @@ const _: (/* FromRequest items */) = { Fut: Future + Send + Sync + 'static, { fn into_handler(self) -> Handler { - Handler::new(move |req, c, _| + Handler::new(false, move |req, c, _| match Item1::parse(&req) { Ok(item1) => match Item2::parse(&req) { Ok(item2) => match Item3::parse(&req) { @@ -202,7 +202,7 @@ const _: (/* single PathParam and FromRequest items */) = { Fut: Future + Send + Sync + 'static, { fn into_handler(self) -> Handler { - Handler::new(move |req, c, params| + Handler::new(false, move |req, c, params| // SAFETY: Due to the architecture of `Router`, // `params` has already `append`ed once before this code match <$param_type as PathParam>::parse(unsafe {params.assume_init_first().as_bytes()}) { @@ -231,7 +231,7 @@ const _: (/* single PathParam and FromRequest items */) = { Fut: Future + Send + Sync + 'static, { fn into_handler(self) -> Handler { - Handler::new(move |req, c, params| + Handler::new(false, move |req, c, params| // SAFETY: Due to the architecture of `Router`, // `params` has already `append`ed once before this code match <$param_type as PathParam>::parse(unsafe {params.assume_init_first().as_bytes()}) { @@ -266,7 +266,7 @@ const _: (/* single PathParam and FromRequest items */) = { Fut: Future + Send + Sync + 'static, { fn into_handler(self) -> Handler { - Handler::new(move |req, c, params| + Handler::new(false, move |req, c, params| // SAFETY: Due to the architecture of `Router`, // `params` has already `append`ed once before this code match <$param_type as PathParam>::parse(unsafe {params.assume_init_first().as_bytes()}) { @@ -313,7 +313,7 @@ const _: (/* one PathParam and FromRequest items */) = { Fut: Future + Send + Sync + 'static, { fn into_handler(self) -> Handler { - Handler::new(move |req, c, params| + Handler::new(false, move |req, c, params| // SAFETY: Due to the architecture of `Router`, // `params` has already `append`ed once before this code match P1::parse(unsafe {params.assume_init_first().as_bytes()}) { @@ -342,7 +342,7 @@ const _: (/* one PathParam and FromRequest items */) = { Fut: Future + Send + Sync + 'static, { fn into_handler(self) -> Handler { - Handler::new(move |req, c, params| + Handler::new(false, move |req, c, params| // SAFETY: Due to the architecture of `Router`, // `params` has already `append`ed once before this code match P1::parse(unsafe {params.assume_init_first().as_bytes()}) { @@ -377,7 +377,7 @@ const _: (/* one PathParam and FromRequest items */) = { Fut: Future + Send + Sync + 'static, { fn into_handler(self) -> Handler { - Handler::new(move |req, c, params| + Handler::new(false, move |req, c, params| // SAFETY: Due to the architecture of `Router`, // `params` has already `append`ed once before this code match P1::parse(unsafe {params.assume_init_first().as_bytes()}) { @@ -420,7 +420,7 @@ const _: (/* two PathParams and FromRequest items */) = { Fut: Future + Send + Sync + 'static, { fn into_handler(self) -> Handler { - Handler::new(move |req, c, params| { + Handler::new(false, move |req, c, params| { let (p1_range, p2_range) = params.assume_init_extract(); // SAFETY: Due to the architecture of `Router`, @@ -457,7 +457,7 @@ const _: (/* two PathParams and FromRequest items */) = { Fut: Future + Send + Sync + 'static, { fn into_handler(self) -> Handler { - Handler::new(move |req, c, params| { + Handler::new(false, move |req, c, params| { let (p1_range, p2_range) = params.assume_init_extract(); // SAFETY: Due to the architecture of `Router`, @@ -500,7 +500,7 @@ const _: (/* two PathParams and FromRequest items */) = { Fut: Future + Send + Sync + 'static, { fn into_handler(self) -> Handler { - Handler::new(move |req, c, params| { + Handler::new(false, move |req, c, params| { let (p1_range, p2_range) = params.assume_init_extract(); // SAFETY: Due to the architecture of `Router`, diff --git a/ohkami/src/layer3_fang_handler/handler/mod.rs b/ohkami/src/layer3_fang_handler/handler/mod.rs index 01a4cde1..4dfbde69 100644 --- a/ohkami/src/layer3_fang_handler/handler/mod.rs +++ b/ohkami/src/layer3_fang_handler/handler/mod.rs @@ -18,8 +18,9 @@ pub(crate) type PathParams = List; #[cfg(not(test))] -pub struct Handler( - pub(crate) Box Pin< Box @@ -27,12 +28,13 @@ pub struct Handler( > > + Send + Sync + 'static > -); +} #[cfg(test)] #[derive(Clone)] -pub struct Handler( - pub(crate) Arc Pin< Box @@ -40,11 +42,11 @@ pub struct Handler( > > + Send + Sync + 'static > -); +} impl Handler { - fn new(proc: (impl + fn new(requires_upgrade: bool, proc: (impl Fn(&mut Request, Context, PathParams) -> Pin< Box @@ -53,7 +55,7 @@ impl Handler { > + Send + Sync + 'static ) ) -> Self { - #[cfg(not(test))] {Self(Box::new(proc))} - #[cfg(test)] {Self(Arc::new(proc))} + #[cfg(not(test))] {Self { requires_upgrade, proc: Box::new(proc) }} + #[cfg(test)] {Self { requires_upgrade, proc: Arc::new(proc) }} } } diff --git a/ohkami/src/layer4_router/radix.rs b/ohkami/src/layer4_router/radix.rs index 11a0d5a4..c46909f2 100644 --- a/ohkami/src/layer4_router/radix.rs +++ b/ohkami/src/layer4_router/radix.rs @@ -4,11 +4,9 @@ use crate::{ Response, layer0_lib::{Method, Status, Slice}, layer3_fang_handler::{Handler, FrontFang, PathParams, BackFang}, + websocket::{request_upgrade_id, UpgradeID}, }; -#[cfg(feature="websocket")] -use {std::sync::Arc, crate::__rt__::{Mutex, AsyncReader, AsyncWriter}}; - /*===== defs =====*/ pub(crate) struct RadixRouter { @@ -51,7 +49,7 @@ impl RadixRouter { &self, mut c: Context, req: &mut Request, - ) -> Response { + ) -> (Response, Option) { let mut params = PathParams::new(); let search_result = match req.method() { Method::GET => self.GET .search(&mut c, req/*.path_bytes()*/, &mut params), @@ -65,17 +63,17 @@ impl RadixRouter { for ff in front { if let Err(err_res) = ff.0(&mut c, req) { - return err_res + return (err_res, None) } } let target = match self.GET.search(&mut c, req/*.path_bytes()*/, &mut params) { Ok(Some(node)) => node, - Ok(None) => return c.NotFound(), - Err(err_res) => return err_res, + Ok(None) => return (c.NotFound(), None), + Err(err_res) => return (err_res, None), }; - let Response { headers, .. } = target.handle(c, req, params).await; + let Response { headers, .. } = target.handle_discarding_upgrade(c, req, params).await; let mut res = Response { headers, status: Status::NoContent, @@ -86,18 +84,18 @@ impl RadixRouter { res = bf.0(res) } - return res + return (res, None) } Method::OPTIONS => { let Some((cors_str, cors)) = crate::layer3_fang_handler::builtin::CORS.get() else { - return c.InternalServerError() + return (c.InternalServerError(), None) }; let (front, back) = self.OPTIONSfangs; for ff in front { if let Err(err_res) = ff.0(&mut c, req) { - return err_res + return (err_res, None) } } @@ -105,33 +103,33 @@ impl RadixRouter { { let Some(origin) = req.header("Origin") else { - return c.BadRequest() + return (c.BadRequest(), None) }; if !cors.AllowOrigin.matches(origin) { - return c.Forbidden() + return (c.Forbidden(), None) } if req.header("Authorization").is_some() && !cors.AllowCredentials { - return c.Forbidden() + return (c.Forbidden(), None) } if let Some(request_method) = req.header("Access-Control-Request-Method") { let request_method = Method::from_bytes(request_method.as_bytes()); let Some(allow_methods) = cors.AllowMethods.as_ref() else { - return c.Forbidden() + return (c.Forbidden(), None) }; if !allow_methods.contains(&request_method) { - return c.Forbidden() + return (c.Forbidden(), None) } } if let Some(request_headers) = req.header("Access-Control-Request-Headers") { let mut request_headers = request_headers.split(',').map(|h| h.trim_matches(' ')); let Some(allow_headers) = cors.AllowHeaders.as_ref() else { - return c.Forbidden() + return (c.Forbidden(), None) }; if !request_headers.all(|h| allow_headers.contains(&h)) { - return c.Forbidden() + return (c.Forbidden(), None) } } } @@ -142,29 +140,53 @@ impl RadixRouter { res = bf.0(res) } - return res + return (res, None) } }; let target = match search_result { Ok(Some(node)) => node, - Ok(None) => return c.NotFound(), - Err(err_res) => return err_res, + Ok(None) => return (c.NotFound(), None), + Err(err_res) => return (err_res, None), }; - target.handle(c, req, params).await + #[cfg(not(test))] {target.handle(c, req, params).await} + #[cfg(test)] {(target.handle_discarding_upgrade(c, req, params).await, None)} } } impl Node { #[inline] pub(super) async fn handle(&self, + mut c: Context, + req: &mut Request, + params: PathParams, + ) -> (Response, Option) { + match &self.handler { + Some(Handler { requires_upgrade, proc }) => { + let upgrade_id = match (*requires_upgrade).then(|| async { + let id = request_upgrade_id().await; + c.upgrade_id = Some(id); + id + }) {None => None, Some(id) => Some(id.await)}; + + let mut res = proc(req, c, params).await; + for b in self.back { + res = b.0(res); + } + + (res, upgrade_id) + } + None => (c.NotFound(), None) + } + } + #[inline] pub(super) async fn handle_discarding_upgrade(&self, c: Context, req: &mut Request, params: PathParams, ) -> Response { match &self.handler { - Some(h) => { - let mut res = h.0(req, c, params).await; + Some(Handler { requires_upgrade:_, proc }) => { + let mut res = proc(req, c, params).await; for b in self.back { res = b.0(res); } diff --git a/ohkami/src/layer5_ohkami/howl.rs b/ohkami/src/layer5_ohkami/howl.rs index b2e3948c..d088a7be 100644 --- a/ohkami/src/layer5_ohkami/howl.rs +++ b/ohkami/src/layer5_ohkami/howl.rs @@ -1,6 +1,6 @@ use std::{sync::Arc, pin::Pin}; use super::{Ohkami}; -use crate::{__rt__, Request, Context}; +use crate::{__rt__, Request, Context, websocket::reserve_upgrade}; #[cfg(feature="rt_async-std")] use crate::__rt__::StreamExt; @@ -65,7 +65,7 @@ impl Ohkami { } )); - if let Err(e) = __rt__::task::spawn({ + match __rt__::task::spawn({ let router = router.clone(); let stream = stream.clone(); @@ -76,15 +76,22 @@ impl Ohkami { let mut req = unsafe {Pin::new_unchecked(&mut req)}; req.as_mut().read(stream).await; - let res = router.handle(Context::new(), req.get_mut()).await; - res.send(stream).await + let (res, upgrade_id) = router.handle(Context::new(), req.get_mut()).await; + res.send(stream).await; + + upgrade_id } }).await { - (|| async { + Ok(upgrade_id) => { + if let Some(id) = upgrade_id { + reserve_upgrade(id, stream).await + } + } + Err(e) => (|| async { println!("Fatal error: {e}"); let res = Context::new().InternalServerError(); res.send(&mut *stream.lock().await).await - })().await + })().await, } } } diff --git a/ohkami/src/layer6_testing/mod.rs b/ohkami/src/layer6_testing/mod.rs index 9cef6241..517da100 100644 --- a/ohkami/src/layer6_testing/mod.rs +++ b/ohkami/src/layer6_testing/mod.rs @@ -86,7 +86,7 @@ impl Testing for Ohkami { let mut req = unsafe {Pin::new_unchecked(&mut req)}; req.as_mut().read(&mut &request.encode_request()[..]).await; - let res = router.handle(Context::new(), &mut req).await; + let (res, _) = router.handle(Context::new(), &mut req).await; TestResponse::new(res) }; diff --git a/ohkami/src/x_websocket/context.rs b/ohkami/src/x_websocket/context.rs index 362a4ec5..7235e47a 100644 --- a/ohkami/src/x_websocket/context.rs +++ b/ohkami/src/x_websocket/context.rs @@ -1,13 +1,13 @@ use std::{future::Future, borrow::Cow, sync::Arc}; -use super::{WebSocket, sign}; +use super::{WebSocket, sign, assume_upgraded}; use crate::{Response, Context, Request}; -use crate::__rt__::{task, Mutex, TcpStream}; +use crate::__rt__::{task, TcpStream}; use crate::http::{Method}; pub struct WebSocketContext { c: Context, - stream: Arc>, + stream: TcpStream, config: Config, @@ -41,7 +41,10 @@ pub struct Config { pub enum UpgradeError { /* TODO */ } impl WebSocketContext { - pub(crate) fn new(c: Context, stream: Arc>, req: &mut Request) -> Result> { + pub(crate) async fn new(c: Context, req: &mut Request) -> Result> { + let id = c.upgrade_id.ok_or(Cow::Borrowed("Failed to upgrade"))?; + let stream = assume_upgraded(id).await; + if req.method() != Method::GET { return Err(Cow::Borrowed("Method is not `GET`")) } @@ -128,27 +131,6 @@ impl WebSocketContext { } = self; task::spawn({ - //#[cfg(debug_assertions)] let mut __loop_count = 0; -// - //let stream = loop { - // #[cfg(debug_assertions)] { - // if __loop_count == usize::MAX {panic!("Infinite loop in web socket handshake")}} -// - // if Arc::strong_count(&stream) == 1 { - // break Arc::into_inner(stream).unwrap().into_inner() - // } -// - // #[cfg(debug_assertions)] { - // __loop_count += 1} - //}; - let stream = { - while Arc::strong_count(&stream) > 1 { - - } - - Arc::into_inner(stream).unwrap().into_inner() - }; - async move { let ws = WebSocket::new(stream); handler(ws).await diff --git a/ohkami/src/x_websocket/mod.rs b/ohkami/src/x_websocket/mod.rs index 396508bf..c40af8d0 100644 --- a/ohkami/src/x_websocket/mod.rs +++ b/ohkami/src/x_websocket/mod.rs @@ -11,6 +11,8 @@ use std::{io::{Error, ErrorKind}, sync::{Arc, OnceLock, atomic::AtomicUsize}, co use crate::__rt__::{TcpStream, RwLock}; use self::message::Message; +pub(crate) use upgrade::{UpgradeID, request_upgrade_id, reserve_upgrade, assume_upgraded}; + pub struct WebSocket { stream: TcpStream diff --git a/ohkami/src/x_websocket/upgrade.rs b/ohkami/src/x_websocket/upgrade.rs index 4bfba3bd..17ab56af 100644 --- a/ohkami/src/x_websocket/upgrade.rs +++ b/ohkami/src/x_websocket/upgrade.rs @@ -4,15 +4,22 @@ type UpgradeLock = Mutex; pub static UPGRADE_STREAMS: OnceLock = OnceLock::new(); -pub async fn wait_upgrade(arc_stream: Arc>) -> UpgradeID { + +pub async fn request_upgrade_id() -> UpgradeID { + UPGRADE_STREAMS.get_or_init(UpgradeStreams::new) + .reserve().await +} +pub async fn reserve_upgrade(id: UpgradeID, stream: Arc>) { UPGRADE_STREAMS.get_or_init(UpgradeStreams::new) - .push(arc_stream).await + .set(id, stream).await } +//pub async fn cancel_upgrade pub async fn assume_upgraded(id: UpgradeID) -> TcpStream { UPGRADE_STREAMS.get_or_init(UpgradeStreams::new) .get(id).await } +#[derive(Clone, Copy)] pub struct UpgradeID(usize); pub struct UpgradeStreams { @@ -25,14 +32,27 @@ pub struct UpgradeStreams { } } } + impl UpgradeStreams { - async fn push(&self, arc_stream: Arc>) -> UpgradeID { + async fn reserve(&self) -> UpgradeID { let mut this = self.streams.lock().await; - let id = match this.iter().position(|us| us.is_available()) { - Some(i) => {this[i] = UpgradeStream::new(arc_stream); i} - None => {this.push(UpgradeStream::new(arc_stream)); this.len()-1} - }; - UpgradeID(id) + match this.iter().position(UpgradeStream::is_empty) { + Some(i) => { + this[i].reserved = true; + UpgradeID(i) + } + None => { + this.push(UpgradeStream { + reserved: true, + stream: None, + }); + UpgradeID(this.len() - 1) + }, + } + } + async fn set(&self, id: UpgradeID, stream: Arc>) { + let mut this = self.streams.lock().await; + this[id.0].stream = Some(stream) } async fn get(&self, id: UpgradeID) -> TcpStream { let mut this = self.streams.lock().await; @@ -41,25 +61,31 @@ pub struct UpgradeStreams { } }; -struct UpgradeStream( - Option>> -); const _: () = { +struct UpgradeStream { + reserved: bool, + stream: Option>>, +} const _: () = { impl UpgradeStream { fn new(arc_stream: Arc>) -> Self { - Self(Some(arc_stream)) + Self { + reserved: false, + stream: Some(arc_stream), + } } - fn is_available(&self) -> bool { - self.0.is_none() + fn is_empty(&self) -> bool { + self.stream.is_none() && !self.reserved } } impl Default for UpgradeStream { - fn default() -> Self {Self(None)} + fn default() -> Self { + Self { reserved: false, stream: None } + } } impl Future for UpgradeStream { type Output = TcpStream; fn poll(self: std::pin::Pin<&mut Self>, _: &mut std::task::Context<'_>) -> std::task::Poll { - match Arc::strong_count(self.0.as_ref().unwrap()) { - 1 => std::task::Poll::Ready(Arc::into_inner(self.get_mut().0.take().unwrap()).unwrap().into_inner()), + match Arc::strong_count(self.stream.as_ref().unwrap()) { + 1 => std::task::Poll::Ready(Arc::into_inner(self.get_mut().stream.take().unwrap()).unwrap().into_inner()), _ => std::task::Poll::Pending, } } From 4738f5b6ced06dd6b124213c0b4a98118a1c2b74 Mon Sep 17 00:00:00 2001 From: kana-rus Date: Sat, 11 Nov 2023 15:50:37 +0900 Subject: [PATCH 24/37] @2023-11-11 15:50+9:00 --- ohkami/src/lib.rs | 11 +++--- ohkami/src/x_websocket/context.rs | 2 +- ohkami/src/x_websocket/mod.rs | 34 +++++-------------- ohkami/src/x_websocket/websocket.rs | 52 +++++++++++++++++++++++++++++ 4 files changed, 67 insertions(+), 32 deletions(-) create mode 100644 ohkami/src/x_websocket/websocket.rs diff --git a/ohkami/src/lib.rs b/ohkami/src/lib.rs index 1d9e15d5..ba079c18 100644 --- a/ohkami/src/lib.rs +++ b/ohkami/src/lib.rs @@ -210,7 +210,7 @@ "); -/*===== runtime dependency injection layer =====*/ +/*===== async runtime dependency layer =====*/ mod __rt__ { #[cfg(all(feature="rt_tokio", feature="DEBUG"))] @@ -225,11 +225,6 @@ mod __rt__ { #[cfg(feature="rt_async-std")] pub(crate) use async_std::sync::Mutex; - #[cfg(feature="rt_tokio")] - pub(crate) use tokio::sync::RwLock; - #[cfg(feature="rt_async-std")] - pub(crate) use async_std::sync::RwLock; - #[cfg(feature="rt_tokio")] pub(crate) use tokio::net::TcpListener; #[cfg(feature="rt_async-std")] @@ -240,6 +235,10 @@ mod __rt__ { #[cfg(feature="rt_async-std")] pub(crate) use async_std::net::TcpStream; + #[cfg(all(feature="rt_tokio", feature="websocket"))] + pub(crate) use tokio::net::tcp::{ReadHalf, WriteHalf}; + /* async-std doesn't have `split` */ + #[cfg(feature="rt_tokio")] pub(crate) use tokio::task; #[cfg(feature="rt_async-std")] diff --git a/ohkami/src/x_websocket/context.rs b/ohkami/src/x_websocket/context.rs index 7235e47a..85349606 100644 --- a/ohkami/src/x_websocket/context.rs +++ b/ohkami/src/x_websocket/context.rs @@ -1,4 +1,4 @@ -use std::{future::Future, borrow::Cow, sync::Arc}; +use std::{future::Future, borrow::Cow}; use super::{WebSocket, sign, assume_upgraded}; use crate::{Response, Context, Request}; use crate::__rt__::{task, TcpStream}; diff --git a/ohkami/src/x_websocket/mod.rs b/ohkami/src/x_websocket/mod.rs index c40af8d0..df11a0e7 100644 --- a/ohkami/src/x_websocket/mod.rs +++ b/ohkami/src/x_websocket/mod.rs @@ -1,34 +1,18 @@ #[cfg(not(target_pointer_width = "64"))] compile_error!{ "pointer width must be 64" } +mod websocket; mod context; mod message; mod upgrade; mod frame; mod sign; -use std::{io::{Error, ErrorKind}, sync::{Arc, OnceLock, atomic::AtomicUsize}, collections::HashMap}; -use crate::__rt__::{TcpStream, RwLock}; -use self::message::Message; - -pub(crate) use upgrade::{UpgradeID, request_upgrade_id, reserve_upgrade, assume_upgraded}; - - -pub struct WebSocket { - stream: TcpStream -} - -impl WebSocket { - fn new(stream: TcpStream) -> Self { - Self { stream } - } -} - -impl WebSocket { - pub async fn recv(&mut self) -> Result, Error> { - Message::read_from(&mut self.stream).await - } - pub async fn send(&mut self, message: Message) -> Result<(), Error> { - message.send(&mut self.stream).await - } -} +pub use { + message::{Message}, + websocket::{WebSocket}, + context::{WebSocketContext, UpgradeError}, +}; +pub(crate) use { + upgrade::{UpgradeID, request_upgrade_id, reserve_upgrade, assume_upgraded}, +}; diff --git a/ohkami/src/x_websocket/websocket.rs b/ohkami/src/x_websocket/websocket.rs new file mode 100644 index 00000000..5fde6062 --- /dev/null +++ b/ohkami/src/x_websocket/websocket.rs @@ -0,0 +1,52 @@ +use std::io::Error; +use super::Message; +use crate::__rt__::{TcpStream}; + + +pub struct WebSocket { + stream: TcpStream +} + +impl WebSocket { + pub(crate) fn new(stream: TcpStream) -> Self { + Self { stream } + } +} + +impl WebSocket { + pub async fn recv(&mut self) -> Result, Error> { + Message::read_from(&mut self.stream).await + } + pub async fn send(&mut self, message: Message) -> Result<(), Error> { + message.send(&mut self.stream).await + } +} + +#[cfg(feature="rt_tokio")] const _: () = { + impl WebSocket { + pub fn split(&mut self) -> (ReadHalf, WriteHalf) { + let (rh, wh) = self.stream.split(); + (ReadHalf(rh), WriteHalf(wh)) + } + } + + + use crate::__rt__::{ + ReadHalf as TcpReadHalf, + WriteHalf as TcpWriteHalf, + }; + + pub struct ReadHalf<'ws>(TcpReadHalf<'ws>); + impl<'ws> ReadHalf<'ws> { + pub async fn recv(&mut self) -> Result, Error> { + Message::read_from(&mut self.0).await + } + } + + pub struct WriteHalf<'ws>(TcpWriteHalf<'ws>); + impl<'ws> WriteHalf<'ws> { + pub async fn send(&mut self, message: Message) -> Result<(), Error> { + message.send(&mut self.0).await + } + } +}; From 4446f517bcfdc272912d87288c2fbb220878fe9e Mon Sep 17 00:00:00 2001 From: kana-rus Date: Sat, 11 Nov 2023 16:34:07 +0900 Subject: [PATCH 25/37] @2023-11-11 16:34+9:00 --- .../handler/into_handler.rs | 20 +++++++++ ohkami/src/x_websocket/context.rs | 41 +++++++++++-------- 2 files changed, 45 insertions(+), 16 deletions(-) diff --git a/ohkami/src/layer3_fang_handler/handler/into_handler.rs b/ohkami/src/layer3_fang_handler/handler/into_handler.rs index df8ee44a..d93b055d 100644 --- a/ohkami/src/layer3_fang_handler/handler/into_handler.rs +++ b/ohkami/src/layer3_fang_handler/handler/into_handler.rs @@ -5,6 +5,8 @@ use crate::{ Response, layer1_req_res::{FromRequest, FromBuffer as PathParam}, }; +#[cfg(feature="websocket")] +use crate::websocket::WebSocketContext; pub trait IntoHandler { @@ -543,3 +545,21 @@ const _: (/* two PathParams and FromRequest items */) = { } } }; + +#[cfg(feature="websocket")] +const _: (/* requires upgrade to websocket */) = { + impl IntoHandler<(WebSocketContext,)> for F + where + F: Fn(WebSocketContext) -> Fut + Send + Sync + 'static, + Fut: Future + Send + Sync + 'static, + { + fn into_handler(self) -> Handler { + Handler::new(true, move |req, c, _| { + match WebSocketContext::new(c, req) { + Ok(wsc) => Box::pin(self(wsc)), + Err(res) => Box::pin(async {res}), + } + }) + } + } +}; diff --git a/ohkami/src/x_websocket/context.rs b/ohkami/src/x_websocket/context.rs index 85349606..73bc3f26 100644 --- a/ohkami/src/x_websocket/context.rs +++ b/ohkami/src/x_websocket/context.rs @@ -5,13 +5,12 @@ use crate::__rt__::{task, TcpStream}; use crate::http::{Method}; -pub struct WebSocketContext { +pub struct WebSocketContext { c: Context, - stream: TcpStream, config: Config, - on_failed_upgrade: Box, + on_failed_upgrade: UFH, sec_websocket_key: Cow<'static, str>, selected_protocol: Option>, @@ -38,36 +37,42 @@ pub struct Config { } }; -pub enum UpgradeError { /* TODO */ } +pub trait UpgradeFailureHandler { + fn handle(self, error: UpgradeError); +} +pub enum UpgradeError { + NotRequestedUpgrade, +} +pub struct DefaultUpgradeFailureHandler; +impl UpgradeFailureHandler for DefaultUpgradeFailureHandler { + fn handle(self, _: UpgradeError) {/* discard error */} +} impl WebSocketContext { - pub(crate) async fn new(c: Context, req: &mut Request) -> Result> { - let id = c.upgrade_id.ok_or(Cow::Borrowed("Failed to upgrade"))?; - let stream = assume_upgraded(id).await; - + pub(crate) fn new(c: Context, req: &mut Request) -> Result { if req.method() != Method::GET { - return Err(Cow::Borrowed("Method is not `GET`")) + return Err((|| c.BadRequest().text("Method is not `GET`"))()) } if req.header("Connection") != Some("upgrade") { - return Err(Cow::Borrowed("Connection header is not `upgrade`")) + return Err((|| c.BadRequest().text("Connection header is not `upgrade`"))()) } if req.header("Upgrade") != Some("websocket") { - return Err(Cow::Borrowed("Upgrade header is not `websocket`")) + return Err((|| c.BadRequest().text("Upgrade header is not `websocket`"))()) } if req.header("Sec-WebSocket-Version") != Some("13") { - return Err(Cow::Borrowed("Sec-WebSocket-Version header is not `13`")) + return Err((|| c.BadRequest().text("Sec-WebSocket-Version header is not `13`"))()) } let sec_websocket_key = Cow::Owned(req.header("Sec-WebSocket-Key") - .ok_or(Cow::Borrowed("Sec-WebSocket-Key header is missing"))? + .ok_or_else(|| c.BadRequest().text("Sec-WebSocket-Key header is missing"))? .to_string()); let sec_websocket_protocol = req.header("Sec-WebSocket-Protocol") .map(|swp| Cow::Owned(swp.to_string())); - Ok(Self {c, stream, + Ok(Self {c, config: Config::default(), - on_failed_upgrade: Box::new(|_| (/* discard error */)), + on_failed_upgrade: DefaultUpgradeFailureHandler, selected_protocol: None, sec_websocket_key, sec_websocket_protocol, @@ -122,7 +127,6 @@ impl WebSocketContext { let Self { mut c, - stream, config, on_failed_upgrade, selected_protocol, @@ -132,6 +136,11 @@ impl WebSocketContext { task::spawn({ async move { + let stream = match c.upgrade_id { + None => return on_failed_upgrade.handle(UpgradeError::NotRequestedUpgrade), + Some(id) => assume_upgraded(id).await, + }; + let ws = WebSocket::new(stream); handler(ws).await } From b5e8040f8b1fa544b1ae2e09cfe143d15186ee70 Mon Sep 17 00:00:00 2001 From: kana-rus Date: Sat, 11 Nov 2023 21:03:41 +0900 Subject: [PATCH 26/37] @2023-11-11 21:03+9:00 --- .../handler/into_handler.rs | 495 +++++------------- ohkami/src/layer4_router/radix.rs | 4 +- ohkami/src/layer5_ohkami/howl.rs | 10 + ohkami/src/layer6_testing/mod.rs | 51 -- ohkami/src/x_websocket/context.rs | 2 +- ohkami/src/x_websocket/upgrade.rs | 6 - 6 files changed, 152 insertions(+), 416 deletions(-) diff --git a/ohkami/src/layer3_fang_handler/handler/into_handler.rs b/ohkami/src/layer3_fang_handler/handler/into_handler.rs index d93b055d..629fa380 100644 --- a/ohkami/src/layer3_fang_handler/handler/into_handler.rs +++ b/ohkami/src/layer3_fang_handler/handler/into_handler.rs @@ -13,6 +13,16 @@ pub trait IntoHandler { fn into_handler(self) -> Handler; } +#[cold] fn __bad_request( + c: &Context, + e: std::borrow::Cow<'static, str>, +) -> std::pin::Pin>> { + Box::pin({ + let res = c.BadRequest().text(e.to_string()); + async {res} + }) +} + const _: (/* only Context */) = { impl IntoHandler<(Context,)> for F where @@ -21,10 +31,7 @@ const _: (/* only Context */) = { { fn into_handler(self) -> Handler { Handler::new(false, move |_, c, _| - Box::pin({ - let res = self(c); - async {res.await} - }) + Box::pin(self(c)) ) } } @@ -41,14 +48,8 @@ const _: (/* PathParam */) = { fn into_handler(self) -> Handler { Handler::new(false, move |_, c, params| match <$param_type as PathParam>::parse(unsafe {params.assume_init_first().as_bytes()}) { - Ok(p1) => Box::pin({ - let res = self(c, p1); - async {res.await} - }), - Err(e) => Box::pin({ - let res = c.BadRequest().text(e.to_string()); - async {res} - }), + Ok(p1) => Box::pin(self(c, p1)), + Err(e) => __bad_request(&c, e) } ) } @@ -68,14 +69,8 @@ const _: (/* PathParam */) = { // SAFETY: Due to the architecture of `Router`, // `params` has already `append`ed once before this code match ::parse(unsafe {params.assume_init_first().as_bytes()}) { - Ok(p1) => Box::pin({ - let res = self(c, (p1,)); - async {res.await} - }), - Err(e) => Box::pin({ - let res = c.BadRequest().text(e.to_string()); - async {res} - }), + Ok(p1) => Box::pin(self(c, (p1,))), + Err(e) => __bad_request(&c, e) } ) } @@ -88,24 +83,11 @@ const _: (/* PathParam */) = { { fn into_handler(self) -> Handler { Handler::new(false, move |_, c, params| { - let (p1_range, p2_range) = params.assume_init_extract(); - // SAFETY: Due to the architecture of `Router`, - // `params` has already `append`ed twice before this code - match ::parse(unsafe {p1_range.as_bytes()}) { - Ok(p1) => match ::parse(unsafe {p2_range.as_bytes()}) { - Ok(p2) => Box::pin({ - let res = self(c, (p1, p2)); - async {res.await} - }), - Err(e) => Box::pin({ - let res = c.BadRequest().text(e.to_string()); - async {res} - }) - } - Err(e) => Box::pin({ - let res = c.BadRequest().text(e.to_string()); - async {res} - }), + let (p1, p2) = params.assume_init_extract(); + let (p1, p2) = unsafe {(p1.as_bytes(), p2.as_bytes())}; + match (::parse(p1), ::parse(p2)) { + (Ok(p1), Ok(p2)) => Box::pin(self(c, (p1, p2))), + (Err(e), _) | (_, Err(e)) => __bad_request(&c, e), } }) } @@ -120,15 +102,9 @@ const _: (/* FromRequest items */) = { { fn into_handler(self) -> Handler { Handler::new(false, move |req, c, _| - match Item1::parse(&req) { - Ok(item1) => Box::pin({ - let res = self(c, item1); - async {res.await} - }), - Err(e) => Box::pin({ - let res = c.BadRequest().text(e.to_string()); - async {res} - }) + match Item1::parse(req) { + Ok(item1) => Box::pin(self(c, item1)), + Err(e) => __bad_request(&c, e) } ) } @@ -141,54 +117,9 @@ const _: (/* FromRequest items */) = { { fn into_handler(self) -> Handler { Handler::new(false, move |req, c, _| - match Item1::parse(&req) { - Ok(item1) => match Item2::parse(&req) { - Ok(item2) => Box::pin({ - let res = self(c, item1, item2); - async {res.await} - }), - Err(e) => Box::pin({ - let res = c.BadRequest().text(e.to_string()); - async {res} - }) - } - Err(e) => Box::pin({ - let res = c.BadRequest().text(e.to_string()); - async {res} - }) - } - ) - } - } - - impl IntoHandler<(Context, Item1, Item2, Item3)> for F - where - F: Fn(Context, Item1, Item2, Item3) -> Fut + Send + Sync + 'static, - Fut: Future + Send + Sync + 'static, - { - fn into_handler(self) -> Handler { - Handler::new(false, move |req, c, _| - match Item1::parse(&req) { - Ok(item1) => match Item2::parse(&req) { - Ok(item2) => match Item3::parse(&req) { - Ok(item3) => Box::pin({ - let res = self(c, item1, item2, item3); - async {res.await} - }), - Err(e) => Box::pin({ - let res = c.BadRequest().text(e.to_string()); - async {res} - }) - } - Err(e) => Box::pin({ - let res = c.BadRequest().text(e.to_string()); - async {res} - }) - } - Err(e) => Box::pin({ - let res = c.BadRequest().text(e.to_string()); - async {res} - }) + match (Item1::parse(req), Item2::parse(req)) { + (Ok(item1), Ok(item2)) => Box::pin(self(c, item1, item2)), + (Err(e), _) | (_, Err(e)) => __bad_request(&c, e), } ) } @@ -204,26 +135,16 @@ const _: (/* single PathParam and FromRequest items */) = { Fut: Future + Send + Sync + 'static, { fn into_handler(self) -> Handler { - Handler::new(false, move |req, c, params| + Handler::new(false, move |req, c, params| { // SAFETY: Due to the architecture of `Router`, // `params` has already `append`ed once before this code - match <$param_type as PathParam>::parse(unsafe {params.assume_init_first().as_bytes()}) { - Ok(p1) => match Item1::parse(&req) { - Ok(item1) => Box::pin({ - let res = self(c, p1, item1); - async {res.await} - }), - Err(e) => Box::pin({ - let res = c.BadRequest().text(e.to_string()); - async {res} - }) - } - Err(e) => Box::pin({ - let res = c.BadRequest().text(e.to_string()); - async {res} - }), + let p1 = unsafe {params.assume_init_first().as_bytes()}; + + match (<$param_type as PathParam>::parse(p1), Item1::parse(req)) { + (Ok(p1), Ok(item1)) => Box::pin(self(c, p1, item1)), + (Err(e), _) | (_, Err(e)) => __bad_request(&c, e), } - ) + }) } } @@ -233,73 +154,16 @@ const _: (/* single PathParam and FromRequest items */) = { Fut: Future + Send + Sync + 'static, { fn into_handler(self) -> Handler { - Handler::new(false, move |req, c, params| + Handler::new(false, move |req, c, params| { // SAFETY: Due to the architecture of `Router`, // `params` has already `append`ed once before this code - match <$param_type as PathParam>::parse(unsafe {params.assume_init_first().as_bytes()}) { - Ok(p1) => match Item1::parse(&req) { - Ok(item1) => match Item2::parse(&req) { - Ok(item2) => Box::pin({ - let res = self(c, p1, item1, item2); - async {res.await} - }), - Err(e) => Box::pin({ - let res = c.BadRequest().text(e.to_string()); - async {res} - }) - } - Err(e) => Box::pin({ - let res = c.BadRequest().text(e.to_string()); - async {res} - }) - } - Err(e) => Box::pin({ - let res = c.BadRequest().text(e.to_string()); - async {res} - }), - } - ) - } - } + let p1 = unsafe {params.assume_init_first().as_bytes()}; - impl IntoHandler<(Context, $param_type, Item1, Item2, Item3)> for F - where - F: Fn(Context, $param_type, Item1, Item2, Item3) -> Fut + Send + Sync + 'static, - Fut: Future + Send + Sync + 'static, - { - fn into_handler(self) -> Handler { - Handler::new(false, move |req, c, params| - // SAFETY: Due to the architecture of `Router`, - // `params` has already `append`ed once before this code - match <$param_type as PathParam>::parse(unsafe {params.assume_init_first().as_bytes()}) { - Ok(p1) => match Item1::parse(&req) { - Ok(item1) => match Item2::parse(&req) { - Ok(item2) => match Item3::parse(&req) { - Ok(item3) => Box::pin({ - let res = self(c, p1, item1, item2, item3); - async {res.await} - }), - Err(e) => Box::pin({ - let res = c.BadRequest().text(e.to_string()); - async {res} - }) - } - Err(e) => Box::pin({ - let res = c.BadRequest().text(e.to_string()); - async {res} - }) - } - Err(e) => Box::pin({ - let res = c.BadRequest().text(e.to_string()); - async {res} - }) - } - Err(e) => Box::pin({ - let res = c.BadRequest().text(e.to_string()); - async {res} - }), + match (<$param_type as PathParam>::parse(p1), Item1::parse(req), Item2::parse(req)) { + (Ok(p1), Ok(item1), Ok(item2)) => Box::pin(self(c, p1, item1, item2)), + (Err(e),_,_)|(_,Err(e),_)|(_,_,Err(e)) => __bad_request(&c, e), } - ) + }) } } )*}; @@ -315,26 +179,16 @@ const _: (/* one PathParam and FromRequest items */) = { Fut: Future + Send + Sync + 'static, { fn into_handler(self) -> Handler { - Handler::new(false, move |req, c, params| + Handler::new(false, move |req, c, params| { // SAFETY: Due to the architecture of `Router`, // `params` has already `append`ed once before this code - match P1::parse(unsafe {params.assume_init_first().as_bytes()}) { - Ok(p1) => match Item1::parse(&req) { - Ok(item1) => Box::pin({ - let res = self(c, (p1,), item1); - async {res.await} - }), - Err(e) => Box::pin({ - let res = c.BadRequest().text(e.to_string()); - async {res} - }) - } - Err(e) => Box::pin({ - let res = c.BadRequest().text(e.to_string()); - async {res} - }), + let p1 = unsafe {params.assume_init_first().as_bytes()}; + + match (P1::parse(p1), Item1::parse(req)) { + (Ok(p1), Ok(item1)) => Box::pin(self(c, (p1,), item1)), + (Err(e),_)|(_,Err(e)) => __bad_request(&c, e) } - ) + }) } } @@ -344,73 +198,16 @@ const _: (/* one PathParam and FromRequest items */) = { Fut: Future + Send + Sync + 'static, { fn into_handler(self) -> Handler { - Handler::new(false, move |req, c, params| + Handler::new(false, move |req, c, params| { // SAFETY: Due to the architecture of `Router`, // `params` has already `append`ed once before this code - match P1::parse(unsafe {params.assume_init_first().as_bytes()}) { - Ok(p1) => match Item1::parse(&req) { - Ok(item1) => match Item2::parse(&req) { - Ok(item2) => Box::pin({ - let res = self(c, (p1,), item1, item2); - async {res.await} - }), - Err(e) => Box::pin({ - let res = c.BadRequest().text(e.to_string()); - async {res} - }) - } - Err(e) => Box::pin({ - let res = c.BadRequest().text(e.to_string()); - async {res} - }) - } - Err(e) => Box::pin({ - let res = c.BadRequest().text(e.to_string()); - async {res} - }), - } - ) - } - } + let p1 = unsafe {params.assume_init_first().as_bytes()}; - impl IntoHandler<(Context, (P1,), Item1, Item2, Item3)> for F - where - F: Fn(Context, (P1,), Item1, Item2, Item3) -> Fut + Send + Sync + 'static, - Fut: Future + Send + Sync + 'static, - { - fn into_handler(self) -> Handler { - Handler::new(false, move |req, c, params| - // SAFETY: Due to the architecture of `Router`, - // `params` has already `append`ed once before this code - match P1::parse(unsafe {params.assume_init_first().as_bytes()}) { - Ok(p1) => match Item1::parse(&req) { - Ok(item1) => match Item2::parse(&req) { - Ok(item2) => match Item3::parse(&req) { - Ok(item3) => Box::pin({ - let res = self(c, (p1,), item1, item2, item3); - async {res.await} - }), - Err(e) => Box::pin({ - let res = c.BadRequest().text(e.to_string()); - async {res} - }) - } - Err(e) => Box::pin({ - let res = c.BadRequest().text(e.to_string()); - async {res} - }) - } - Err(e) => Box::pin({ - let res = c.BadRequest().text(e.to_string()); - async {res} - }) - } - Err(e) => Box::pin({ - let res = c.BadRequest().text(e.to_string()); - async {res} - }), + match (P1::parse(p1), Item1::parse(req), Item2::parse(req)) { + (Ok(p1), Ok(item1), Ok(item2)) => Box::pin(self(c, (p1,), item1, item2)), + (Err(e),_,_)|(_,Err(e),_)|(_,_,Err(e)) => __bad_request(&c, e), } - ) + }) } } }; @@ -423,31 +220,14 @@ const _: (/* two PathParams and FromRequest items */) = { { fn into_handler(self) -> Handler { Handler::new(false, move |req, c, params| { - let (p1_range, p2_range) = params.assume_init_extract(); - // SAFETY: Due to the architecture of `Router`, // `params` has already `append`ed twice before this code - match ::parse(unsafe {p1_range.as_bytes()}) { - Ok(p1) => match ::parse(unsafe {p2_range.as_bytes()}) { - Ok(p2) => match Item1::parse(&req) { - Ok(item1) => Box::pin({ - let res = self(c, (p1, p2), item1); - async {res.await} - }), - Err(e) => Box::pin({ - let res = c.BadRequest().text(e.to_string()); - async {res} - }) - } - Err(e) => Box::pin({ - let res = c.BadRequest().text(e.to_string()); - async {res} - }) - } - Err(e) => Box::pin({ - let res = c.BadRequest().text(e.to_string()); - async {res} - }), + let (p1, p2) = params.assume_init_extract(); + let (p1, p2) = unsafe {(p1.as_bytes(), p2.as_bytes())}; + + match (P1::parse(p1), P2::parse(p2), Item1::parse(req)) { + (Ok(p1), Ok(p2), Ok(item1)) => Box::pin(self(c, (p1, p2), item1)), + (Err(e),_,_)|(_,Err(e),_)|(_,_,Err(e)) => __bad_request(&c, e), } }) } @@ -460,86 +240,14 @@ const _: (/* two PathParams and FromRequest items */) = { { fn into_handler(self) -> Handler { Handler::new(false, move |req, c, params| { - let (p1_range, p2_range) = params.assume_init_extract(); - // SAFETY: Due to the architecture of `Router`, // `params` has already `append`ed twice before this code - match ::parse(unsafe {p1_range.as_bytes()}) { - Ok(p1) => match ::parse(unsafe {p2_range.as_bytes()}) { - Ok(p2) => match Item1::parse(&req) { - Ok(item1) => match Item2::parse(&req) { - Ok(item2) => Box::pin({ - let res = self(c, (p1, p2), item1, item2); - async {res.await} - }), - Err(e) => Box::pin({ - let res = c.BadRequest().text(e.to_string()); - async {res} - }) - } - Err(e) => Box::pin({ - let res = c.BadRequest().text(e.to_string()); - async {res} - }) - } - Err(e) => Box::pin({ - let res = c.BadRequest().text(e.to_string()); - async {res} - }) - } - Err(e) => Box::pin({ - let res = c.BadRequest().text(e.to_string()); - async {res} - }), - } - }) - } - } - - impl IntoHandler<(Context, (P1, P2), Item1, Item2, Item3)> for F - where - F: Fn(Context, (P1, P2), Item1, Item2, Item3) -> Fut + Send + Sync + 'static, - Fut: Future + Send + Sync + 'static, - { - fn into_handler(self) -> Handler { - Handler::new(false, move |req, c, params| { - let (p1_range, p2_range) = params.assume_init_extract(); + let (p1, p2) = params.assume_init_extract(); + let (p1, p2) = unsafe {(p1.as_bytes(), p2.as_bytes())}; - // SAFETY: Due to the architecture of `Router`, - // `params` has already `append`ed twice before this code - match ::parse(unsafe {p1_range.as_bytes()}) { - Ok(p1) => match ::parse(unsafe {p2_range.as_bytes()}) { - Ok(p2) => match Item1::parse(&req) { - Ok(item1) => match Item2::parse(&req) { - Ok(item2) => match Item3::parse(&req) { - Ok(item3) => Box::pin({ - let res = self(c, (p1, p2), item1, item2, item3); - async {res.await} - }), - Err(e) => Box::pin({ - let res = c.BadRequest().text(e.to_string()); - async {res} - }) - } - Err(e) => Box::pin({ - let res = c.BadRequest().text(e.to_string()); - async {res} - }) - } - Err(e) => Box::pin({ - let res = c.BadRequest().text(e.to_string()); - async {res} - }) - } - Err(e) => Box::pin({ - let res = c.BadRequest().text(e.to_string()); - async {res} - }) - } - Err(e) => Box::pin({ - let res = c.BadRequest().text(e.to_string()); - async {res} - }), + match (P1::parse(p1), P2::parse(p2), Item1::parse(req), Item2::parse(req)) { + (Ok(p1), Ok(p2), Ok(item1), Ok(item2)) => Box::pin(self(c, (p1, p2), item1, item2)), + (Err(e),_,_,_)|(_,Err(e),_,_)|(_,_,Err(e),_)|(_,_,_,Err(e)) => __bad_request(&c, e), } }) } @@ -557,7 +265,82 @@ const _: (/* requires upgrade to websocket */) = { Handler::new(true, move |req, c, _| { match WebSocketContext::new(c, req) { Ok(wsc) => Box::pin(self(wsc)), - Err(res) => Box::pin(async {res}), + Err(res) => (|| Box::pin(async {res}))(), + } + }) + } + } + + impl IntoHandler<(WebSocketContext, P1)> for F + where + F: Fn(WebSocketContext, P1) -> Fut + Send + Sync + 'static, + Fut: Future + Send + Sync + 'static, + { + fn into_handler(self) -> Handler { + Handler::new(true, move |req, c, params| { + let p1 = unsafe {params.assume_init_first().as_bytes()}; + match P1::parse(p1) { + Ok(p1) => match WebSocketContext::new(c, req) { + Ok(wsc) => Box::pin(self(wsc, p1)), + Err(res) => (|| Box::pin(async {res}))(), + } + Err(e) => __bad_request(&c, e), + } + }) + } + } + impl IntoHandler<(WebSocketContext, P1, P2)> for F + where + F: Fn(WebSocketContext, P1, P2) -> Fut + Send + Sync + 'static, + Fut: Future + Send + Sync + 'static, + { + fn into_handler(self) -> Handler { + Handler::new(true, move |req, c, params| { + let (p1, p2) = params.assume_init_extract(); + let (p1, p2) = unsafe {(p1.as_bytes(), p2.as_bytes())}; + match (P1::parse(p1), P2::parse(p2)) { + (Ok(p1), Ok(p2)) => match WebSocketContext::new(c, req) { + Ok(wsc) => Box::pin(self(wsc, p1, p2)), + Err(res) => (|| Box::pin(async {res}))(), + } + (Err(e),_)|(_,Err(e)) => __bad_request(&c, e), + } + }) + } + } + impl IntoHandler<(WebSocketContext, (P1,))> for F + where + F: Fn(WebSocketContext, (P1,)) -> Fut + Send + Sync + 'static, + Fut: Future + Send + Sync + 'static, + { + fn into_handler(self) -> Handler { + Handler::new(true, move |req, c, params| { + let p1 = unsafe {params.assume_init_first().as_bytes()}; + match P1::parse(p1) { + Ok(p1) => match WebSocketContext::new(c, req) { + Ok(wsc) => Box::pin(self(wsc, (p1,))), + Err(res) => (|| Box::pin(async {res}))(), + } + Err(e) => __bad_request(&c, e), + } + }) + } + } + impl IntoHandler<(WebSocketContext, (P1, P2))> for F + where + F: Fn(WebSocketContext, (P1, P2)) -> Fut + Send + Sync + 'static, + Fut: Future + Send + Sync + 'static, + { + fn into_handler(self) -> Handler { + Handler::new(true, move |req, c, params| { + let (p1, p2) = params.assume_init_extract(); + let (p1, p2) = unsafe {(p1.as_bytes(), p2.as_bytes())}; + match (P1::parse(p1), P2::parse(p2)) { + (Ok(p1), Ok(p2)) => match WebSocketContext::new(c, req) { + Ok(wsc) => Box::pin(self(wsc, (p1, p2))), + Err(res) => (|| Box::pin(async {res}))(), + } + (Err(e),_)|(_,Err(e)) => __bad_request(&c, e), } }) } diff --git a/ohkami/src/layer4_router/radix.rs b/ohkami/src/layer4_router/radix.rs index c46909f2..3084042f 100644 --- a/ohkami/src/layer4_router/radix.rs +++ b/ohkami/src/layer4_router/radix.rs @@ -150,8 +150,8 @@ impl RadixRouter { Err(err_res) => return (err_res, None), }; - #[cfg(not(test))] {target.handle(c, req, params).await} - #[cfg(test)] {(target.handle_discarding_upgrade(c, req, params).await, None)} + #[cfg(feature="websocket")] {target.handle(c, req, params).await} + #[cfg(not(feature="websocket"))] {(target.handle_discarding_upgrade(c, req, params).await, None)} } } diff --git a/ohkami/src/layer5_ohkami/howl.rs b/ohkami/src/layer5_ohkami/howl.rs index d088a7be..ab84a13d 100644 --- a/ohkami/src/layer5_ohkami/howl.rs +++ b/ohkami/src/layer5_ohkami/howl.rs @@ -76,17 +76,27 @@ impl Ohkami { let mut req = unsafe {Pin::new_unchecked(&mut req)}; req.as_mut().read(stream).await; + #[cfg(not(feature="websocket"))] + let res = router.handle_discarding_upgrade(Context::new(), req.get_mut()).await; + #[cfg(feature="websocket")] let (res, upgrade_id) = router.handle(Context::new(), req.get_mut()).await; + res.send(stream).await; + #[cfg(feature="websocket")] upgrade_id } }).await { + #[cfg(not(feature="websocket"))] + Ok(_) => (), + + #[cfg(feature="websocket")] Ok(upgrade_id) => { if let Some(id) = upgrade_id { reserve_upgrade(id, stream).await } } + Err(e) => (|| async { println!("Fatal error: {e}"); let res = Context::new().InternalServerError(); diff --git a/ohkami/src/layer6_testing/mod.rs b/ohkami/src/layer6_testing/mod.rs index 517da100..6059849d 100644 --- a/ohkami/src/layer6_testing/mod.rs +++ b/ohkami/src/layer6_testing/mod.rs @@ -8,9 +8,6 @@ use byte_reader::Reader; use crate::{Response, Request, Ohkami, Context}; use crate::layer0_lib::{IntoCows, Status, Method, ContentType}; -#[cfg(feature="websocket")] -use {std::sync::Arc, crate::__rt__::Mutex}; - pub trait Testing { fn oneshot(&self, req: TestRequest) -> TestFuture; @@ -24,53 +21,6 @@ impl Future for TestFuture { } } -#[cfg(feature="websocket")] -pub struct TestWebSocket(Vec); -#[cfg(feature="websocket")] const _: () = { - impl TestWebSocket { - fn new(size: usize) -> Self { - Self(Vec::with_capacity(size)) - } - } - - #[cfg(feature="rt_tokio")] const _: () = { - impl tokio::io::AsyncRead for TestWebSocket { - fn poll_read( - self: Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> std::task::Poll> { - let mut pin_inner = unsafe {self.map_unchecked_mut(|this| &mut this.0)}; - - let amt = std::cmp::min(pin_inner.len(), buf.remaining()); - let (a, b) = pin_inner.split_at(amt); - buf.put_slice(a); - *pin_inner = b.to_vec(); - - std::task::Poll::Ready(Ok(())) - } - } - impl tokio::io::AsyncWrite for TestWebSocket { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &[u8], - ) -> std::task::Poll> { - unsafe {self.map_unchecked_mut(|this| &mut this.0)} - .poll_write(cx, buf) - } - fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { - unsafe {self.map_unchecked_mut(|this| &mut this.0)} - .poll_flush(cx) - } - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { - unsafe {self.map_unchecked_mut(|this| &mut this.0)} - .poll_shutdown(cx) - } - } - }; -}; - impl Testing for Ohkami { fn oneshot(&self, request: TestRequest) -> TestFuture { let router = { @@ -94,7 +44,6 @@ impl Testing for Ohkami { } } - pub struct TestRequest { method: Method, path: Cow<'static, str>, diff --git a/ohkami/src/x_websocket/context.rs b/ohkami/src/x_websocket/context.rs index 73bc3f26..590fec97 100644 --- a/ohkami/src/x_websocket/context.rs +++ b/ohkami/src/x_websocket/context.rs @@ -1,7 +1,7 @@ use std::{future::Future, borrow::Cow}; use super::{WebSocket, sign, assume_upgraded}; use crate::{Response, Context, Request}; -use crate::__rt__::{task, TcpStream}; +use crate::__rt__::{task}; use crate::http::{Method}; diff --git a/ohkami/src/x_websocket/upgrade.rs b/ohkami/src/x_websocket/upgrade.rs index 17ab56af..06ba8e2f 100644 --- a/ohkami/src/x_websocket/upgrade.rs +++ b/ohkami/src/x_websocket/upgrade.rs @@ -66,12 +66,6 @@ struct UpgradeStream { stream: Option>>, } const _: () = { impl UpgradeStream { - fn new(arc_stream: Arc>) -> Self { - Self { - reserved: false, - stream: Some(arc_stream), - } - } fn is_empty(&self) -> bool { self.stream.is_none() && !self.reserved } From bf09249c00ebeef2a2cac4a4bdce38128be89cde Mon Sep 17 00:00:00 2001 From: kana-rus Date: Sun, 12 Nov 2023 02:14:32 +0900 Subject: [PATCH 27/37] @2023-11-12 02:14+9:00 --- ohkami/src/x_websocket/context.rs | 27 +------ ohkami/src/x_websocket/frame.rs | 17 +++-- ohkami/src/x_websocket/message.rs | 20 +++-- ohkami/src/x_websocket/websocket.rs | 111 ++++++++++++++++++++++++---- 4 files changed, 126 insertions(+), 49 deletions(-) diff --git a/ohkami/src/x_websocket/context.rs b/ohkami/src/x_websocket/context.rs index 590fec97..48b0d731 100644 --- a/ohkami/src/x_websocket/context.rs +++ b/ohkami/src/x_websocket/context.rs @@ -1,4 +1,5 @@ use std::{future::Future, borrow::Cow}; +use super::websocket::Config; use super::{WebSocket, sign, assume_upgraded}; use crate::{Response, Context, Request}; use crate::__rt__::{task}; @@ -17,26 +18,6 @@ pub struct WebSocketContext>, } -pub struct Config { - write_buffer_size: usize, - max_write_buffer_size: usize, - max_message_size: Option, - max_frame_size: Option, - accept_unmasked_frames: bool, -} const _: () = { - impl Default for Config { - fn default() -> Self { - Self { - write_buffer_size: 128 * 1024, // 128 KiB - max_write_buffer_size: usize::MAX, - max_message_size: Some(64 << 20), - max_frame_size: Some(16 << 20), - accept_unmasked_frames: false, - } - } - } -}; - pub trait UpgradeFailureHandler { fn handle(self, error: UpgradeError); } @@ -131,17 +112,17 @@ impl WebSocketContext { on_failed_upgrade, selected_protocol, sec_websocket_key, - sec_websocket_protocol, + .. } = self; task::spawn({ async move { let stream = match c.upgrade_id { - None => return on_failed_upgrade.handle(UpgradeError::NotRequestedUpgrade), Some(id) => assume_upgraded(id).await, + None => return on_failed_upgrade.handle(UpgradeError::NotRequestedUpgrade), }; - let ws = WebSocket::new(stream); + let ws = WebSocket::new(stream, config); handler(ws).await } }); diff --git a/ohkami/src/x_websocket/frame.rs b/ohkami/src/x_websocket/frame.rs index 187fd305..65511fae 100644 --- a/ohkami/src/x_websocket/frame.rs +++ b/ohkami/src/x_websocket/frame.rs @@ -1,5 +1,6 @@ use std::io::{Error, ErrorKind}; use crate::__rt__::{AsyncReader, AsyncWriter}; +use super::websocket::Config; #[derive(PartialEq)] @@ -53,7 +54,10 @@ pub struct Frame { pub mask: Option<[u8; 4]>, pub payload: Vec, } impl Frame { - pub(super) async fn read_from(stream: &mut (impl AsyncReader + Unpin)) -> Result, Error> { + pub(super) async fn read_from( + stream: &mut (impl AsyncReader + Unpin), + config: &Config, + ) -> Result, Error> { let [first, second] = { let mut head = [0; 2]; stream.read_exact(&mut head).await?; @@ -65,7 +69,7 @@ pub struct Frame { let payload_len = { let payload_len_byte = second & 0x7F; - let len_part_size = match payload_len_byte {126=>2, 127=>8, _=>0}; + let len_part_size = match payload_len_byte {127=>8, 126=>2, _=>0}; match len_part_size { 0 => payload_len_byte as usize, _ => { @@ -101,7 +105,10 @@ pub struct Frame { Ok(Some(Self { is_final, opcode, mask, payload })) } - pub(super) async fn write_to(self, stream: &mut (impl AsyncWriter + Unpin)) -> Result<(), Error> { + pub(super) async fn write_to(self, + stream: &mut (impl AsyncWriter + Unpin), + config: &Config, + ) -> Result { fn into_bytes(frame: Frame) -> Vec { let Frame { is_final, opcode, mask, payload } = frame; @@ -119,12 +126,12 @@ pub struct Frame { header_bytes.append(&mut payload_len_bytes) } if let Some(mask_bytes) = mask { - header_bytes.extend_from_slice(&mask_bytes) + header_bytes.extend(mask_bytes) } [header_bytes, payload].concat() } - stream.write_all(&into_bytes(self)).await + stream.write(&into_bytes(self)).await } } diff --git a/ohkami/src/x_websocket/message.rs b/ohkami/src/x_websocket/message.rs index 287fb033..1d15063c 100644 --- a/ohkami/src/x_websocket/message.rs +++ b/ohkami/src/x_websocket/message.rs @@ -1,6 +1,6 @@ use std::{borrow::Cow, io::{Error, ErrorKind}}; use crate::{__rt__::{AsyncReader, AsyncWriter}}; -use super::frame::{Frame, OpCode}; +use super::{frame::{Frame, OpCode}, websocket::Config}; pub enum Message { @@ -43,7 +43,10 @@ const _: (/* `From` impls */) = { }; impl Message { - pub(super) async fn send(self, stream: &mut (impl AsyncWriter + Unpin)) -> Result<(), Error> { + pub(super) async fn write(self, + stream: &mut (impl AsyncWriter + Unpin), + config: &Config, + ) -> Result { fn into_frame(message: Message) -> Frame { let (opcode, payload) = match message { Message::Text (text) => (OpCode::Text, text.into_bytes()), @@ -61,13 +64,16 @@ impl Message { Frame { is_final: false, mask: None, opcode, payload } } - into_frame(self).write_to(stream).await + into_frame(self).write_to(stream, config).await } } impl Message { - pub(super) async fn read_from(stream: &mut (impl AsyncReader + Unpin)) -> Result, Error> { - let head_frame = match Frame::read_from(stream).await? { + pub(super) async fn read_from( + stream: &mut (impl AsyncReader + Unpin), + config: &Config, + ) -> Result, Error> { + let head_frame = match Frame::read_from(stream, config).await? { Some(frame) => frame, None => return Ok(None), }; @@ -77,7 +83,7 @@ impl Message { let mut payload = String::from_utf8(head_frame.payload) .map_err(|_| Error::new(ErrorKind::InvalidData, "Text frame's payload is not valid UTF-8"))?; if !head_frame.is_final { - while let Ok(Some(next_frame)) = Frame::read_from(stream).await { + while let Ok(Some(next_frame)) = Frame::read_from(stream, config).await { if next_frame.opcode != OpCode::Continue { return Err(Error::new(ErrorKind::InvalidData, "Expected continue frame")); } @@ -94,7 +100,7 @@ impl Message { OpCode::Binary => { let mut payload = head_frame.payload; if !head_frame.is_final { - while let Ok(Some(mut next_frame)) = Frame::read_from(stream).await { + while let Ok(Some(mut next_frame)) = Frame::read_from(stream, config).await { if next_frame.opcode != OpCode::Continue { return Err(Error::new(ErrorKind::InvalidData, "Expected continue frame")); } diff --git a/ohkami/src/x_websocket/websocket.rs b/ohkami/src/x_websocket/websocket.rs index 5fde6062..d0fc48f1 100644 --- a/ohkami/src/x_websocket/websocket.rs +++ b/ohkami/src/x_websocket/websocket.rs @@ -1,52 +1,135 @@ use std::io::Error; -use super::Message; -use crate::__rt__::{TcpStream}; +use super::{Message}; +use crate::__rt__::{TcpStream, AsyncWriter}; pub struct WebSocket { - stream: TcpStream + stream: TcpStream, + config: Config, + + n_buffered: usize, } +// :fields may set through `WebSocketContext`'s methods +pub struct Config { + pub(crate) write_buffer_size: usize, + pub(crate) max_write_buffer_size: usize, + pub(crate) max_message_size: Option, + pub(crate) max_frame_size: Option, + pub(crate) accept_unmasked_frames: bool, +} const _: () = { + impl Default for Config { + fn default() -> Self { + Self { + write_buffer_size: 128 * 1024, // 128 KiB + max_write_buffer_size: usize::MAX, + max_message_size: Some(64 << 20), + max_frame_size: Some(16 << 20), + accept_unmasked_frames: false, + } + } + } +}; impl WebSocket { - pub(crate) fn new(stream: TcpStream) -> Self { - Self { stream } + pub(crate) fn new(stream: TcpStream, config: Config) -> Self { + Self { stream, config, n_buffered:0 } } } impl WebSocket { pub async fn recv(&mut self) -> Result, Error> { - Message::read_from(&mut self.stream).await + Message::read_from(&mut self.stream, &self.config).await } +} + +// ============================================================================= +async fn send(message:Message, + stream: &mut (impl AsyncWriter + Unpin), + config: &Config, + n_buffered: &mut usize, +) -> Result<(), Error> { + message.write(stream, config).await?; + flush(stream, n_buffered).await?; + Ok(()) +} +async fn write(message:Message, + stream: &mut (impl AsyncWriter + Unpin), + config: &Config, + n_buffered: &mut usize, +) -> Result { + let n = message.write(stream, config).await?; + + *n_buffered += n; + if *n_buffered > config.write_buffer_size { + if *n_buffered > config.max_write_buffer_size { + panic!("Buffered messages is larger than `max_write_buffer_size`"); + } else { + flush(stream, n_buffered).await? + } + } + + Ok(n) +} +async fn flush( + stream: &mut (impl AsyncWriter + Unpin), + n_buffered: &mut usize, +) -> Result<(), Error> { + stream.flush().await + .map(|_| *n_buffered = 0) +} +// ============================================================================= + +impl WebSocket { pub async fn send(&mut self, message: Message) -> Result<(), Error> { - message.send(&mut self.stream).await + send(message, &mut self.stream, &self.config, &mut self.n_buffered).await + } + pub async fn write(&mut self, message: Message) -> Result { + write(message, &mut self.stream, &self.config, &mut self.n_buffered).await + } + pub async fn flush(&mut self) -> Result<(), Error> { + flush(&mut self.stream, &mut self.n_buffered).await } } #[cfg(feature="rt_tokio")] const _: () = { impl WebSocket { pub fn split(&mut self) -> (ReadHalf, WriteHalf) { - let (rh, wh) = self.stream.split(); - (ReadHalf(rh), WriteHalf(wh)) + let (rh, wh) = self.stream.split(); + let config = &self.config; + let n_buffered = self.n_buffered; + (ReadHalf {config, stream:rh}, WriteHalf {config, n_buffered, stream:wh}) } } - use crate::__rt__::{ ReadHalf as TcpReadHalf, WriteHalf as TcpWriteHalf, }; - pub struct ReadHalf<'ws>(TcpReadHalf<'ws>); + pub struct ReadHalf<'ws> { + stream: TcpReadHalf<'ws>, + config: &'ws Config, + } impl<'ws> ReadHalf<'ws> { pub async fn recv(&mut self) -> Result, Error> { - Message::read_from(&mut self.0).await + Message::read_from(&mut self.stream, &self.config).await } } - pub struct WriteHalf<'ws>(TcpWriteHalf<'ws>); + pub struct WriteHalf<'ws> { + stream: TcpWriteHalf<'ws>, + config: &'ws Config, + n_buffered: usize, + } impl<'ws> WriteHalf<'ws> { pub async fn send(&mut self, message: Message) -> Result<(), Error> { - message.send(&mut self.0).await + send(message, &mut self.stream, &self.config, &mut self.n_buffered).await + } + pub async fn write(&mut self, message: Message) -> Result { + write(message, &mut self.stream, &self.config, &mut self.n_buffered).await + } + pub async fn flush(&mut self) -> Result<(), Error> { + flush(&mut self.stream, &mut self.n_buffered).await } } }; From 5ef55e49624b5e34a8c71ba61eebf155ef377b65 Mon Sep 17 00:00:00 2001 From: kana-rus Date: Tue, 14 Nov 2023 23:54:56 +0900 Subject: [PATCH 28/37] @2023-11-14 23:54+9:00 --- ohkami/src/x_websocket/frame.rs | 49 ++++++++++---- ohkami/src/x_websocket/message.rs | 106 ++++++++++++++++++++++-------- 2 files changed, 114 insertions(+), 41 deletions(-) diff --git a/ohkami/src/x_websocket/frame.rs b/ohkami/src/x_websocket/frame.rs index 65511fae..9f82eac3 100644 --- a/ohkami/src/x_websocket/frame.rs +++ b/ohkami/src/x_websocket/frame.rs @@ -36,16 +36,24 @@ pub enum OpCode { pub enum CloseCode { Normal, Away, Protocol, Unsupported, Status, Abnormal, Invalid, - Policy, Size, Extension, Error, Restart,Again, Tls, Reserved, + Policy, Size, Extension, Error, Restart, Again, Tls, Reserved, Iana(u16), Library(u16), Bad(u16), -} impl From for CloseCode { - fn from(code: u16) -> Self {match code { - 1000 => Self::Normal, 1001 => Self::Away, 1002 => Self::Protocol, 1003 => Self::Unsupported, - 1005 => Self::Status, 1006 => Self::Abnormal, 1007 => Self::Invalid, 1008 => Self::Policy, - 1009 => Self::Size, 1010 => Self::Extension, 1011 => Self::Error, 1012 => Self::Restart, - 1013 => Self::Again, 1015 => Self::Tls, 1016..=2999 => Self::Reserved, - 3000..=3999 => Self::Iana(code), 4000..=4999 => Self::Library(code), _ => Self::Bad(code), - }} +} impl CloseCode { + pub(super) fn from_bytes(bytes: [u8; 2]) -> Self { + let code = u16::from_be_bytes(bytes); + match code { + 1000 => Self::Normal, 1001 => Self::Away, 1002 => Self::Protocol, 1003 => Self::Unsupported, + 1005 => Self::Status, 1006 => Self::Abnormal, 1007 => Self::Invalid, 1008 => Self::Policy, + 1009 => Self::Size, 1010 => Self::Extension, 1011 => Self::Error, 1012 => Self::Restart, + 1013 => Self::Again, 1015 => Self::Tls, 1016..=2999 => Self::Reserved, + 3000..=3999 => Self::Iana(code), 4000..=4999 => Self::Library(code), _ => Self::Bad(code), + } + } + pub(super) fn into_bytes(self) -> [u8; 2] { + match self { + + } + } } pub struct Frame { @@ -70,7 +78,8 @@ pub struct Frame { let payload_len = { let payload_len_byte = second & 0x7F; let len_part_size = match payload_len_byte {127=>8, 126=>2, _=>0}; - match len_part_size { + + let len = match len_part_size { 0 => payload_len_byte as usize, _ => { let mut bytes = [0; 8]; @@ -82,10 +91,24 @@ pub struct Frame { } usize::from_be_bytes(bytes) } + }; if let Some(limit) = &config.max_frame_size { + (&len <= limit).then_some(()) + .ok_or_else(|| Error::new( + ErrorKind::InvalidData, + "Incoming frame is too large" + ))?; } + + len }; - let mask = if second & 0x80 == 0 {None} else { + let mask = if second & 0x80 == 0 { + (config.accept_unmasked_frames).then_some(None) + .ok_or_else(|| Error::new( + ErrorKind::InvalidData, + "Client frame is unmasked" + ))? + } else { let mut mask_bytes = [0; 4]; if let Err(e) = stream.read_exact(&mut mask_bytes).await { return match e.kind() { @@ -106,8 +129,8 @@ pub struct Frame { } pub(super) async fn write_to(self, - stream: &mut (impl AsyncWriter + Unpin), - config: &Config, + stream: &mut (impl AsyncWriter + Unpin), + _config: &Config, ) -> Result { fn into_bytes(frame: Frame) -> Vec { let Frame { is_final, opcode, mask, payload } = frame; diff --git a/ohkami/src/x_websocket/message.rs b/ohkami/src/x_websocket/message.rs index 1d15063c..0bf3ea85 100644 --- a/ohkami/src/x_websocket/message.rs +++ b/ohkami/src/x_websocket/message.rs @@ -1,21 +1,18 @@ use std::{borrow::Cow, io::{Error, ErrorKind}}; use crate::{__rt__::{AsyncReader, AsyncWriter}}; -use super::{frame::{Frame, OpCode}, websocket::Config}; +use super::{frame::{Frame, OpCode, CloseCode}, websocket::Config}; +const PING_PONG_PAYLOAD_LIMIT: usize = 125; pub enum Message { Text (String), Binary(Vec), - Ping (PingPongFrame), - Pong (PingPongFrame), - Close (CloseFrame), -} -pub struct PingPongFrame { - buf: [u8; 125], - len: usize/* less than 125 */ + Ping (Vec), + Pong (Vec), + Close (Option), } pub struct CloseFrame { - pub code: u16, + pub code: CloseCode, pub reason: Option>, } @@ -49,18 +46,32 @@ impl Message { ) -> Result { fn into_frame(message: Message) -> Frame { let (opcode, payload) = match message { - Message::Text (text) => (OpCode::Text, text.into_bytes()), - Message::Binary(vec) => (OpCode::Binary, vec), - Message::Ping (PingPongFrame { buf, len }) => (OpCode::Ping, buf[..len].to_vec()), - Message::Pong (PingPongFrame { buf, len }) => (OpCode::Pong, buf[..len].to_vec()), - Message::Close (CloseFrame { code, reason }) => { - let mut payload = code.to_be_bytes().to_vec(); - if let Some(reason_text) = reason { - payload.extend_from_slice(reason_text.as_bytes()) - } + Message::Text (text) => (OpCode::Text, text.into_bytes()), + Message::Binary(bytes) => (OpCode::Binary, bytes), + + Message::Ping(mut bytes) => { + bytes.truncate(PING_PONG_PAYLOAD_LIMIT); + (OpCode::Ping, bytes) + } + Message::Pong(mut bytes) => { + bytes.truncate(PING_PONG_PAYLOAD_LIMIT); + (OpCode::Ping, bytes) + } + + Message::Close(close_frame) => { + let payload = close_frame + .map(|CloseFrame { code, reason }| { + let mut bytes = code.to_be_bytes().to_vec(); + if let Some(reason_text) = reason { + bytes.extend_from_slice(reason_text.as_bytes()) + } + bytes + }).unwrap_or(Vec::new()); + (OpCode::Close, payload) } }; + Frame { is_final: false, mask: None, opcode, payload } } @@ -73,16 +84,16 @@ impl Message { stream: &mut (impl AsyncReader + Unpin), config: &Config, ) -> Result, Error> { - let head_frame = match Frame::read_from(stream, config).await? { + let first_frame = match Frame::read_from(stream, config).await? { Some(frame) => frame, None => return Ok(None), }; - match &head_frame.opcode { + match &first_frame.opcode { OpCode::Text => { - let mut payload = String::from_utf8(head_frame.payload) + let mut payload = String::from_utf8(first_frame.payload) .map_err(|_| Error::new(ErrorKind::InvalidData, "Text frame's payload is not valid UTF-8"))?; - if !head_frame.is_final { + if !first_frame.is_final { while let Ok(Some(next_frame)) = Frame::read_from(stream, config).await { if next_frame.opcode != OpCode::Continue { return Err(Error::new(ErrorKind::InvalidData, "Expected continue frame")); @@ -95,11 +106,20 @@ impl Message { } } } + + if let Some(limit) = &config.max_message_size { + (&payload.len() <= limit).then_some(()) + .ok_or_else(|| Error::new( + ErrorKind::InvalidData, + "Incoming message is too large" + ))?; + } + Ok(Some(Message::Text(payload))) } OpCode::Binary => { - let mut payload = head_frame.payload; - if !head_frame.is_final { + let mut payload = first_frame.payload; + if !first_frame.is_final { while let Ok(Some(mut next_frame)) = Frame::read_from(stream, config).await { if next_frame.opcode != OpCode::Continue { return Err(Error::new(ErrorKind::InvalidData, "Expected continue frame")); @@ -112,14 +132,44 @@ impl Message { } } } + + if let Some(limit) = &config.max_message_size { + (&payload.len() <= limit).then_some(()) + .ok_or_else(|| Error::new( + ErrorKind::InvalidData, + "Incoming message is too large" + ))?; + } + Ok(Some(Message::Binary(payload))) } + OpCode::Ping => { - todo!() + let payload = first_frame.payload; + (payload.len() <= PING_PONG_PAYLOAD_LIMIT) + .then_some(Some(Message::Ping(payload))) + .ok_or_else(|| Error::new(ErrorKind::InvalidData, "Incoming ping payload is too large")) } - OpCode::Close => return Ok(None), - OpCode::Pong => return Err(Error::new(ErrorKind::InvalidData, "Unexpected pong frame")), - OpCode::Continue => return Err(Error::new(ErrorKind::InvalidData, "Unexpected continue frame")), + OpCode::Pong => { + let payload = first_frame.payload; + (payload.len() <= PING_PONG_PAYLOAD_LIMIT) + .then_some(Some(Message::Pong(payload))) + .ok_or_else(|| Error::new(ErrorKind::InvalidData, "Incoming pong payload is too large")) + } + + OpCode::Close => { + let payload = first_frame.payload; + Ok(Some(Message::Close( + (! payload.is_empty()).then(|| { + let (code_bytes, rem) = payload.split_at(2); + let code = CloseCode::from_bytes(unsafe {(code_bytes.as_ptr() as *const [u8; 2]).read()}); + + todo!() + }) + ))) + } + + OpCode::Continue => Err(Error::new(ErrorKind::InvalidData, "Unexpected continue frame")) } } } From b832f5a32c4a9d06e76fdca6179edcafe15aee04 Mon Sep 17 00:00:00 2001 From: kana-rus Date: Wed, 15 Nov 2023 15:50:26 +0900 Subject: [PATCH 29/37] @2023-11-15 16:50+9:00 --- ohkami/src/x_websocket/frame.rs | 10 +++++++--- ohkami/src/x_websocket/message.rs | 15 ++++++--------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/ohkami/src/x_websocket/frame.rs b/ohkami/src/x_websocket/frame.rs index 9f82eac3..6e024a67 100644 --- a/ohkami/src/x_websocket/frame.rs +++ b/ohkami/src/x_websocket/frame.rs @@ -50,9 +50,13 @@ pub enum CloseCode { } } pub(super) fn into_bytes(self) -> [u8; 2] { - match self { - - } + u16::to_be_bytes(match self { + Self::Normal => 1000, Self::Away => 1001, Self::Protocol => 1002, Self::Unsupported => 1003, + Self::Status => 1005, Self::Abnormal => 1006, Self::Invalid => 1007, Self::Policy => 1008, + Self::Size => 1009, Self::Extension => 1010, Self::Error => 1011, Self::Restart => 1012, + Self::Again => 1013, Self::Tls => 1015, + Self::Reserved => 1016, Self::Iana(code) | Self::Library(code) | Self::Bad(code) => code, + }) } } diff --git a/ohkami/src/x_websocket/message.rs b/ohkami/src/x_websocket/message.rs index 0bf3ea85..cbab9e9c 100644 --- a/ohkami/src/x_websocket/message.rs +++ b/ohkami/src/x_websocket/message.rs @@ -61,13 +61,10 @@ impl Message { Message::Close(close_frame) => { let payload = close_frame .map(|CloseFrame { code, reason }| { - let mut bytes = code.to_be_bytes().to_vec(); - if let Some(reason_text) = reason { - bytes.extend_from_slice(reason_text.as_bytes()) - } - bytes + let code = code.into_bytes(); + let reason = reason.as_ref().map(|cow| cow.as_bytes()).unwrap_or(&[]); + [&code, reason].concat() }).unwrap_or(Vec::new()); - (OpCode::Close, payload) } }; @@ -162,9 +159,9 @@ impl Message { Ok(Some(Message::Close( (! payload.is_empty()).then(|| { let (code_bytes, rem) = payload.split_at(2); - let code = CloseCode::from_bytes(unsafe {(code_bytes.as_ptr() as *const [u8; 2]).read()}); - - todo!() + let code = CloseCode::from_bytes(unsafe {(code_bytes.as_ptr() as *const [u8; 2]).read()}); + let reason = (! rem.is_empty()).then(|| Cow::Owned(String::from_utf8(rem.to_vec()).unwrap())); + CloseFrame { code, reason } }) ))) } From 49b2fd013c0fdbe9439651fe23ff8d4d1eae0423 Mon Sep 17 00:00:00 2001 From: kana-rus Date: Sun, 19 Nov 2023 13:12:07 +0900 Subject: [PATCH 30/37] fin impl --- ohkami/src/layer5_ohkami/howl.rs | 4 +- ohkami/src/x_websocket/context.rs | 4 +- ohkami/src/x_websocket/mod.rs | 2 +- ohkami/src/x_websocket/upgrade.rs | 167 +++++++++++++++++++----------- 4 files changed, 113 insertions(+), 64 deletions(-) diff --git a/ohkami/src/layer5_ohkami/howl.rs b/ohkami/src/layer5_ohkami/howl.rs index ab84a13d..fec49001 100644 --- a/ohkami/src/layer5_ohkami/howl.rs +++ b/ohkami/src/layer5_ohkami/howl.rs @@ -1,6 +1,6 @@ use std::{sync::Arc, pin::Pin}; use super::{Ohkami}; -use crate::{__rt__, Request, Context, websocket::reserve_upgrade}; +use crate::{__rt__, Request, Context, websocket::{reserve_upgrade}}; #[cfg(feature="rt_async-std")] use crate::__rt__::StreamExt; @@ -93,7 +93,7 @@ impl Ohkami { #[cfg(feature="websocket")] Ok(upgrade_id) => { if let Some(id) = upgrade_id { - reserve_upgrade(id, stream).await + unsafe{reserve_upgrade(id, stream)} } } diff --git a/ohkami/src/x_websocket/context.rs b/ohkami/src/x_websocket/context.rs index 48b0d731..ec9be272 100644 --- a/ohkami/src/x_websocket/context.rs +++ b/ohkami/src/x_websocket/context.rs @@ -1,6 +1,6 @@ use std::{future::Future, borrow::Cow}; use super::websocket::Config; -use super::{WebSocket, sign, assume_upgraded}; +use super::{WebSocket, sign, assume_upgradable}; use crate::{Response, Context, Request}; use crate::__rt__::{task}; use crate::http::{Method}; @@ -118,7 +118,7 @@ impl WebSocketContext { task::spawn({ async move { let stream = match c.upgrade_id { - Some(id) => assume_upgraded(id).await, + Some(id) => assume_upgradable(id).await, None => return on_failed_upgrade.handle(UpgradeError::NotRequestedUpgrade), }; diff --git a/ohkami/src/x_websocket/mod.rs b/ohkami/src/x_websocket/mod.rs index df11a0e7..eb2de761 100644 --- a/ohkami/src/x_websocket/mod.rs +++ b/ohkami/src/x_websocket/mod.rs @@ -14,5 +14,5 @@ pub use { context::{WebSocketContext, UpgradeError}, }; pub(crate) use { - upgrade::{UpgradeID, request_upgrade_id, reserve_upgrade, assume_upgraded}, + upgrade::{UpgradeID, request_upgrade_id, reserve_upgrade, assume_upgradable}, }; diff --git a/ohkami/src/x_websocket/upgrade.rs b/ohkami/src/x_websocket/upgrade.rs index 06ba8e2f..21190fe2 100644 --- a/ohkami/src/x_websocket/upgrade.rs +++ b/ohkami/src/x_websocket/upgrade.rs @@ -1,88 +1,137 @@ -use std::{sync::{Arc, OnceLock}, future::Future, pin::Pin}; +use std::{ + sync::{Arc, OnceLock, atomic::{AtomicBool, Ordering}}, + pin::Pin, cell::UnsafeCell, + future::Future, +}; use crate::__rt__::{TcpStream, Mutex}; -type UpgradeLock = Mutex; - -pub static UPGRADE_STREAMS: OnceLock = OnceLock::new(); pub async fn request_upgrade_id() -> UpgradeID { - UPGRADE_STREAMS.get_or_init(UpgradeStreams::new) - .reserve().await + struct ReserveUpgrade; + impl Future for ReserveUpgrade { + type Output = UpgradeID; + fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { + let Some(mut streams) = UpgradeStreams().request_reservation() + else {cx.waker().wake_by_ref(); return std::task::Poll::Pending}; + + let id = UpgradeID(match streams.iter().position(|cell| cell.is_empty()) { + Some(i) => i, + None => {streams.push(StreamCell::new()); streams.len() - 1}, + }); + + streams[id.as_usize()].reserved = true; + + std::task::Poll::Ready(id) + } + } + + ReserveUpgrade.await } -pub async fn reserve_upgrade(id: UpgradeID, stream: Arc>) { - UPGRADE_STREAMS.get_or_init(UpgradeStreams::new) - .set(id, stream).await + +/// SAFETY: This must be called after the corresponded `reserve_upgrade` +pub unsafe fn reserve_upgrade(id: UpgradeID, stream: Arc>) { + #[cfg(debug_assertions)] assert!( + UpgradeStreams().get().get(id.as_usize()) + .is_some_and(|cell| cell.reserved && cell.stream.is_some()), + "Cell not reserved" + ); + + (UpgradeStreams().get_mut())[id.as_usize()].stream = Some(stream); } -//pub async fn cancel_upgrade -pub async fn assume_upgraded(id: UpgradeID) -> TcpStream { - UPGRADE_STREAMS.get_or_init(UpgradeStreams::new) - .get(id).await + +pub async fn assume_upgradable(id: UpgradeID) -> TcpStream { + struct AssumeUpgraded{id: UpgradeID} + impl Future for AssumeUpgraded { + type Output = TcpStream; + fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { + let Some(StreamCell { reserved, stream }) = (unsafe {UpgradeStreams().get_mut()}).get_mut(self.id.as_usize()) + else {cx.waker().wake_by_ref(); return std::task::Poll::Pending}; + + if !stream.as_ref().is_some_and(|arc| Arc::strong_count(arc) == 1) + {cx.waker().wake_by_ref(); return std::task::Poll::Pending}; + + *reserved = false; + std::task::Poll::Ready(unsafe { + Mutex::into_inner( + Arc::into_inner( + Option::take(stream) + .unwrap_unchecked()) + .unwrap_unchecked())}) + } + } + + AssumeUpgraded{id}.await } -#[derive(Clone, Copy)] -pub struct UpgradeID(usize); -pub struct UpgradeStreams { - streams: UpgradeLock>, +static UPGRADE_STREAMS: OnceLock = OnceLock::new(); +#[allow(non_snake_case)] fn UpgradeStreams() -> &'static UpgradeStreams { + UPGRADE_STREAMS.get_or_init(UpgradeStreams::new) +} + +struct UpgradeStreams { + in_scanning: AtomicBool, + streams: UnsafeCell>, } const _: () = { + unsafe impl Sync for UpgradeStreams {} + impl UpgradeStreams { fn new() -> Self { Self { - streams: UpgradeLock::new(Vec::new()), + in_scanning: AtomicBool::new(false), + streams: UnsafeCell::new(Vec::new()), } } - } - - impl UpgradeStreams { - async fn reserve(&self) -> UpgradeID { - let mut this = self.streams.lock().await; - match this.iter().position(UpgradeStream::is_empty) { - Some(i) => { - this[i].reserved = true; - UpgradeID(i) - } - None => { - this.push(UpgradeStream { - reserved: true, - stream: None, - }); - UpgradeID(this.len() - 1) - }, - } + fn get(&self) -> &Vec { + unsafe {&*self.streams.get()} } - async fn set(&self, id: UpgradeID, stream: Arc>) { - let mut this = self.streams.lock().await; - this[id.0].stream = Some(stream) + unsafe fn get_mut(&self) -> &mut Vec { + &mut *self.streams.get() } - async fn get(&self, id: UpgradeID) -> TcpStream { - let mut this = self.streams.lock().await; - Pin::new(this.get_mut(id.0).unwrap()).await + fn request_reservation(&self) -> Option> { + self.in_scanning.compare_exchange_weak(false, true, Ordering::Acquire, Ordering::Relaxed) + .ok().and(Some(ReservationLock(unsafe {self.get_mut()}))) } } -}; -struct UpgradeStream { + struct ReservationLock<'scan>(&'scan mut Vec); + impl<'scan> Drop for ReservationLock<'scan> { + fn drop(&mut self) { + UpgradeStreams().in_scanning.store(false, Ordering::Release) + } + } + impl<'scan> std::ops::Deref for ReservationLock<'scan> { + type Target = Vec; + fn deref(&self) -> &Self::Target {&*self.0} + } + impl<'scan> std::ops::DerefMut for ReservationLock<'scan> { + fn deref_mut(&mut self) -> &mut Self::Target {self.0} + } +}; + +struct StreamCell { reserved: bool, stream: Option>>, } const _: () = { - impl UpgradeStream { - fn is_empty(&self) -> bool { - self.stream.is_none() && !self.reserved + impl StreamCell { + fn new() -> Self { + Self { + reserved: false, + stream: None, + } } - } - impl Default for UpgradeStream { - fn default() -> Self { - Self { reserved: false, stream: None } + fn is_empty(&self) -> bool { + (!self.reserved) && self.stream.is_none() } } - impl Future for UpgradeStream { - type Output = TcpStream; - fn poll(self: std::pin::Pin<&mut Self>, _: &mut std::task::Context<'_>) -> std::task::Poll { - match Arc::strong_count(self.stream.as_ref().unwrap()) { - 1 => std::task::Poll::Ready(Arc::into_inner(self.get_mut().stream.take().unwrap()).unwrap().into_inner()), - _ => std::task::Poll::Pending, - } +}; + +#[derive(Clone, Copy)] +pub struct UpgradeID(usize); +const _: () = { + impl UpgradeID { + fn as_usize(&self) -> usize { + self.0 } } }; - From 8a66b4bbfa187e62efbdf1c0a07476a3c91195fe Mon Sep 17 00:00:00 2001 From: kana-rus Date: Sun, 19 Nov 2023 14:45:12 +0900 Subject: [PATCH 31/37] @2023-11-19 14:45+9:00 --- .github/workflows/check.yml | 33 ++----- .github/workflows/test.yml | 33 ++----- ohkami/Cargo.toml | 10 +- .../handler/into_handler.rs | 36 +++---- ohkami/src/layer3_fang_handler/handler/mod.rs | 23 +++-- ohkami/src/layer4_router/radix.rs | 94 +++++++++++++------ ohkami/src/layer5_ohkami/howl.rs | 20 +++- ohkami/src/layer6_testing/mod.rs | 2 +- ohkami/src/lib.rs | 2 - ohkami/src/x_websocket/upgrade.rs | 35 +++++-- 10 files changed, 166 insertions(+), 122 deletions(-) diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index d000a527..efa18eef 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -6,44 +6,25 @@ on: - main jobs: - check-stable: + test: runs-on: ubuntu-latest strategy: matrix: - rt: [tokio, async-std] + rt: ["tokio", "async-std"] + x: [",websocket", ""] + toolchain: ["stable", "nightly"] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: actions-rs/toolchain@v1 with: - toolchain: stable + toolchain: ${{ matrix.toolchain }} profile: minimal override: true - uses: actions-rs/cargo@v1 with: command: check - args: --features rt_${{ matrix.rt }} - - check-nighlt: - runs-on: ubuntu-latest - - strategy: - matrix: - rt: [tokio, async-std] - - steps: - - uses: actions/checkout@v3 - - - uses: actions-rs/toolchain@v1 - with: - toolchain: nightly - profile: minimal - override: true - - - uses: actions-rs/cargo@v1 - with: - command: check - args: --features nightly,rt_${{ matrix.rt }} + args: --features rt_${{ matrix.rt }}${{ matrix.x }}${{ matrix.toolchain == "nightly" && ",nightly" || "" }},DEBUG diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index dac55ad3..1172b421 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -6,44 +6,25 @@ on: - main jobs: - test-stable: + test: runs-on: ubuntu-latest strategy: matrix: - rt: [tokio, async-std] + rt: ["tokio", "async-std"] + x: [",websocket", ""] + toolchain: ["stable", "nightly"] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: actions-rs/toolchain@v1 with: - toolchain: stable + toolchain: ${{ matrix.toolchain }} profile: minimal override: true - uses: actions-rs/cargo@v1 with: command: test - args: --features rt_${{ matrix.rt }},DEBUG - - test-nightly: - runs-on: ubuntu-latest - - strategy: - matrix: - rt: [tokio, async-std] - - steps: - - uses: actions/checkout@v3 - - - uses: actions-rs/toolchain@v1 - with: - toolchain: nightly - profile: minimal - override: true - - - uses: actions-rs/cargo@v1 - with: - command: test - args: --features nightly,rt_${{ matrix.rt }},DEBUG \ No newline at end of file + args: --features rt_${{ matrix.rt }}${{ matrix.x }}${{ matrix.toolchain == "nightly" && ",nightly" || "" }},DEBUG diff --git a/ohkami/Cargo.toml b/ohkami/Cargo.toml index 2842ab44..5bd80b2b 100644 --- a/ohkami/Cargo.toml +++ b/ohkami/Cargo.toml @@ -32,9 +32,13 @@ websocket = [] nightly = [] ##### DEBUG ##### -DEBUG = ["websocket", "serde/derive", "tokio?/macros", "async-std?/attributes"] +DEBUG = [ + #"websocket", + "serde/derive", "tokio?/macros", "async-std?/attributes" +] default = [ - "rt_tokio", + # "rt_tokio", + "rt_async-std", "DEBUG", - #"nightly" + # "nightly" ] \ No newline at end of file diff --git a/ohkami/src/layer3_fang_handler/handler/into_handler.rs b/ohkami/src/layer3_fang_handler/handler/into_handler.rs index 629fa380..6e114151 100644 --- a/ohkami/src/layer3_fang_handler/handler/into_handler.rs +++ b/ohkami/src/layer3_fang_handler/handler/into_handler.rs @@ -30,7 +30,7 @@ const _: (/* only Context */) = { Fut: Future + Send + Sync + 'static, { fn into_handler(self) -> Handler { - Handler::new(false, move |_, c, _| + Handler::new(move |_, c, _| Box::pin(self(c)) ) } @@ -46,7 +46,7 @@ const _: (/* PathParam */) = { Fut: Future + Send + Sync + 'static, { fn into_handler(self) -> Handler { - Handler::new(false, move |_, c, params| + Handler::new(move |_, c, params| match <$param_type as PathParam>::parse(unsafe {params.assume_init_first().as_bytes()}) { Ok(p1) => Box::pin(self(c, p1)), Err(e) => __bad_request(&c, e) @@ -65,7 +65,7 @@ const _: (/* PathParam */) = { Fut: Future + Send + Sync + 'static, { fn into_handler(self) -> Handler { - Handler::new(false, move |_, c, params| + Handler::new(move |_, c, params| // SAFETY: Due to the architecture of `Router`, // `params` has already `append`ed once before this code match ::parse(unsafe {params.assume_init_first().as_bytes()}) { @@ -82,7 +82,7 @@ const _: (/* PathParam */) = { Fut: Future + Send + Sync + 'static, { fn into_handler(self) -> Handler { - Handler::new(false, move |_, c, params| { + Handler::new(move |_, c, params| { let (p1, p2) = params.assume_init_extract(); let (p1, p2) = unsafe {(p1.as_bytes(), p2.as_bytes())}; match (::parse(p1), ::parse(p2)) { @@ -101,7 +101,7 @@ const _: (/* FromRequest items */) = { Fut: Future + Send + Sync + 'static, { fn into_handler(self) -> Handler { - Handler::new(false, move |req, c, _| + Handler::new(move |req, c, _| match Item1::parse(req) { Ok(item1) => Box::pin(self(c, item1)), Err(e) => __bad_request(&c, e) @@ -116,7 +116,7 @@ const _: (/* FromRequest items */) = { Fut: Future + Send + Sync + 'static, { fn into_handler(self) -> Handler { - Handler::new(false, move |req, c, _| + Handler::new(move |req, c, _| match (Item1::parse(req), Item2::parse(req)) { (Ok(item1), Ok(item2)) => Box::pin(self(c, item1, item2)), (Err(e), _) | (_, Err(e)) => __bad_request(&c, e), @@ -135,7 +135,7 @@ const _: (/* single PathParam and FromRequest items */) = { Fut: Future + Send + Sync + 'static, { fn into_handler(self) -> Handler { - Handler::new(false, move |req, c, params| { + Handler::new(move |req, c, params| { // SAFETY: Due to the architecture of `Router`, // `params` has already `append`ed once before this code let p1 = unsafe {params.assume_init_first().as_bytes()}; @@ -154,7 +154,7 @@ const _: (/* single PathParam and FromRequest items */) = { Fut: Future + Send + Sync + 'static, { fn into_handler(self) -> Handler { - Handler::new(false, move |req, c, params| { + Handler::new(move |req, c, params| { // SAFETY: Due to the architecture of `Router`, // `params` has already `append`ed once before this code let p1 = unsafe {params.assume_init_first().as_bytes()}; @@ -179,7 +179,7 @@ const _: (/* one PathParam and FromRequest items */) = { Fut: Future + Send + Sync + 'static, { fn into_handler(self) -> Handler { - Handler::new(false, move |req, c, params| { + Handler::new(move |req, c, params| { // SAFETY: Due to the architecture of `Router`, // `params` has already `append`ed once before this code let p1 = unsafe {params.assume_init_first().as_bytes()}; @@ -198,7 +198,7 @@ const _: (/* one PathParam and FromRequest items */) = { Fut: Future + Send + Sync + 'static, { fn into_handler(self) -> Handler { - Handler::new(false, move |req, c, params| { + Handler::new(move |req, c, params| { // SAFETY: Due to the architecture of `Router`, // `params` has already `append`ed once before this code let p1 = unsafe {params.assume_init_first().as_bytes()}; @@ -219,7 +219,7 @@ const _: (/* two PathParams and FromRequest items */) = { Fut: Future + Send + Sync + 'static, { fn into_handler(self) -> Handler { - Handler::new(false, move |req, c, params| { + Handler::new(move |req, c, params| { // SAFETY: Due to the architecture of `Router`, // `params` has already `append`ed twice before this code let (p1, p2) = params.assume_init_extract(); @@ -239,7 +239,7 @@ const _: (/* two PathParams and FromRequest items */) = { Fut: Future + Send + Sync + 'static, { fn into_handler(self) -> Handler { - Handler::new(false, move |req, c, params| { + Handler::new(move |req, c, params| { // SAFETY: Due to the architecture of `Router`, // `params` has already `append`ed twice before this code let (p1, p2) = params.assume_init_extract(); @@ -295,7 +295,7 @@ const _: (/* requires upgrade to websocket */) = { Fut: Future + Send + Sync + 'static, { fn into_handler(self) -> Handler { - Handler::new(true, move |req, c, params| { + Handler::new(move |req, c, params| { let (p1, p2) = params.assume_init_extract(); let (p1, p2) = unsafe {(p1.as_bytes(), p2.as_bytes())}; match (P1::parse(p1), P2::parse(p2)) { @@ -305,7 +305,7 @@ const _: (/* requires upgrade to websocket */) = { } (Err(e),_)|(_,Err(e)) => __bad_request(&c, e), } - }) + }).requires_upgrade() } } impl IntoHandler<(WebSocketContext, (P1,))> for F @@ -314,7 +314,7 @@ const _: (/* requires upgrade to websocket */) = { Fut: Future + Send + Sync + 'static, { fn into_handler(self) -> Handler { - Handler::new(true, move |req, c, params| { + Handler::new(move |req, c, params| { let p1 = unsafe {params.assume_init_first().as_bytes()}; match P1::parse(p1) { Ok(p1) => match WebSocketContext::new(c, req) { @@ -323,7 +323,7 @@ const _: (/* requires upgrade to websocket */) = { } Err(e) => __bad_request(&c, e), } - }) + }).requires_upgrade() } } impl IntoHandler<(WebSocketContext, (P1, P2))> for F @@ -332,7 +332,7 @@ const _: (/* requires upgrade to websocket */) = { Fut: Future + Send + Sync + 'static, { fn into_handler(self) -> Handler { - Handler::new(true, move |req, c, params| { + Handler::new(move |req, c, params| { let (p1, p2) = params.assume_init_extract(); let (p1, p2) = unsafe {(p1.as_bytes(), p2.as_bytes())}; match (P1::parse(p1), P2::parse(p2)) { @@ -342,7 +342,7 @@ const _: (/* requires upgrade to websocket */) = { } (Err(e),_)|(_,Err(e)) => __bad_request(&c, e), } - }) + }).requires_upgrade() } } }; diff --git a/ohkami/src/layer3_fang_handler/handler/mod.rs b/ohkami/src/layer3_fang_handler/handler/mod.rs index 4dfbde69..4de443ec 100644 --- a/ohkami/src/layer3_fang_handler/handler/mod.rs +++ b/ohkami/src/layer3_fang_handler/handler/mod.rs @@ -19,7 +19,7 @@ pub(crate) type PathParams = List; #[cfg(not(test))] pub struct Handler { - pub(crate) requires_upgrade: bool, + #[cfg(feature="websocket")] pub(crate) requires_upgrade: bool, pub(crate) proc: Box Pin< Box Pin< Box Pin< + fn new( + proc: (impl Fn(&mut Request, Context, PathParams) -> Pin< Box + Send + 'static @@ -55,7 +55,18 @@ impl Handler { > + Send + Sync + 'static ) ) -> Self { - #[cfg(not(test))] {Self { requires_upgrade, proc: Box::new(proc) }} - #[cfg(test)] {Self { requires_upgrade, proc: Arc::new(proc) }} + #[cfg(not(test))] {Self { + #[cfg(feature="websocket")] requires_upgrade: false, + proc: Box::new(proc), + }} + #[cfg(test)] {Self { + #[cfg(feature="websocket")] requires_upgrade: false, + proc: Arc::new(proc), + }} + } + + #[cfg(feature="websocket")] fn requires_upgrade(mut self) -> Self { + self.requires_upgrade = true; + self } } diff --git a/ohkami/src/layer4_router/radix.rs b/ohkami/src/layer4_router/radix.rs index 3084042f..ac2c3ce2 100644 --- a/ohkami/src/layer4_router/radix.rs +++ b/ohkami/src/layer4_router/radix.rs @@ -4,9 +4,11 @@ use crate::{ Response, layer0_lib::{Method, Status, Slice}, layer3_fang_handler::{Handler, FrontFang, PathParams, BackFang}, - websocket::{request_upgrade_id, UpgradeID}, }; +#[cfg(feature="websocket")] +use crate::websocket::{request_upgrade_id, UpgradeID}; + /*===== defs =====*/ pub(crate) struct RadixRouter { @@ -44,12 +46,15 @@ pub(super) enum Pattern { /*===== impls =====*/ +#[cfg(feature="websocket")] type HandleResult = (Response, Option); +#[cfg(not(feature="websocket"))] type HandleResult = Response; + impl RadixRouter { pub(crate) async fn handle( &self, mut c: Context, req: &mut Request, - ) -> (Response, Option) { + ) -> HandleResult { let mut params = PathParams::new(); let search_result = match req.method() { Method::GET => self.GET .search(&mut c, req/*.path_bytes()*/, &mut params), @@ -63,14 +68,21 @@ impl RadixRouter { for ff in front { if let Err(err_res) = ff.0(&mut c, req) { - return (err_res, None) + {#[cfg(feature="websocket")] return (err_res, None);} + {#[cfg(not(feature="websocket"))] return err_res;} } } let target = match self.GET.search(&mut c, req/*.path_bytes()*/, &mut params) { Ok(Some(node)) => node, - Ok(None) => return (c.NotFound(), None), - Err(err_res) => return (err_res, None), + Ok(None) => return { + {#[cfg(feature="websocket")] (c.NotFound(), None)} + #[cfg(not(feature="websocket"))] c.NotFound() + }, + Err(err_res) => return { + {#[cfg(feature="websocket")] (err_res, None)} + #[cfg(not(feature="websocket"))] err_res + }, }; let Response { headers, .. } = target.handle_discarding_upgrade(c, req, params).await; @@ -84,18 +96,21 @@ impl RadixRouter { res = bf.0(res) } - return (res, None) + #[cfg(feature="websocket")] return (res, None); + #[cfg(not(feature="websocket"))] return res; } Method::OPTIONS => { let Some((cors_str, cors)) = crate::layer3_fang_handler::builtin::CORS.get() else { - return (c.InternalServerError(), None) + #[cfg(feature="websocket")] return (c.InternalServerError(), None); + #[cfg(not(feature="websocket"))] return c.InternalServerError(); }; let (front, back) = self.OPTIONSfangs; for ff in front { if let Err(err_res) = ff.0(&mut c, req) { - return (err_res, None) + #[cfg(feature="websocket")] return (err_res, None); + #[cfg(not(feature="websocket"))] return err_res; } } @@ -103,33 +118,40 @@ impl RadixRouter { { let Some(origin) = req.header("Origin") else { - return (c.BadRequest(), None) + #[cfg(feature="websocket")] return (c.BadRequest(), None); + #[cfg(not(feature="websocket"))] return c.BadRequest(); }; if !cors.AllowOrigin.matches(origin) { - return (c.Forbidden(), None) + #[cfg(feature="websocket")] return (c.Forbidden(), None); + #[cfg(not(feature="websocket"))] return c.Forbidden(); } if req.header("Authorization").is_some() && !cors.AllowCredentials { - return (c.Forbidden(), None) + #[cfg(feature="websocket")] return (c.Forbidden(), None); + #[cfg(not(feature="websocket"))] return c.Forbidden(); } if let Some(request_method) = req.header("Access-Control-Request-Method") { let request_method = Method::from_bytes(request_method.as_bytes()); let Some(allow_methods) = cors.AllowMethods.as_ref() else { - return (c.Forbidden(), None) + #[cfg(feature="websocket")] return (c.Forbidden(), None); + #[cfg(not(feature="websocket"))] return c.Forbidden(); }; if !allow_methods.contains(&request_method) { - return (c.Forbidden(), None) + #[cfg(feature="websocket")] return (c.Forbidden(), None); + #[cfg(not(feature="websocket"))] return c.Forbidden(); } } if let Some(request_headers) = req.header("Access-Control-Request-Headers") { let mut request_headers = request_headers.split(',').map(|h| h.trim_matches(' ')); let Some(allow_headers) = cors.AllowHeaders.as_ref() else { - return (c.Forbidden(), None) + #[cfg(feature="websocket")] return (c.Forbidden(), None); + #[cfg(not(feature="websocket"))] return c.Forbidden(); }; if !request_headers.all(|h| allow_headers.contains(&h)) { - return (c.Forbidden(), None) + #[cfg(feature="websocket")] return (c.Forbidden(), None); + #[cfg(not(feature="websocket"))] return c.Forbidden(); } } } @@ -140,53 +162,67 @@ impl RadixRouter { res = bf.0(res) } - return (res, None) + #[cfg(feature="websocket")] return (res, None); + #[cfg(not(feature="websocket"))] return res; } }; let target = match search_result { Ok(Some(node)) => node, - Ok(None) => return (c.NotFound(), None), - Err(err_res) => return (err_res, None), + Ok(None) => { + #[cfg(feature="websocket")] return (c.NotFound(), None); + #[cfg(not(feature="websocket"))] return c.NotFound(); + } + Err(err_res) => { + #[cfg(feature="websocket")] return (err_res, None); + #[cfg(not(feature="websocket"))] return err_res; + } }; - #[cfg(feature="websocket")] {target.handle(c, req, params).await} - #[cfg(not(feature="websocket"))] {(target.handle_discarding_upgrade(c, req, params).await, None)} + target.handle(c, req, params).await } } impl Node { #[inline] pub(super) async fn handle(&self, - mut c: Context, + #[allow(unused_mut)] mut c: Context, req: &mut Request, params: PathParams, - ) -> (Response, Option) { + ) -> HandleResult { match &self.handler { - Some(Handler { requires_upgrade, proc }) => { - let upgrade_id = match (*requires_upgrade).then(|| async { + Some(handler) => { + #[cfg(feature="websocket")] + let upgrade_id = match (handler.requires_upgrade).then(|| async { let id = request_upgrade_id().await; c.upgrade_id = Some(id); id }) {None => None, Some(id) => Some(id.await)}; - let mut res = proc(req, c, params).await; + let mut res = (handler.proc)(req, c, params).await; for b in self.back { res = b.0(res); } - (res, upgrade_id) + #[cfg(feature="websocket")] + {(res, upgrade_id)} + #[cfg(not(feature="websocket"))] + {res} } - None => (c.NotFound(), None) + #[cfg(feature="websocket")] + None => (c.NotFound(), None), + #[cfg(not(feature="websocket"))] + None => c.NotFound(), } } + #[inline] pub(super) async fn handle_discarding_upgrade(&self, c: Context, req: &mut Request, params: PathParams, ) -> Response { match &self.handler { - Some(Handler { requires_upgrade:_, proc }) => { - let mut res = proc(req, c, params).await; + Some(handler) => { + let mut res = (handler.proc)(req, c, params).await; for b in self.back { res = b.0(res); } diff --git a/ohkami/src/layer5_ohkami/howl.rs b/ohkami/src/layer5_ohkami/howl.rs index fec49001..f79090d4 100644 --- a/ohkami/src/layer5_ohkami/howl.rs +++ b/ohkami/src/layer5_ohkami/howl.rs @@ -1,7 +1,9 @@ use std::{sync::Arc, pin::Pin}; use super::{Ohkami}; -use crate::{__rt__, Request, Context, websocket::{reserve_upgrade}}; +use crate::{__rt__, Request, Context}; + #[cfg(feature="rt_async-std")] use crate::__rt__::StreamExt; +#[cfg(feature="websocket")] use crate::websocket::reserve_upgrade; pub trait TCPAddress { @@ -44,15 +46,23 @@ impl Ohkami { #[cfg(feature="rt_async-std")] while let Some(Ok(mut stream)) = listener.incoming().next().await { let router = Arc::clone(&router); - let c = Context::new(); __rt__::task::spawn(async move { let mut req = Request::init(); let mut req = unsafe {Pin::new_unchecked(&mut req)}; req.as_mut().read(&mut stream).await; - let res = router.handle(c, req.get_mut()).await; - res.send(&mut stream).await + #[cfg(not(feature="websocket"))] + let res = router.handle(Context::new(), req.get_mut()).await; + #[cfg(feature="websocket")] + let (res, upgrade_id) = router.handle(Context::new(), req.get_mut()).await; + + res.send(&mut stream).await; + + #[cfg(feature="websocket")] + if let Some(id) = upgrade_id { + unsafe{reserve_upgrade(id, stream)} + } }).await } @@ -77,7 +87,7 @@ impl Ohkami { req.as_mut().read(stream).await; #[cfg(not(feature="websocket"))] - let res = router.handle_discarding_upgrade(Context::new(), req.get_mut()).await; + let res = router.handle(Context::new(), req.get_mut()).await; #[cfg(feature="websocket")] let (res, upgrade_id) = router.handle(Context::new(), req.get_mut()).await; diff --git a/ohkami/src/layer6_testing/mod.rs b/ohkami/src/layer6_testing/mod.rs index 6059849d..67a1d959 100644 --- a/ohkami/src/layer6_testing/mod.rs +++ b/ohkami/src/layer6_testing/mod.rs @@ -36,7 +36,7 @@ impl Testing for Ohkami { let mut req = unsafe {Pin::new_unchecked(&mut req)}; req.as_mut().read(&mut &request.encode_request()[..]).await; - let (res, _) = router.handle(Context::new(), &mut req).await; + let res = router.handle_discarding_upgrade(Context::new(), &mut req).await; TestResponse::new(res) }; diff --git a/ohkami/src/lib.rs b/ohkami/src/lib.rs index ba079c18..3752c856 100644 --- a/ohkami/src/lib.rs +++ b/ohkami/src/lib.rs @@ -222,8 +222,6 @@ mod __rt__ { #[cfg(feature="rt_tokio")] pub(crate) use tokio::sync::Mutex; - #[cfg(feature="rt_async-std")] - pub(crate) use async_std::sync::Mutex; #[cfg(feature="rt_tokio")] pub(crate) use tokio::net::TcpListener; diff --git a/ohkami/src/x_websocket/upgrade.rs b/ohkami/src/x_websocket/upgrade.rs index 21190fe2..1cb0ec2b 100644 --- a/ohkami/src/x_websocket/upgrade.rs +++ b/ohkami/src/x_websocket/upgrade.rs @@ -1,9 +1,14 @@ use std::{ - sync::{Arc, OnceLock, atomic::{AtomicBool, Ordering}}, + sync::{OnceLock, atomic::{AtomicBool, Ordering}}, pin::Pin, cell::UnsafeCell, future::Future, }; -use crate::__rt__::{TcpStream, Mutex}; +use crate::__rt__::{TcpStream}; + +#[cfg(feature="rt")] use { + std::sync::Arc, + __rt__::Mutex, +}; pub async fn request_upgrade_id() -> UpgradeID { @@ -29,10 +34,14 @@ pub async fn request_upgrade_id() -> UpgradeID { } /// SAFETY: This must be called after the corresponded `reserve_upgrade` -pub unsafe fn reserve_upgrade(id: UpgradeID, stream: Arc>) { +pub unsafe fn reserve_upgrade( + id: UpgradeID, + #[cfg(feature="rt_tokio")] stream: Arc>, + #[cfg(feature="rt_async-std")] stream: TcpStream, +) { #[cfg(debug_assertions)] assert!( - UpgradeStreams().get().get(id.as_usize()) - .is_some_and(|cell| cell.reserved && cell.stream.is_some()), + UpgradeStreams().get().get(id.as_usize()).is_some_and( + |cell| cell.reserved && cell.stream.is_some()), "Cell not reserved" ); @@ -47,16 +56,28 @@ pub async fn assume_upgradable(id: UpgradeID) -> TcpStream { let Some(StreamCell { reserved, stream }) = (unsafe {UpgradeStreams().get_mut()}).get_mut(self.id.as_usize()) else {cx.waker().wake_by_ref(); return std::task::Poll::Pending}; + #[cfg(feature="rt_tokio")] if !stream.as_ref().is_some_and(|arc| Arc::strong_count(arc) == 1) {cx.waker().wake_by_ref(); return std::task::Poll::Pending}; + #[cfg(feature="rt_async-std")] + if !stream.is_some() + {cx.waker().wake_by_ref(); return std::task::Poll::Pending}; *reserved = false; + + #[cfg(feature="rt_tokio")] { std::task::Poll::Ready(unsafe { Mutex::into_inner( Arc::into_inner( Option::take(stream) .unwrap_unchecked()) .unwrap_unchecked())}) + } + #[cfg(feature="rt_async-std")] { + std::task::Poll::Ready(unsafe { + Option::take(stream) + .unwrap_unchecked()}) + } } } @@ -111,7 +132,9 @@ struct UpgradeStreams { struct StreamCell { reserved: bool, - stream: Option>>, + + #[cfg(feature="rt_tokio")] stream: Option>>, + #[cfg(feature="rt_async-std")] stream: Option, } const _: () = { impl StreamCell { fn new() -> Self { From 2507e50644bffa486645004f6c85300d27702dec Mon Sep 17 00:00:00 2001 From: kana-rus Date: Wed, 22 Nov 2023 17:02:12 +0900 Subject: [PATCH 32/37] TODO: abstract `Stream` --- ohkami/Cargo.toml | 4 +- .../handler/into_handler.rs | 8 +- ohkami/src/layer4_router/radix.rs | 27 +++-- ohkami/src/layer6_testing/mod.rs | 11 +- ohkami/src/layer6_testing/x_websocket.rs | 105 ++++++++++++++++++ ohkami/src/lib.rs | 2 + ohkami/src/x_websocket/upgrade.rs | 4 +- 7 files changed, 143 insertions(+), 18 deletions(-) create mode 100644 ohkami/src/layer6_testing/x_websocket.rs diff --git a/ohkami/Cargo.toml b/ohkami/Cargo.toml index 5bd80b2b..0d587f8a 100644 --- a/ohkami/Cargo.toml +++ b/ohkami/Cargo.toml @@ -33,11 +33,11 @@ nightly = [] ##### DEBUG ##### DEBUG = [ - #"websocket", + "websocket", "serde/derive", "tokio?/macros", "async-std?/attributes" ] default = [ - # "rt_tokio", + #"rt_tokio", "rt_async-std", "DEBUG", # "nightly" diff --git a/ohkami/src/layer3_fang_handler/handler/into_handler.rs b/ohkami/src/layer3_fang_handler/handler/into_handler.rs index 6e114151..10720030 100644 --- a/ohkami/src/layer3_fang_handler/handler/into_handler.rs +++ b/ohkami/src/layer3_fang_handler/handler/into_handler.rs @@ -262,12 +262,12 @@ const _: (/* requires upgrade to websocket */) = { Fut: Future + Send + Sync + 'static, { fn into_handler(self) -> Handler { - Handler::new(true, move |req, c, _| { + Handler::new(move |req, c, _| { match WebSocketContext::new(c, req) { Ok(wsc) => Box::pin(self(wsc)), Err(res) => (|| Box::pin(async {res}))(), } - }) + }).requires_upgrade() } } @@ -277,7 +277,7 @@ const _: (/* requires upgrade to websocket */) = { Fut: Future + Send + Sync + 'static, { fn into_handler(self) -> Handler { - Handler::new(true, move |req, c, params| { + Handler::new(move |req, c, params| { let p1 = unsafe {params.assume_init_first().as_bytes()}; match P1::parse(p1) { Ok(p1) => match WebSocketContext::new(c, req) { @@ -286,7 +286,7 @@ const _: (/* requires upgrade to websocket */) = { } Err(e) => __bad_request(&c, e), } - }) + }).requires_upgrade() } } impl IntoHandler<(WebSocketContext, P1, P2)> for F diff --git a/ohkami/src/layer4_router/radix.rs b/ohkami/src/layer4_router/radix.rs index ac2c3ce2..ca5e6e69 100644 --- a/ohkami/src/layer4_router/radix.rs +++ b/ohkami/src/layer4_router/radix.rs @@ -7,7 +7,11 @@ use crate::{ }; #[cfg(feature="websocket")] -use crate::websocket::{request_upgrade_id, UpgradeID}; +use crate::websocket::{ + UpgradeID, +}; +#[cfg(all(feature="websocket", not(test)))] +use crate::websocket::request_upgrade_id; /*===== defs =====*/ @@ -68,20 +72,20 @@ impl RadixRouter { for ff in front { if let Err(err_res) = ff.0(&mut c, req) { - {#[cfg(feature="websocket")] return (err_res, None);} - {#[cfg(not(feature="websocket"))] return err_res;} + #[cfg(feature="websocket")] return (err_res, None); + #[cfg(not(feature="websocket"))] return err_res; } } let target = match self.GET.search(&mut c, req/*.path_bytes()*/, &mut params) { Ok(Some(node)) => node, - Ok(None) => return { - {#[cfg(feature="websocket")] (c.NotFound(), None)} - #[cfg(not(feature="websocket"))] c.NotFound() + Ok(None) => { + #[cfg(feature="websocket")] return (c.NotFound(), None); + #[cfg(not(feature="websocket"))] return c.NotFound(); }, - Err(err_res) => return { - {#[cfg(feature="websocket")] (err_res, None)} - #[cfg(not(feature="websocket"))] err_res + Err(err_res) => { + #[cfg(feature="websocket")] return (err_res, None); + #[cfg(not(feature="websocket"))] return err_res; }, }; @@ -192,12 +196,17 @@ impl Node { match &self.handler { Some(handler) => { #[cfg(feature="websocket")] + #[cfg(not(test))] let upgrade_id = match (handler.requires_upgrade).then(|| async { let id = request_upgrade_id().await; c.upgrade_id = Some(id); id }) {None => None, Some(id) => Some(id.await)}; + #[cfg(feature="websocket")] + #[cfg(test)] + let upgrade_id = None; + let mut res = (handler.proc)(req, c, params).await; for b in self.back { res = b.0(res); diff --git a/ohkami/src/layer6_testing/mod.rs b/ohkami/src/layer6_testing/mod.rs index 67a1d959..b25bf859 100644 --- a/ohkami/src/layer6_testing/mod.rs +++ b/ohkami/src/layer6_testing/mod.rs @@ -1,4 +1,7 @@ mod _test; +mod x_websocket; + +pub(crate) use x_websocket::TestWebSocket; use std::borrow::Cow; use std::collections::HashMap; @@ -11,7 +14,9 @@ use crate::layer0_lib::{IntoCows, Status, Method, ContentType}; pub trait Testing { fn oneshot(&self, req: TestRequest) -> TestFuture; + //fn oneshot_and_upgraded(&self, req: TestRequest) -> (TestFuture, TestWebSocket); } + pub struct TestFuture( Box>); impl Future for TestFuture { @@ -36,7 +41,11 @@ impl Testing for Ohkami { let mut req = unsafe {Pin::new_unchecked(&mut req)}; req.as_mut().read(&mut &request.encode_request()[..]).await; - let res = router.handle_discarding_upgrade(Context::new(), &mut req).await; + #[cfg(not(feature="websocket"))] + let res = router.handle(Context::new(), &mut req).await; + #[cfg(feature="websocket")] + let (res, _) = router.handle(Context::new(), &mut req).await; + TestResponse::new(res) }; diff --git a/ohkami/src/layer6_testing/x_websocket.rs b/ohkami/src/layer6_testing/x_websocket.rs new file mode 100644 index 00000000..f0753661 --- /dev/null +++ b/ohkami/src/layer6_testing/x_websocket.rs @@ -0,0 +1,105 @@ +use std::cell::UnsafeCell; +use std::pin::Pin; +use std::sync::Arc; +use crate::__rt__::Mutex; + + +pub struct TestWebSocket { + client: TestStream, +} impl TestWebSocket { + pub(crate) fn new(stream: TestStream) -> Self { + Self { client: stream } + } +} + + +pub(crate) struct TestStream { + read: Arc>>, + write: Arc>>, +} +/// SAFETY: Only one of the client - server needs `&mut _` to `write` into +/// each `Arc>>` at a time +impl TestStream { + fn read_half(&self) -> &[u8] { + unsafe {&*self.read.get()} + } + fn write_half(&self) -> &mut Vec { + unsafe {&mut *self.write.get()} + } +} +impl Clone for TestStream { + fn clone(&self) -> Self { + Self { + read: self.read.clone(), + write: self.write.clone(), + } + } +} +#[cfg(feature="rt_tokio")] const _: () = { + impl tokio::io::AsyncRead for TestStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + let mut read = &self.read[..]; + let read = unsafe {Pin::new_unchecked(&mut read)}; + read.poll_read(cx, buf) + } + } + + impl tokio::io::AsyncWrite for TestStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + unsafe {self.map_unchecked_mut(|this| &mut this.write)} + .poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { + unsafe {self.map_unchecked_mut(|this| &mut this.write)} + .poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { + unsafe {self.map_unchecked_mut(|this| &mut this.write)} + .poll_shutdown(cx) + } + } +}; +#[cfg(feature="rt_async-std")] const _: () = { + impl async_std::io::Read for TestStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut [u8], + ) -> std::task::Poll> { + let mut read = self.read_half(); + let read = unsafe {Pin::new_unchecked(&mut read)}; + read.poll_read(cx, buf) + } + } + + impl async_std::io::Write for TestStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + unsafe {self.map_unchecked_mut(|this| this.write_half())} + .poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { + unsafe {self.map_unchecked_mut(|this| this.write_half())} + .poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { + unsafe {self.map_unchecked_mut(|this| this.write_half())} + .poll_close(cx) + } + } +}; diff --git a/ohkami/src/lib.rs b/ohkami/src/lib.rs index 3752c856..ba079c18 100644 --- a/ohkami/src/lib.rs +++ b/ohkami/src/lib.rs @@ -222,6 +222,8 @@ mod __rt__ { #[cfg(feature="rt_tokio")] pub(crate) use tokio::sync::Mutex; + #[cfg(feature="rt_async-std")] + pub(crate) use async_std::sync::Mutex; #[cfg(feature="rt_tokio")] pub(crate) use tokio::net::TcpListener; diff --git a/ohkami/src/x_websocket/upgrade.rs b/ohkami/src/x_websocket/upgrade.rs index 1cb0ec2b..1ff7f724 100644 --- a/ohkami/src/x_websocket/upgrade.rs +++ b/ohkami/src/x_websocket/upgrade.rs @@ -5,9 +5,9 @@ use std::{ }; use crate::__rt__::{TcpStream}; -#[cfg(feature="rt")] use { +#[cfg(feature="rt_tokio")] use { std::sync::Arc, - __rt__::Mutex, + crate::__rt__::Mutex, }; From 2b7e55d5b0d8bf7d6ef4f9d79dc1d27e5a2ccd4c Mon Sep 17 00:00:00 2001 From: kana-rus Date: Thu, 23 Nov 2023 00:36:12 +0900 Subject: [PATCH 33/37] @2023-11-23 00:36+9:00 --- ohkami/src/layer6_testing/x_websocket.rs | 84 ++++++++++++++++-------- 1 file changed, 56 insertions(+), 28 deletions(-) diff --git a/ohkami/src/layer6_testing/x_websocket.rs b/ohkami/src/layer6_testing/x_websocket.rs index f0753661..cff10282 100644 --- a/ohkami/src/layer6_testing/x_websocket.rs +++ b/ohkami/src/layer6_testing/x_websocket.rs @@ -1,7 +1,9 @@ use std::cell::UnsafeCell; +use std::future::Future; use std::pin::Pin; use std::sync::Arc; -use crate::__rt__::Mutex; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::task::Poll; pub struct TestWebSocket { @@ -13,28 +15,47 @@ pub struct TestWebSocket { } +/// +/// ```txt +/// client ------------- server +/// | | +/// [read ============= write] : TestStream +/// [write ============= read] : TestStream +/// | | +/// TestWebSocket TestWebSocket +/// ``` pub(crate) struct TestStream { - read: Arc>>, - write: Arc>>, + locked: AtomicBool, // It could be more efficient, but now use very simple lock + buf: Arc>>, } -/// SAFETY: Only one of the client - server needs `&mut _` to `write` into -/// each `Arc>>` at a time -impl TestStream { - fn read_half(&self) -> &[u8] { - unsafe {&*self.read.get()} +const _: () = { + impl TestStream { + fn lock(self: Pin<&mut Self>) -> Poll> { + match self.locked.compare_exchange_weak(false, true, Ordering::Acquire, Ordering::Relaxed) { + Ok(_) => Poll::Ready(Lock(self.get_mut())), + Err(_) => Poll::Pending, + } + } } - fn write_half(&self) -> &mut Vec { - unsafe {&mut *self.write.get()} + + struct Lock<'stream>(&'stream mut TestStream); + impl<'stream> Drop for Lock<'stream> { + fn drop(&mut self) { + self.0.locked.store(false, Ordering::Release); + } } -} -impl Clone for TestStream { - fn clone(&self) -> Self { - Self { - read: self.read.clone(), - write: self.write.clone(), + impl<'stream> std::ops::Deref for Lock<'stream> { + type Target = Vec; + fn deref(&self) -> &Self::Target { + unsafe {&*self.0.buf.get()} } } -} + impl<'stream> std::ops::DerefMut for Lock<'stream> { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe {&mut *self.0.buf.get()} + } + } +}; #[cfg(feature="rt_tokio")] const _: () = { impl tokio::io::AsyncRead for TestStream { fn poll_read( @@ -76,9 +97,15 @@ impl Clone for TestStream { cx: &mut std::task::Context<'_>, buf: &mut [u8], ) -> std::task::Poll> { - let mut read = self.read_half(); - let read = unsafe {Pin::new_unchecked(&mut read)}; - read.poll_read(cx, buf) + let Poll::Ready(mut this) = self.lock() + else {cx.waker().wake_by_ref(); return Poll::Pending}; + + let size = (this.len()).min(buf.len()); + let (a, b) = this.split_at(size); + buf.copy_from_slice(a); + *this = b.to_vec(); + + Poll::Ready(Ok(size)) } } @@ -88,18 +115,19 @@ impl Clone for TestStream { cx: &mut std::task::Context<'_>, buf: &[u8], ) -> std::task::Poll> { - unsafe {self.map_unchecked_mut(|this| this.write_half())} - .poll_write(cx, buf) + let Poll::Ready(mut this) = self.lock() + else {cx.waker().wake_by_ref(); return Poll::Pending}; + + this.extend_from_slice(buf); + Poll::Ready(Ok(buf.len())) } - fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { - unsafe {self.map_unchecked_mut(|this| this.write_half())} - .poll_flush(cx) + fn poll_flush(self: Pin<&mut Self>, _: &mut std::task::Context<'_>) -> std::task::Poll> { + Poll::Ready(Ok(())) } - fn poll_close(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { - unsafe {self.map_unchecked_mut(|this| this.write_half())} - .poll_close(cx) + fn poll_close(self: Pin<&mut Self>, _: &mut std::task::Context<'_>) -> std::task::Poll> { + Poll::Ready(Ok(())) } } }; From 355e087043262ba49cee6d9dd3e42956ff2d8d4e Mon Sep 17 00:00:00 2001 From: kana-rus Date: Sat, 25 Nov 2023 20:18:29 +0900 Subject: [PATCH 34/37] @2023-11-25 20:18+9:00 --- ohkami/Cargo.toml | 4 +- ohkami/src/layer4_router/radix.rs | 79 +++++--------- ohkami/src/layer5_ohkami/howl.rs | 1 + ohkami/src/layer6_testing/mod.rs | 61 +++++++++-- ohkami/src/layer6_testing/x_websocket.rs | 130 +++++++++++++++++------ ohkami/src/lib.rs | 10 +- ohkami/src/x_websocket/mod.rs | 3 + ohkami/src/x_websocket/upgrade.rs | 101 ++++++++++-------- ohkami/src/x_websocket/websocket.rs | 2 +- 9 files changed, 244 insertions(+), 147 deletions(-) diff --git a/ohkami/Cargo.toml b/ohkami/Cargo.toml index 0d587f8a..2b3da70a 100644 --- a/ohkami/Cargo.toml +++ b/ohkami/Cargo.toml @@ -37,8 +37,8 @@ DEBUG = [ "serde/derive", "tokio?/macros", "async-std?/attributes" ] default = [ - #"rt_tokio", - "rt_async-std", + "rt_tokio", + #"rt_async-std", "DEBUG", # "nightly" ] \ No newline at end of file diff --git a/ohkami/src/layer4_router/radix.rs b/ohkami/src/layer4_router/radix.rs index ca5e6e69..8aa9dd50 100644 --- a/ohkami/src/layer4_router/radix.rs +++ b/ohkami/src/layer4_router/radix.rs @@ -9,9 +9,8 @@ use crate::{ #[cfg(feature="websocket")] use crate::websocket::{ UpgradeID, + request_upgrade_id, }; -#[cfg(all(feature="websocket", not(test)))] -use crate::websocket::request_upgrade_id; /*===== defs =====*/ @@ -50,8 +49,16 @@ pub(super) enum Pattern { /*===== impls =====*/ -#[cfg(feature="websocket")] type HandleResult = (Response, Option); +#[cfg(feature="websocket")] type HandleResult = (Response, Option); +#[cfg(feature="websocket")] fn __no_upgrade(res: Response) -> HandleResult { + (res, None) +} + #[cfg(not(feature="websocket"))] type HandleResult = Response; +#[cfg(not(feature="websocket"))] fn __no_upgrade(res: Response) -> HandleResult { + res +} + impl RadixRouter { pub(crate) async fn handle( @@ -72,21 +79,14 @@ impl RadixRouter { for ff in front { if let Err(err_res) = ff.0(&mut c, req) { - #[cfg(feature="websocket")] return (err_res, None); - #[cfg(not(feature="websocket"))] return err_res; + return __no_upgrade(err_res) } } let target = match self.GET.search(&mut c, req/*.path_bytes()*/, &mut params) { Ok(Some(node)) => node, - Ok(None) => { - #[cfg(feature="websocket")] return (c.NotFound(), None); - #[cfg(not(feature="websocket"))] return c.NotFound(); - }, - Err(err_res) => { - #[cfg(feature="websocket")] return (err_res, None); - #[cfg(not(feature="websocket"))] return err_res; - }, + Ok(None) => return __no_upgrade(c.NotFound()), + Err(err_res) => return __no_upgrade(err_res), }; let Response { headers, .. } = target.handle_discarding_upgrade(c, req, params).await; @@ -100,21 +100,18 @@ impl RadixRouter { res = bf.0(res) } - #[cfg(feature="websocket")] return (res, None); - #[cfg(not(feature="websocket"))] return res; + return __no_upgrade(res); } Method::OPTIONS => { let Some((cors_str, cors)) = crate::layer3_fang_handler::builtin::CORS.get() else { - #[cfg(feature="websocket")] return (c.InternalServerError(), None); - #[cfg(not(feature="websocket"))] return c.InternalServerError(); + return __no_upgrade(c.InternalServerError()); }; let (front, back) = self.OPTIONSfangs; for ff in front { if let Err(err_res) = ff.0(&mut c, req) { - #[cfg(feature="websocket")] return (err_res, None); - #[cfg(not(feature="websocket"))] return err_res; + return __no_upgrade(err_res); } } @@ -122,40 +119,33 @@ impl RadixRouter { { let Some(origin) = req.header("Origin") else { - #[cfg(feature="websocket")] return (c.BadRequest(), None); - #[cfg(not(feature="websocket"))] return c.BadRequest(); + return __no_upgrade(c.BadRequest()); }; if !cors.AllowOrigin.matches(origin) { - #[cfg(feature="websocket")] return (c.Forbidden(), None); - #[cfg(not(feature="websocket"))] return c.Forbidden(); + return __no_upgrade(c.Forbidden()); } if req.header("Authorization").is_some() && !cors.AllowCredentials { - #[cfg(feature="websocket")] return (c.Forbidden(), None); - #[cfg(not(feature="websocket"))] return c.Forbidden(); + return __no_upgrade(c.Forbidden()); } if let Some(request_method) = req.header("Access-Control-Request-Method") { let request_method = Method::from_bytes(request_method.as_bytes()); let Some(allow_methods) = cors.AllowMethods.as_ref() else { - #[cfg(feature="websocket")] return (c.Forbidden(), None); - #[cfg(not(feature="websocket"))] return c.Forbidden(); + return __no_upgrade(c.Forbidden()); }; if !allow_methods.contains(&request_method) { - #[cfg(feature="websocket")] return (c.Forbidden(), None); - #[cfg(not(feature="websocket"))] return c.Forbidden(); + return __no_upgrade(c.Forbidden()); } } if let Some(request_headers) = req.header("Access-Control-Request-Headers") { let mut request_headers = request_headers.split(',').map(|h| h.trim_matches(' ')); let Some(allow_headers) = cors.AllowHeaders.as_ref() else { - #[cfg(feature="websocket")] return (c.Forbidden(), None); - #[cfg(not(feature="websocket"))] return c.Forbidden(); + return __no_upgrade(c.Forbidden()); }; if !request_headers.all(|h| allow_headers.contains(&h)) { - #[cfg(feature="websocket")] return (c.Forbidden(), None); - #[cfg(not(feature="websocket"))] return c.Forbidden(); + return __no_upgrade(c.Forbidden()); } } } @@ -166,21 +156,14 @@ impl RadixRouter { res = bf.0(res) } - #[cfg(feature="websocket")] return (res, None); - #[cfg(not(feature="websocket"))] return res; + return __no_upgrade(res); } }; let target = match search_result { Ok(Some(node)) => node, - Ok(None) => { - #[cfg(feature="websocket")] return (c.NotFound(), None); - #[cfg(not(feature="websocket"))] return c.NotFound(); - } - Err(err_res) => { - #[cfg(feature="websocket")] return (err_res, None); - #[cfg(not(feature="websocket"))] return err_res; - } + Ok(None) => return __no_upgrade(c.NotFound()), + Err(err_res) => return __no_upgrade(err_res), }; target.handle(c, req, params).await @@ -196,17 +179,12 @@ impl Node { match &self.handler { Some(handler) => { #[cfg(feature="websocket")] - #[cfg(not(test))] let upgrade_id = match (handler.requires_upgrade).then(|| async { let id = request_upgrade_id().await; c.upgrade_id = Some(id); id }) {None => None, Some(id) => Some(id.await)}; - #[cfg(feature="websocket")] - #[cfg(test)] - let upgrade_id = None; - let mut res = (handler.proc)(req, c, params).await; for b in self.back { res = b.0(res); @@ -217,10 +195,7 @@ impl Node { #[cfg(not(feature="websocket"))] {res} } - #[cfg(feature="websocket")] - None => (c.NotFound(), None), - #[cfg(not(feature="websocket"))] - None => c.NotFound(), + None => __no_upgrade(c.NotFound()), } } diff --git a/ohkami/src/layer5_ohkami/howl.rs b/ohkami/src/layer5_ohkami/howl.rs index f79090d4..9d739228 100644 --- a/ohkami/src/layer5_ohkami/howl.rs +++ b/ohkami/src/layer5_ohkami/howl.rs @@ -103,6 +103,7 @@ impl Ohkami { #[cfg(feature="websocket")] Ok(upgrade_id) => { if let Some(id) = upgrade_id { + let stream = __rt__::Mutex::into_inner(Arc::into_inner(stream).unwrap()); unsafe{reserve_upgrade(id, stream)} } } diff --git a/ohkami/src/layer6_testing/mod.rs b/ohkami/src/layer6_testing/mod.rs index b25bf859..66cbb9be 100644 --- a/ohkami/src/layer6_testing/mod.rs +++ b/ohkami/src/layer6_testing/mod.rs @@ -1,7 +1,7 @@ mod _test; mod x_websocket; -pub(crate) use x_websocket::TestWebSocket; +pub(crate) use x_websocket::{TestWebSocket, TestStream}; use std::borrow::Cow; use std::collections::HashMap; @@ -13,21 +13,31 @@ use crate::layer0_lib::{IntoCows, Status, Method, ContentType}; pub trait Testing { - fn oneshot(&self, req: TestRequest) -> TestFuture; - //fn oneshot_and_upgraded(&self, req: TestRequest) -> (TestFuture, TestWebSocket); + fn oneshot(&self, req: TestRequest) -> Oneshot; + + #[cfg(feature="websocket")] + fn oneshot_and_upgraded(&self, req: TestRequest) -> OneshotAndUpgraded; } -pub struct TestFuture( - Box>); -impl Future for TestFuture { +pub struct Oneshot( + Box> +); impl Future for Oneshot { type Output = TestResponse; fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { unsafe {self.map_unchecked_mut(|this| this.0.as_mut())}.poll(cx) } } +pub struct OneshotAndUpgraded( + Box)>> +); impl Future for OneshotAndUpgraded { + type Output = (TestResponse, Option); + fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { + unsafe {self.map_unchecked_mut(|this| this.0.as_mut())}.poll(cx) + } +} impl Testing for Ohkami { - fn oneshot(&self, request: TestRequest) -> TestFuture { + fn oneshot(&self, request: TestRequest) -> Oneshot { let router = { let mut router = self.routes.clone(); for (methods, fang) in &self.fangs { @@ -36,7 +46,7 @@ impl Testing for Ohkami { router.into_radix() }; - let test_res = async move { + let res = async move { let mut req = Request::init(); let mut req = unsafe {Pin::new_unchecked(&mut req)}; req.as_mut().read(&mut &request.encode_request()[..]).await; @@ -49,7 +59,40 @@ impl Testing for Ohkami { TestResponse::new(res) }; - TestFuture(Box::new(test_res)) + Oneshot(Box::new(res)) + } + + #[cfg(feature="websocket")] + fn oneshot_and_upgraded(&self, request: TestRequest) -> OneshotAndUpgraded { + use crate::websocket::{reserve_upgrade_in_test, assume_upgradable_in_test}; + + let router = { + let mut router = self.routes.clone(); + for (methods, fang) in &self.fangs { + router = router.apply_fang(methods, fang.clone()) + } + router.into_radix() + }; + + let res_and_socket = async move { + let mut req = Request::init(); + let mut req = unsafe {Pin::new_unchecked(&mut req)}; + req.as_mut().read(&mut &request.encode_request()[..]).await; + + let (res, upgrade_id) = router.handle(Context::new(), &mut req).await; + match upgrade_id { + None => (TestResponse::new(res), None), + Some(id) => { + let (client, server) = TestWebSocket::new_pair(); + unsafe {reserve_upgrade_in_test(id, client.stream)}; + let _ = assume_upgradable_in_test(id).await; + + __TODO__ + }, + } + }; + + OneshotAndUpgraded(Box::new(res_and_socket)) } } diff --git a/ohkami/src/layer6_testing/x_websocket.rs b/ohkami/src/layer6_testing/x_websocket.rs index cff10282..51412e92 100644 --- a/ohkami/src/layer6_testing/x_websocket.rs +++ b/ohkami/src/layer6_testing/x_websocket.rs @@ -1,5 +1,4 @@ use std::cell::UnsafeCell; -use std::future::Future; use std::pin::Pin; use std::sync::Arc; use std::sync::atomic::{AtomicBool, Ordering}; @@ -7,10 +6,32 @@ use std::task::Poll; pub struct TestWebSocket { - client: TestStream, + pub(crate) stream: TestStream, } impl TestWebSocket { - pub(crate) fn new(stream: TestStream) -> Self { - Self { client: stream } + pub(crate) fn new_pair() -> (Self, Self) { + let (server_read, server_write) = ( + Arc::new( + HalfStream { + locked: AtomicBool::new(false), + buf: UnsafeCell::new(Vec::new()), + } + ), + Arc::new( + HalfStream { + locked: AtomicBool::new(false), + buf: UnsafeCell::new(Vec::new()), + } + ), + ); + let (client_write, client_read) = ( + server_read.clone(), + server_write.clone(), + ); + + ( + Self {stream: TestStream::new(client_read, client_write)}, + Self {stream: TestStream::new(server_read, server_write)} + ) } } @@ -19,40 +40,78 @@ pub struct TestWebSocket { /// ```txt /// client ------------- server /// | | -/// [read ============= write] : TestStream -/// [write ============= read] : TestStream +/// [read ============= write] | +/// =========================== | : TestStream +/// [write ============= read] | /// | | /// TestWebSocket TestWebSocket /// ``` -pub(crate) struct TestStream { - locked: AtomicBool, // It could be more efficient, but now use very simple lock - buf: Arc>>, +pub struct TestStream { + read: Arc, + write: Arc, +} +pub struct HalfStream { + locked: AtomicBool, // It could be more efficient, but now using very simple lock + buf: UnsafeCell> } const _: () = { + unsafe impl Sync for TestStream {} + unsafe impl Send for TestStream {} + impl TestStream { - fn lock(self: Pin<&mut Self>) -> Poll> { - match self.locked.compare_exchange_weak(false, true, Ordering::Acquire, Ordering::Relaxed) { - Ok(_) => Poll::Ready(Lock(self.get_mut())), + pub(crate) fn new(read: Arc, write: Arc) -> Self { + Self { read, write } + } + } + + impl TestStream { + fn read_lock(self: Pin<&mut Self>) -> Poll> { + match self.read.locked.compare_exchange_weak(false, true, Ordering::Acquire, Ordering::Relaxed) { + Ok(_) => Poll::Ready(ReadLock(self.get_mut())), + Err(_) => Poll::Pending, + } + } + fn write_lock(self: Pin<&mut Self>) -> Poll> { + match self.write.locked.compare_exchange_weak(false, true, Ordering::Acquire, Ordering::Relaxed) { + Ok(_) => Poll::Ready(WriteLock(self.get_mut())), Err(_) => Poll::Pending, } } } - struct Lock<'stream>(&'stream mut TestStream); - impl<'stream> Drop for Lock<'stream> { + struct ReadLock<'stream>(&'stream mut TestStream); + impl<'stream> Drop for ReadLock<'stream> { fn drop(&mut self) { - self.0.locked.store(false, Ordering::Release); + self.0.read.locked.store(false, Ordering::Release); } } - impl<'stream> std::ops::Deref for Lock<'stream> { + impl<'stream> std::ops::Deref for ReadLock<'stream> { type Target = Vec; fn deref(&self) -> &Self::Target { - unsafe {&*self.0.buf.get()} + unsafe {&*self.0.read.buf.get()} } } - impl<'stream> std::ops::DerefMut for Lock<'stream> { + impl<'stream> std::ops::DerefMut for ReadLock<'stream> { fn deref_mut(&mut self) -> &mut Self::Target { - unsafe {&mut *self.0.buf.get()} + unsafe {&mut *self.0.read.buf.get()} + } + } + + struct WriteLock<'stream>(&'stream mut TestStream); + impl<'stream> Drop for WriteLock<'stream> { + fn drop(&mut self) { + self.0.write.locked.store(false, Ordering::Release); + } + } + impl<'stream> std::ops::Deref for WriteLock<'stream> { + type Target = Vec; + fn deref(&self) -> &Self::Target { + unsafe {&*self.0.write.buf.get()} + } + } + impl<'stream> std::ops::DerefMut for WriteLock<'stream> { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe {&mut *self.0.write.buf.get()} } } }; @@ -63,9 +122,15 @@ const _: () = { cx: &mut std::task::Context<'_>, buf: &mut tokio::io::ReadBuf<'_>, ) -> std::task::Poll> { - let mut read = &self.read[..]; - let read = unsafe {Pin::new_unchecked(&mut read)}; - read.poll_read(cx, buf) + let Poll::Ready(mut this) = self.read_lock() + else {cx.waker().wake_by_ref(); return Poll::Pending}; + + let size = (this.len()).min(buf.remaining()); + let (a, b) = this.split_at(size); + buf.put_slice(a); + *this = b.to_vec(); + + Poll::Ready(Ok(())) } } @@ -75,18 +140,19 @@ const _: () = { cx: &mut std::task::Context<'_>, buf: &[u8], ) -> std::task::Poll> { - unsafe {self.map_unchecked_mut(|this| &mut this.write)} - .poll_write(cx, buf) + let Poll::Ready(mut this) = self.write_lock() + else {cx.waker().wake_by_ref(); return Poll::Pending}; + + this.extend_from_slice(buf); + Poll::Ready(Ok(buf.len())) } - fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { - unsafe {self.map_unchecked_mut(|this| &mut this.write)} - .poll_flush(cx) + fn poll_flush(self: Pin<&mut Self>, _: &mut std::task::Context<'_>) -> std::task::Poll> { + Poll::Ready(Ok(())) } - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { - unsafe {self.map_unchecked_mut(|this| &mut this.write)} - .poll_shutdown(cx) + fn poll_shutdown(self: Pin<&mut Self>, _: &mut std::task::Context<'_>) -> std::task::Poll> { + Poll::Ready(Ok(())) } } }; @@ -97,7 +163,7 @@ const _: () = { cx: &mut std::task::Context<'_>, buf: &mut [u8], ) -> std::task::Poll> { - let Poll::Ready(mut this) = self.lock() + let Poll::Ready(mut this) = self.read_lock() else {cx.waker().wake_by_ref(); return Poll::Pending}; let size = (this.len()).min(buf.len()); @@ -115,7 +181,7 @@ const _: () = { cx: &mut std::task::Context<'_>, buf: &[u8], ) -> std::task::Poll> { - let Poll::Ready(mut this) = self.lock() + let Poll::Ready(mut this) = self.write_lock() else {cx.waker().wake_by_ref(); return Poll::Pending}; this.extend_from_slice(buf); diff --git a/ohkami/src/lib.rs b/ohkami/src/lib.rs index ba079c18..3b1b5e98 100644 --- a/ohkami/src/lib.rs +++ b/ohkami/src/lib.rs @@ -221,9 +221,9 @@ mod __rt__ { pub(crate) use async_std::test; #[cfg(feature="rt_tokio")] - pub(crate) use tokio::sync::Mutex; + pub(crate) use tokio::net::TcpStream; #[cfg(feature="rt_async-std")] - pub(crate) use async_std::sync::Mutex; + pub(crate) use async_std::net::TcpStream; #[cfg(feature="rt_tokio")] pub(crate) use tokio::net::TcpListener; @@ -231,9 +231,9 @@ mod __rt__ { pub(crate) use async_std::net::TcpListener; #[cfg(feature="rt_tokio")] - pub(crate) use tokio::net::TcpStream; + pub(crate) use tokio::sync::Mutex; #[cfg(feature="rt_async-std")] - pub(crate) use async_std::net::TcpStream; + pub(crate) use async_std::sync::Mutex; #[cfg(all(feature="rt_tokio", feature="websocket"))] pub(crate) use tokio::net::tcp::{ReadHalf, WriteHalf}; @@ -325,6 +325,7 @@ pub mod __internal__ { // fangs struct AppendHeader; impl IntoFang for AppendHeader { + //const METHODS: &'static [Method] = &[Method::GET]; fn bite(self) -> Fang { Fang(|c: &mut Context, _: &mut Request| { c.headers.Server("ohkami"); @@ -334,6 +335,7 @@ pub mod __internal__ { struct Log; impl IntoFang for Log { + //const METHODS: &'static [Method] = &[]; fn bite(self) -> Fang { Fang(|res: Response| { println!("{res:?}"); diff --git a/ohkami/src/x_websocket/mod.rs b/ohkami/src/x_websocket/mod.rs index eb2de761..465d75fd 100644 --- a/ohkami/src/x_websocket/mod.rs +++ b/ohkami/src/x_websocket/mod.rs @@ -16,3 +16,6 @@ pub use { pub(crate) use { upgrade::{UpgradeID, request_upgrade_id, reserve_upgrade, assume_upgradable}, }; +#[cfg(test)] pub(crate) use { + upgrade::{reserve_upgrade_in_test, assume_upgradable_in_test}, +}; diff --git a/ohkami/src/x_websocket/upgrade.rs b/ohkami/src/x_websocket/upgrade.rs index 1ff7f724..316ce3f5 100644 --- a/ohkami/src/x_websocket/upgrade.rs +++ b/ohkami/src/x_websocket/upgrade.rs @@ -4,12 +4,7 @@ use std::{ future::Future, }; use crate::__rt__::{TcpStream}; - -#[cfg(feature="rt_tokio")] use { - std::sync::Arc, - crate::__rt__::Mutex, -}; - +#[cfg(test)] use crate::layer6_testing::TestStream; pub async fn request_upgrade_id() -> UpgradeID { struct ReserveUpgrade; @@ -33,12 +28,8 @@ pub async fn request_upgrade_id() -> UpgradeID { ReserveUpgrade.await } -/// SAFETY: This must be called after the corresponded `reserve_upgrade` -pub unsafe fn reserve_upgrade( - id: UpgradeID, - #[cfg(feature="rt_tokio")] stream: Arc>, - #[cfg(feature="rt_async-std")] stream: TcpStream, -) { +/// SAFETY: This must be called after the corresponded `request_upgrade_id` +pub unsafe fn reserve_upgrade(id: UpgradeID, stream: TcpStream) { #[cfg(debug_assertions)] assert!( UpgradeStreams().get().get(id.as_usize()).is_some_and( |cell| cell.reserved && cell.stream.is_some()), @@ -47,6 +38,16 @@ pub unsafe fn reserve_upgrade( (UpgradeStreams().get_mut())[id.as_usize()].stream = Some(stream); } +/// SAFETY: This must be called after the corresponded `request_upgrade_id_in_test` +#[cfg(test)] pub unsafe fn reserve_upgrade_in_test(id: UpgradeID, stream: TestStream) { + #[cfg(debug_assertions)] assert!( + UpgradeStreams().get().get(id.as_usize()).is_some_and( + |cell| cell.reserved && cell.stream.is_some()), + "Cell not reserved" + ); + + (UpgradeStreamsInTest().get_mut())[id.as_usize()].stream = Some(stream); +} pub async fn assume_upgradable(id: UpgradeID) -> TcpStream { struct AssumeUpgraded{id: UpgradeID} @@ -56,87 +57,93 @@ pub async fn assume_upgradable(id: UpgradeID) -> TcpStream { let Some(StreamCell { reserved, stream }) = (unsafe {UpgradeStreams().get_mut()}).get_mut(self.id.as_usize()) else {cx.waker().wake_by_ref(); return std::task::Poll::Pending}; - #[cfg(feature="rt_tokio")] - if !stream.as_ref().is_some_and(|arc| Arc::strong_count(arc) == 1) - {cx.waker().wake_by_ref(); return std::task::Poll::Pending}; - #[cfg(feature="rt_async-std")] - if !stream.is_some() + if stream.is_none() {cx.waker().wake_by_ref(); return std::task::Poll::Pending}; *reserved = false; - #[cfg(feature="rt_tokio")] { - std::task::Poll::Ready(unsafe { - Mutex::into_inner( - Arc::into_inner( - Option::take(stream) - .unwrap_unchecked()) - .unwrap_unchecked())}) - } - #[cfg(feature="rt_async-std")] { - std::task::Poll::Ready(unsafe { - Option::take(stream) - .unwrap_unchecked()}) - } + std::task::Poll::Ready(unsafe {stream.take().unwrap_unchecked()}) } } AssumeUpgraded{id}.await } +#[cfg(test)] pub async fn assume_upgradable_in_test(id: UpgradeID) -> TestStream { + struct AssumeUpgradedInTest{id: UpgradeID} + impl Future for AssumeUpgradedInTest { + type Output = TestStream; + fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { + let Some(StreamCell { reserved, stream }) = (unsafe {UpgradeStreamsInTest().get_mut()}).get_mut(self.id.as_usize()) + else {cx.waker().wake_by_ref(); return std::task::Poll::Pending}; + + if stream.is_none() + {cx.waker().wake_by_ref(); return std::task::Poll::Pending}; + + *reserved = false; + + std::task::Poll::Ready(unsafe {stream.take().unwrap_unchecked()}) + } + } + + AssumeUpgradedInTest{id}.await +} static UPGRADE_STREAMS: OnceLock = OnceLock::new(); +#[cfg(test)] static UPGRADE_STREAMS_IN_TEST: OnceLock> = OnceLock::new(); + #[allow(non_snake_case)] fn UpgradeStreams() -> &'static UpgradeStreams { UPGRADE_STREAMS.get_or_init(UpgradeStreams::new) } +#[cfg(test)] #[allow(non_snake_case)] fn UpgradeStreamsInTest() -> &'static UpgradeStreams { + UPGRADE_STREAMS_IN_TEST.get_or_init(UpgradeStreams::::new) +} -struct UpgradeStreams { +struct UpgradeStreams { in_scanning: AtomicBool, - streams: UnsafeCell>, + streams: UnsafeCell>>, } const _: () = { - unsafe impl Sync for UpgradeStreams {} + unsafe impl Sync for UpgradeStreams {} - impl UpgradeStreams { + impl UpgradeStreams { fn new() -> Self { Self { in_scanning: AtomicBool::new(false), streams: UnsafeCell::new(Vec::new()), } } - fn get(&self) -> &Vec { + fn get(&self) -> &Vec> { unsafe {&*self.streams.get()} } - unsafe fn get_mut(&self) -> &mut Vec { + unsafe fn get_mut(&self) -> &mut Vec> { &mut *self.streams.get() } - fn request_reservation(&self) -> Option> { + fn request_reservation(&self) -> Option> { self.in_scanning.compare_exchange_weak(false, true, Ordering::Acquire, Ordering::Relaxed) .ok().and(Some(ReservationLock(unsafe {self.get_mut()}))) } } - struct ReservationLock<'scan>(&'scan mut Vec); - impl<'scan> Drop for ReservationLock<'scan> { + struct ReservationLock<'scan, Stream = TcpStream>(&'scan mut Vec>); + impl<'scan, Stream> Drop for ReservationLock<'scan, Stream> { fn drop(&mut self) { UpgradeStreams().in_scanning.store(false, Ordering::Release) } } - impl<'scan> std::ops::Deref for ReservationLock<'scan> { - type Target = Vec; + impl<'scan, Stream> std::ops::Deref for ReservationLock<'scan, Stream> { + type Target = Vec>; fn deref(&self) -> &Self::Target {&*self.0} } - impl<'scan> std::ops::DerefMut for ReservationLock<'scan> { + impl<'scan, Stream> std::ops::DerefMut for ReservationLock<'scan, Stream> { fn deref_mut(&mut self) -> &mut Self::Target {self.0} } }; -struct StreamCell { +struct StreamCell { reserved: bool, - - #[cfg(feature="rt_tokio")] stream: Option>>, - #[cfg(feature="rt_async-std")] stream: Option, + stream: Option, } const _: () = { - impl StreamCell { + impl StreamCell { fn new() -> Self { Self { reserved: false, diff --git a/ohkami/src/x_websocket/websocket.rs b/ohkami/src/x_websocket/websocket.rs index d0fc48f1..1ba53154 100644 --- a/ohkami/src/x_websocket/websocket.rs +++ b/ohkami/src/x_websocket/websocket.rs @@ -91,7 +91,7 @@ impl WebSocket { } } -#[cfg(feature="rt_tokio")] const _: () = { +#[cfg(all(not(test), feature="rt_tokio"))] const _: () = { impl WebSocket { pub fn split(&mut self) -> (ReadHalf, WriteHalf) { let (rh, wh) = self.stream.split(); From 0cb6c30665dafb29aa7df2e1a41947c03a6f2b79 Mon Sep 17 00:00:00 2001 From: kana-rus Date: Sun, 26 Nov 2023 00:51:46 +0900 Subject: [PATCH 35/37] @2023-11-26 00:51+9:00 --- ohkami/src/layer6_testing/mod.rs | 18 ++-- ohkami/src/layer6_testing/x_websocket.rs | 86 ++++++++++++------- ohkami/src/lib.rs | 6 +- ohkami/src/x_websocket/context.rs | 12 ++- ohkami/src/x_websocket/message.rs | 2 +- ohkami/src/x_websocket/mod.rs | 17 +++- ohkami/src/x_websocket/upgrade.rs | 14 ++-- ohkami/src/x_websocket/websocket.rs | 102 ++++++++++++----------- 8 files changed, 152 insertions(+), 105 deletions(-) diff --git a/ohkami/src/layer6_testing/mod.rs b/ohkami/src/layer6_testing/mod.rs index 66cbb9be..1f45c0d6 100644 --- a/ohkami/src/layer6_testing/mod.rs +++ b/ohkami/src/layer6_testing/mod.rs @@ -1,16 +1,17 @@ mod _test; mod x_websocket; -pub(crate) use x_websocket::{TestWebSocket, TestStream}; +#[cfg(feature="websocket")] +pub(crate) use x_websocket::{TestStream, TestWebSocket}; + +use crate::{Response, Request, Ohkami, Context}; +use crate::layer0_lib::{IntoCows, Status, Method, ContentType}; use std::borrow::Cow; use std::collections::HashMap; use std::{pin::Pin, future::Future, format as f}; use byte_reader::Reader; -use crate::{Response, Request, Ohkami, Context}; -use crate::layer0_lib::{IntoCows, Status, Method, ContentType}; - pub trait Testing { fn oneshot(&self, req: TestRequest) -> Oneshot; @@ -64,7 +65,7 @@ impl Testing for Ohkami { #[cfg(feature="websocket")] fn oneshot_and_upgraded(&self, request: TestRequest) -> OneshotAndUpgraded { - use crate::websocket::{reserve_upgrade_in_test, assume_upgradable_in_test}; + use crate::websocket::{reserve_upgrade_in_test}; let router = { let mut router = self.routes.clone(); @@ -83,11 +84,10 @@ impl Testing for Ohkami { match upgrade_id { None => (TestResponse::new(res), None), Some(id) => { - let (client, server) = TestWebSocket::new_pair(); - unsafe {reserve_upgrade_in_test(id, client.stream)}; - let _ = assume_upgradable_in_test(id).await; + let (client, server) = TestStream::new_pair(); + unsafe {reserve_upgrade_in_test(id, server)}; - __TODO__ + (TestResponse::new(res), Some(TestWebSocket::new(client))) }, } }; diff --git a/ohkami/src/layer6_testing/x_websocket.rs b/ohkami/src/layer6_testing/x_websocket.rs index 51412e92..081150fb 100644 --- a/ohkami/src/layer6_testing/x_websocket.rs +++ b/ohkami/src/layer6_testing/x_websocket.rs @@ -1,37 +1,40 @@ +use crate::x_websocket::{Message}; +use crate::x_websocket::{Config, send, write, flush}; + use std::cell::UnsafeCell; use std::pin::Pin; use std::sync::Arc; +use std::io::{Error}; use std::sync::atomic::{AtomicBool, Ordering}; use std::task::Poll; +/// Web socket client for test with upgrade pub struct TestWebSocket { - pub(crate) stream: TestStream, -} impl TestWebSocket { - pub(crate) fn new_pair() -> (Self, Self) { - let (server_read, server_write) = ( - Arc::new( - HalfStream { - locked: AtomicBool::new(false), - buf: UnsafeCell::new(Vec::new()), - } - ), - Arc::new( - HalfStream { - locked: AtomicBool::new(false), - buf: UnsafeCell::new(Vec::new()), - } - ), - ); - let (client_write, client_read) = ( - server_read.clone(), - server_write.clone(), - ); + stream: TestStream, + n_buffered: usize, +} +impl TestWebSocket { + pub(crate) fn new(client_stream: TestStream) -> Self { + Self { + stream: client_stream, + n_buffered: 0, + } + } +} +impl TestWebSocket { + pub async fn recv(&mut self) -> Result, Error> { + Message::read_from(&mut self.stream, &Config::default()).await + } - ( - Self {stream: TestStream::new(client_read, client_write)}, - Self {stream: TestStream::new(server_read, server_write)} - ) + pub async fn send(&mut self, message: Message) -> Result<(), Error> { + send(message, &mut self.stream, &Config::default(), &mut self.n_buffered).await + } + pub async fn write(&mut self, message: Message) -> Result { + write(message, &mut self.stream, &Config::default(), &mut self.n_buffered).await + } + pub async fn flush(&mut self) -> Result<(), Error> { + flush(&mut self.stream, &mut self.n_buffered).await } } @@ -44,7 +47,6 @@ pub struct TestWebSocket { /// =========================== | : TestStream /// [write ============= read] | /// | | -/// TestWebSocket TestWebSocket /// ``` pub struct TestStream { read: Arc, @@ -54,16 +56,36 @@ pub struct HalfStream { locked: AtomicBool, // It could be more efficient, but now using very simple lock buf: UnsafeCell> } + +impl TestStream { + pub(crate) fn new_pair() -> (Self, Self) { + let (client_read, client_write) = ( + Arc::new(HalfStream { + locked: AtomicBool::new(false), + buf: UnsafeCell::new(Vec::new()), + }), + Arc::new(HalfStream { + locked: AtomicBool::new(false), + buf: UnsafeCell::new(Vec::new()), + }), + ); + + let (server_write, server_read) = ( + client_read.clone(), + client_write.clone(), + ); + + ( + Self { read:client_read, write:client_write }, + Self { read:server_read, write:server_write }, + ) + } +} + const _: () = { unsafe impl Sync for TestStream {} unsafe impl Send for TestStream {} - impl TestStream { - pub(crate) fn new(read: Arc, write: Arc) -> Self { - Self { read, write } - } - } - impl TestStream { fn read_lock(self: Pin<&mut Self>) -> Poll> { match self.read.locked.compare_exchange_weak(false, true, Ordering::Acquire, Ordering::Relaxed) { diff --git a/ohkami/src/lib.rs b/ohkami/src/lib.rs index 3b1b5e98..68b5e17d 100644 --- a/ohkami/src/lib.rs +++ b/ohkami/src/lib.rs @@ -235,9 +235,9 @@ mod __rt__ { #[cfg(feature="rt_async-std")] pub(crate) use async_std::sync::Mutex; - #[cfg(all(feature="rt_tokio", feature="websocket"))] - pub(crate) use tokio::net::tcp::{ReadHalf, WriteHalf}; - /* async-std doesn't have `split` */ + // #[cfg(all(feature="rt_tokio", feature="websocket"))] + // pub(crate) use tokio::net::tcp::{ReadHalf, WriteHalf}; + // /* async-std doesn't have `split` */ #[cfg(feature="rt_tokio")] pub(crate) use tokio::task; diff --git a/ohkami/src/x_websocket/context.rs b/ohkami/src/x_websocket/context.rs index ec9be272..ac11fe2f 100644 --- a/ohkami/src/x_websocket/context.rs +++ b/ohkami/src/x_websocket/context.rs @@ -1,10 +1,13 @@ use std::{future::Future, borrow::Cow}; use super::websocket::Config; -use super::{WebSocket, sign, assume_upgradable}; +use super::{WebSocket, sign}; use crate::{Response, Context, Request}; use crate::__rt__::{task}; use crate::http::{Method}; +#[cfg(test)] use crate::websocket::assume_upgradable_in_test; +#[cfg(not(test))] use super::assume_upgradable; + pub struct WebSocketContext { c: Context, @@ -118,8 +121,13 @@ impl WebSocketContext { task::spawn({ async move { let stream = match c.upgrade_id { + None => return on_failed_upgrade.handle(UpgradeError::NotRequestedUpgrade), + + #[cfg(not(test))] Some(id) => assume_upgradable(id).await, - None => return on_failed_upgrade.handle(UpgradeError::NotRequestedUpgrade), + + #[cfg(test)] + Some(id) => assume_upgradable_in_test(id).await, }; let ws = WebSocket::new(stream, config); diff --git a/ohkami/src/x_websocket/message.rs b/ohkami/src/x_websocket/message.rs index cbab9e9c..26735575 100644 --- a/ohkami/src/x_websocket/message.rs +++ b/ohkami/src/x_websocket/message.rs @@ -77,7 +77,7 @@ impl Message { } impl Message { - pub(super) async fn read_from( + pub(crate) async fn read_from( stream: &mut (impl AsyncReader + Unpin), config: &Config, ) -> Result, Error> { diff --git a/ohkami/src/x_websocket/mod.rs b/ohkami/src/x_websocket/mod.rs index 465d75fd..ddcb0089 100644 --- a/ohkami/src/x_websocket/mod.rs +++ b/ohkami/src/x_websocket/mod.rs @@ -13,9 +13,20 @@ pub use { websocket::{WebSocket}, context::{WebSocketContext, UpgradeError}, }; +#[cfg(test)] pub(crate) use websocket::{ + Config, + send, + write, + flush, +}; + pub(crate) use { - upgrade::{UpgradeID, request_upgrade_id, reserve_upgrade, assume_upgradable}, + upgrade::{UpgradeID, request_upgrade_id, reserve_upgrade}, +}; +#[cfg(not(test))] pub(crate) use upgrade::{ + assume_upgradable, }; -#[cfg(test)] pub(crate) use { - upgrade::{reserve_upgrade_in_test, assume_upgradable_in_test}, +#[cfg(test)] pub(crate) use upgrade::{ + reserve_upgrade_in_test, + assume_upgradable_in_test, }; diff --git a/ohkami/src/x_websocket/upgrade.rs b/ohkami/src/x_websocket/upgrade.rs index 316ce3f5..94f128b1 100644 --- a/ohkami/src/x_websocket/upgrade.rs +++ b/ohkami/src/x_websocket/upgrade.rs @@ -49,9 +49,9 @@ pub unsafe fn reserve_upgrade(id: UpgradeID, stream: TcpStream) { (UpgradeStreamsInTest().get_mut())[id.as_usize()].stream = Some(stream); } -pub async fn assume_upgradable(id: UpgradeID) -> TcpStream { - struct AssumeUpgraded{id: UpgradeID} - impl Future for AssumeUpgraded { +#[cfg(not(test))] pub async fn assume_upgradable(id: UpgradeID) -> TcpStream { + struct AssumeUpgradable{id: UpgradeID} + impl Future for AssumeUpgradable { type Output = TcpStream; fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { let Some(StreamCell { reserved, stream }) = (unsafe {UpgradeStreams().get_mut()}).get_mut(self.id.as_usize()) @@ -66,11 +66,11 @@ pub async fn assume_upgradable(id: UpgradeID) -> TcpStream { } } - AssumeUpgraded{id}.await + AssumeUpgradable{id}.await } #[cfg(test)] pub async fn assume_upgradable_in_test(id: UpgradeID) -> TestStream { - struct AssumeUpgradedInTest{id: UpgradeID} - impl Future for AssumeUpgradedInTest { + struct AssumeUpgradableInTest{id: UpgradeID} + impl Future for AssumeUpgradableInTest { type Output = TestStream; fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { let Some(StreamCell { reserved, stream }) = (unsafe {UpgradeStreamsInTest().get_mut()}).get_mut(self.id.as_usize()) @@ -85,7 +85,7 @@ pub async fn assume_upgradable(id: UpgradeID) -> TcpStream { } } - AssumeUpgradedInTest{id}.await + AssumeUpgradableInTest{id}.await } diff --git a/ohkami/src/x_websocket/websocket.rs b/ohkami/src/x_websocket/websocket.rs index 1ba53154..39366407 100644 --- a/ohkami/src/x_websocket/websocket.rs +++ b/ohkami/src/x_websocket/websocket.rs @@ -1,14 +1,19 @@ use std::io::Error; use super::{Message}; -use crate::__rt__::{TcpStream, AsyncWriter}; +use crate::__rt__::{AsyncWriter}; +#[cfg(test)] use crate::layer6_testing::TestStream as Stream; +#[cfg(not(test))] use crate::__rt__::TcpStream as Stream; + +/// In current version, `split` to read / write halves is not supported pub struct WebSocket { - stream: TcpStream, + stream: Stream, config: Config, n_buffered: usize, } + // :fields may set through `WebSocketContext`'s methods pub struct Config { pub(crate) write_buffer_size: usize, @@ -31,7 +36,7 @@ pub struct Config { }; impl WebSocket { - pub(crate) fn new(stream: TcpStream, config: Config) -> Self { + pub(crate) fn new(stream: Stream, config: Config) -> Self { Self { stream, config, n_buffered:0 } } } @@ -43,7 +48,7 @@ impl WebSocket { } // ============================================================================= -async fn send(message:Message, +pub(crate) async fn send(message:Message, stream: &mut (impl AsyncWriter + Unpin), config: &Config, n_buffered: &mut usize, @@ -52,7 +57,7 @@ async fn send(message:Message, flush(stream, n_buffered).await?; Ok(()) } -async fn write(message:Message, +pub(crate) async fn write(message:Message, stream: &mut (impl AsyncWriter + Unpin), config: &Config, n_buffered: &mut usize, @@ -70,7 +75,7 @@ async fn write(message:Message, Ok(n) } -async fn flush( +pub(crate) async fn flush( stream: &mut (impl AsyncWriter + Unpin), n_buffered: &mut usize, ) -> Result<(), Error> { @@ -91,45 +96,46 @@ impl WebSocket { } } -#[cfg(all(not(test), feature="rt_tokio"))] const _: () = { - impl WebSocket { - pub fn split(&mut self) -> (ReadHalf, WriteHalf) { - let (rh, wh) = self.stream.split(); - let config = &self.config; - let n_buffered = self.n_buffered; - (ReadHalf {config, stream:rh}, WriteHalf {config, n_buffered, stream:wh}) - } - } - - use crate::__rt__::{ - ReadHalf as TcpReadHalf, - WriteHalf as TcpWriteHalf, - }; - - pub struct ReadHalf<'ws> { - stream: TcpReadHalf<'ws>, - config: &'ws Config, - } - impl<'ws> ReadHalf<'ws> { - pub async fn recv(&mut self) -> Result, Error> { - Message::read_from(&mut self.stream, &self.config).await - } - } - - pub struct WriteHalf<'ws> { - stream: TcpWriteHalf<'ws>, - config: &'ws Config, - n_buffered: usize, - } - impl<'ws> WriteHalf<'ws> { - pub async fn send(&mut self, message: Message) -> Result<(), Error> { - send(message, &mut self.stream, &self.config, &mut self.n_buffered).await - } - pub async fn write(&mut self, message: Message) -> Result { - write(message, &mut self.stream, &self.config, &mut self.n_buffered).await - } - pub async fn flush(&mut self) -> Result<(), Error> { - flush(&mut self.stream, &mut self.n_buffered).await - } - } -}; +// #[cfg(feature="rt_tokio")] const _: () = { +// impl WebSocket { +// pub fn split(&mut self) -> (ReadHalf<'_, Stream>, WriteHalf<'_, Stream>) { +// let (rh, wh) = self.stream.split(); +// let config = &self.config; +// let n_buffered = self.n_buffered; +// (ReadHalf {config, stream:rh}, WriteHalf {config, n_buffered, stream:wh}) +// } +// } +// +// use crate::__rt__::{ +// ReadHalf as TcpReadHalf, +// WriteHalf as TcpWriteHalf, +// }; +// +// pub struct ReadHalf<'ws, Stream: AsyncReader + AsyncWriter + Unpin> { +// stream: &'ws Stream, +// config: &'ws Config, +// } +// impl<'ws, Stream: AsyncReader + AsyncWriter + Unpin> ReadHalf<'ws, Stream> { +// pub async fn recv(&mut self) -> Result, Error> { +// Message::read_from(&mut self.stream, &self.config).await +// } +// } +// +// pub struct WriteHalf<'ws, Stream: AsyncReader + AsyncWriter + Unpin> { +// stream: &'ws Stream, +// config: &'ws Config, +// n_buffered: usize, +// } +// impl<'ws, Stream: AsyncReader + AsyncWriter + Unpin> WriteHalf<'ws, Stream> { +// pub async fn send(&mut self, message: Message) -> Result<(), Error> { +// send(message, &mut self.stream, &self.config, &mut self.n_buffered).await +// } +// pub async fn write(&mut self, message: Message) -> Result { +// write(message, &mut self.stream, &self.config, &mut self.n_buffered).await +// } +// pub async fn flush(&mut self) -> Result<(), Error> { +// flush(&mut self.stream, &mut self.n_buffered).await +// } +// } +// }; +// \ No newline at end of file From 0389c8608ec1e937888ce6b6bb7d2c124ac7e03f Mon Sep 17 00:00:00 2001 From: kana-rus Date: Sun, 26 Nov 2023 00:54:11 +0900 Subject: [PATCH 36/37] @2023-11-26 00:54+9:00 --- ohkami/src/layer6_testing/x_websocket.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ohkami/src/layer6_testing/x_websocket.rs b/ohkami/src/layer6_testing/x_websocket.rs index 081150fb..94a9dfca 100644 --- a/ohkami/src/layer6_testing/x_websocket.rs +++ b/ohkami/src/layer6_testing/x_websocket.rs @@ -10,6 +10,8 @@ use std::task::Poll; /// Web socket client for test with upgrade +/// +/// In current version, `split` to read / write halves is not supported. pub struct TestWebSocket { stream: TestStream, n_buffered: usize, From 3e4ab49d82b67d64363b279fea2f8940c92bd062 Mon Sep 17 00:00:00 2001 From: kana-rus Date: Sun, 26 Nov 2023 16:01:05 +0900 Subject: [PATCH 37/37] impl mask --- ohkami/src/layer6_testing/x_websocket.rs | 35 +++++++-- ohkami/src/x_websocket/frame.rs | 90 ++++++++++++++++++++++-- ohkami/src/x_websocket/message.rs | 68 ++++++++++-------- ohkami/src/x_websocket/mod.rs | 3 - ohkami/src/x_websocket/websocket.rs | 8 ++- 5 files changed, 157 insertions(+), 47 deletions(-) diff --git a/ohkami/src/layer6_testing/x_websocket.rs b/ohkami/src/layer6_testing/x_websocket.rs index 94a9dfca..f5fe5c3c 100644 --- a/ohkami/src/layer6_testing/x_websocket.rs +++ b/ohkami/src/layer6_testing/x_websocket.rs @@ -1,5 +1,7 @@ +use tokio::io::AsyncWriteExt; + use crate::x_websocket::{Message}; -use crate::x_websocket::{Config, send, write, flush}; +use crate::x_websocket::{Config}; use std::cell::UnsafeCell; use std::pin::Pin; @@ -26,17 +28,40 @@ impl TestWebSocket { } impl TestWebSocket { pub async fn recv(&mut self) -> Result, Error> { - Message::read_from(&mut self.stream, &Config::default()).await + // ========================= // + let config = Config::default(); + // ========================= // + + Message::read_from(&mut self.stream, &config).await } pub async fn send(&mut self, message: Message) -> Result<(), Error> { - send(message, &mut self.stream, &Config::default(), &mut self.n_buffered).await + self.write(message).await?; + self.flush().await?; + Ok(()) } pub async fn write(&mut self, message: Message) -> Result { - write(message, &mut self.stream, &Config::default(), &mut self.n_buffered).await + // ========================= // + let config = Config::default(); + let mask = [12, 34, 56, 78]; + // ========================= // + + let n = message.masking_write(&mut self.stream, &config, mask).await?; + + self.n_buffered += n; + if self.n_buffered > config.write_buffer_size { + if self.n_buffered > config.max_write_buffer_size { + panic!("Buffered messages is larger than `max_write_buffer_size`"); + } else { + self.flush().await?; + } + } + + Ok(n) } pub async fn flush(&mut self) -> Result<(), Error> { - flush(&mut self.stream, &mut self.n_buffered).await + self.stream.flush().await + .map(|_| self.n_buffered = 0) } } diff --git a/ohkami/src/x_websocket/frame.rs b/ohkami/src/x_websocket/frame.rs index 6e024a67..33fb224b 100644 --- a/ohkami/src/x_websocket/frame.rs +++ b/ohkami/src/x_websocket/frame.rs @@ -66,7 +66,7 @@ pub struct Frame { pub mask: Option<[u8; 4]>, pub payload: Vec, } impl Frame { - pub(super) async fn read_from( + pub(crate) async fn read_from( stream: &mut (impl AsyncReader + Unpin), config: &Config, ) -> Result, Error> { @@ -126,18 +126,31 @@ pub struct Frame { let payload = { let mut payload = Vec::with_capacity(payload_len); stream.read_exact(&mut payload).await?; + + if let Some(masking_bytes) = mask { + let mut i = 0; + for b in &mut payload { + *b = *b ^ masking_bytes[i]; + + /* + i = if i == 3 {0} else {i + 1}; + */ + i = (i + 1) & 0b00000011; + } + } + payload }; Ok(Some(Self { is_final, opcode, mask, payload })) } - pub(super) async fn write_to(self, + pub(crate) async fn write_unmasked(self, stream: &mut (impl AsyncWriter + Unpin), _config: &Config, ) -> Result { fn into_bytes(frame: Frame) -> Vec { - let Frame { is_final, opcode, mask, payload } = frame; + let Frame { is_final, opcode, payload, mask:_ } = frame; let (payload_len_byte, payload_len_bytes) = match payload.len() { ..=125 => (payload.len() as u8, None), @@ -145,15 +158,78 @@ pub struct Frame { _ => (127, Some((|| (payload.len() as u64).to_be_bytes().to_vec())())), }; - let first = is_final.then_some(1).unwrap_or(0) << 7 + opcode.into_byte(); - let second = mask.is_some().then_some(1).unwrap_or(0) << 7 + payload_len_byte; + let first = (is_final as u8) << 7 + opcode.into_byte(); + let second = 0 << 7 + payload_len_byte; let mut header_bytes = vec![first, second]; if let Some(mut payload_len_bytes) = payload_len_bytes { header_bytes.append(&mut payload_len_bytes) } - if let Some(mask_bytes) = mask { - header_bytes.extend(mask_bytes) + + [header_bytes, payload].concat() + } + + stream.write(&into_bytes(self)).await + } + + #[cfg(test) /* used in `Message::masking_write` */ ] + pub(crate) async fn write_masked(self, + stream: &mut (impl AsyncWriter + Unpin), + _config: &Config, + ) -> Result { + fn into_bytes(frame: Frame) -> Vec { + let Frame { is_final, opcode, mask, mut payload } = frame; + + let (payload_len_byte, payload_len_bytes) = match payload.len() { + ..=125 => (payload.len() as u8, None), + 126..=65535 => (126, Some((|| (payload.len() as u16).to_be_bytes().to_vec())())), + _ => (127, Some((|| (payload.len() as u64).to_be_bytes().to_vec())())), + }; + + let first = (is_final as u8) << 7 + opcode.into_byte(); + let second = (mask.is_some() as u8) << 7 + payload_len_byte; + + let mut header_bytes = vec![first, second]; + if let Some(mut payload_len_bytes) = payload_len_bytes { + header_bytes.append(&mut payload_len_bytes) + } + if let Some(masking_bytes) = mask { + header_bytes.extend_from_slice(&masking_bytes); + + // mask the payload + let mut i = 0; + for b in &mut payload { + /* + ``` + a = b xor c + ----------------- + 0 0 0 + 1 0 1 + 1 1 0 + 0 1 1 + + if + a = b xor c + then + b = a xor c + + ``` + + When client or server get masked payload, they perform: + + for each {decoded byte} + = {payload byte} xor {masking key byte} + + Here {decoded byte} is `b` and {maksing key byte} is `masking_bytes[i]`, + So {payload byte} (that WAS sent as *masked byte*) is computed as follows. + */ + *b = *b ^ masking_bytes[i]; + + /* + i = if i == 3 {0} else {i + 1}; + */ + i = (i + 1) & 0b00000011; + } } [header_bytes, payload].concat() diff --git a/ohkami/src/x_websocket/message.rs b/ohkami/src/x_websocket/message.rs index 26735575..2183a1a7 100644 --- a/ohkami/src/x_websocket/message.rs +++ b/ohkami/src/x_websocket/message.rs @@ -4,6 +4,7 @@ use super::{frame::{Frame, OpCode, CloseCode}, websocket::Config}; const PING_PONG_PAYLOAD_LIMIT: usize = 125; + pub enum Message { Text (String), Binary(Vec), @@ -40,39 +41,48 @@ const _: (/* `From` impls */) = { }; impl Message { - pub(super) async fn write(self, + pub(crate) fn into_frame(self) -> Frame { + let (opcode, payload) = match self { + Message::Text (text) => (OpCode::Text, text.into_bytes()), + Message::Binary(bytes) => (OpCode::Binary, bytes), + Message::Ping(mut bytes) => { + bytes.truncate(PING_PONG_PAYLOAD_LIMIT); + (OpCode::Ping, bytes) + } + Message::Pong(mut bytes) => { + bytes.truncate(PING_PONG_PAYLOAD_LIMIT); + (OpCode::Ping, bytes) + } + Message::Close(close_frame) => { + let payload = close_frame + .map(|CloseFrame { code, reason }| { + let code = code.into_bytes(); + let reason = reason.as_ref().map(|cow| cow.as_bytes()).unwrap_or(&[]); + [&code, reason].concat() + }).unwrap_or(Vec::new()); + (OpCode::Close, payload) + } + }; + + Frame { is_final: false, mask: None, opcode, payload } + } + + pub(crate) async fn write(self, stream: &mut (impl AsyncWriter + Unpin), config: &Config, ) -> Result { - fn into_frame(message: Message) -> Frame { - let (opcode, payload) = match message { - Message::Text (text) => (OpCode::Text, text.into_bytes()), - Message::Binary(bytes) => (OpCode::Binary, bytes), - - Message::Ping(mut bytes) => { - bytes.truncate(PING_PONG_PAYLOAD_LIMIT); - (OpCode::Ping, bytes) - } - Message::Pong(mut bytes) => { - bytes.truncate(PING_PONG_PAYLOAD_LIMIT); - (OpCode::Ping, bytes) - } - - Message::Close(close_frame) => { - let payload = close_frame - .map(|CloseFrame { code, reason }| { - let code = code.into_bytes(); - let reason = reason.as_ref().map(|cow| cow.as_bytes()).unwrap_or(&[]); - [&code, reason].concat() - }).unwrap_or(Vec::new()); - (OpCode::Close, payload) - } - }; - - Frame { is_final: false, mask: None, opcode, payload } - } + self.into_frame().write_unmasked(stream, config).await + } - into_frame(self).write_to(stream, config).await + #[cfg(test) /* used in `crate::layer6_testing::x_websokcket::TestWebSocket::write` */ ] + pub(crate) async fn masking_write(self, + stream: &mut (impl AsyncWriter + Unpin), + config: &Config, + mask: [u8; 4], + ) -> Result { + let mut frame = self.into_frame(); + frame.mask = Some(mask); + frame.write_masked(stream, config).await } } diff --git a/ohkami/src/x_websocket/mod.rs b/ohkami/src/x_websocket/mod.rs index ddcb0089..ec874885 100644 --- a/ohkami/src/x_websocket/mod.rs +++ b/ohkami/src/x_websocket/mod.rs @@ -15,9 +15,6 @@ pub use { }; #[cfg(test)] pub(crate) use websocket::{ Config, - send, - write, - flush, }; pub(crate) use { diff --git a/ohkami/src/x_websocket/websocket.rs b/ohkami/src/x_websocket/websocket.rs index 39366407..36e9238c 100644 --- a/ohkami/src/x_websocket/websocket.rs +++ b/ohkami/src/x_websocket/websocket.rs @@ -48,7 +48,8 @@ impl WebSocket { } // ============================================================================= -pub(crate) async fn send(message:Message, +pub(super) async fn send( + message: Message, stream: &mut (impl AsyncWriter + Unpin), config: &Config, n_buffered: &mut usize, @@ -57,7 +58,8 @@ pub(crate) async fn send(message:Message, flush(stream, n_buffered).await?; Ok(()) } -pub(crate) async fn write(message:Message, +pub(super) async fn write( + message: Message, stream: &mut (impl AsyncWriter + Unpin), config: &Config, n_buffered: &mut usize, @@ -75,7 +77,7 @@ pub(crate) async fn write(message:Message, Ok(n) } -pub(crate) async fn flush( +pub(super) async fn flush( stream: &mut (impl AsyncWriter + Unpin), n_buffered: &mut usize, ) -> Result<(), Error> {