Skip to content

Commit

Permalink
introduce mews
Browse files Browse the repository at this point in the history
  • Loading branch information
kanarus committed Oct 29, 2024
1 parent 92e4458 commit 4cdad6e
Show file tree
Hide file tree
Showing 15 changed files with 195 additions and 1,063 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,12 @@ Currently, WebSocket on `rt_worker` is *not* supported.

```rust,no_run
use ohkami::prelude::*;
use ohkami::ws::{WebSocketContext, WebSocket, Message};
use ohkami::ws::{WebSocketContext, WebSocket, Message, Connection};
async fn echo_text(c: WebSocketContext<'_>) -> WebSocket {
c.connect(|mut conn| async move {
async fn echo_text(ctx: WebSocketContext<'_>) -> WebSocket {
ctx.upgrade(|mut conn: Connection| async move {
while let Ok(Some(Message::Text(text))) = conn.recv().await {
conn.send(Message::Text(text)).await.expect("Failed to send text");
conn.send(text).await.expect("failed to send text");
}
})
}
Expand Down
93 changes: 45 additions & 48 deletions examples/websocket/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,14 @@
use ohkami::prelude::*;
use ohkami::ws::{WebSocketContext, WebSocket, Message};


#[derive(Clone)]
struct Logger;
impl FangAction for Logger {
async fn fore<'a>(&'a self, req: &'a mut Request) -> Result<(), Response> {
Ok(println!("\n{req:#?}"))
}

async fn back<'a>(&'a self, res: &'a mut Response) {
println!("\n{res:#?}")
}
}
use ohkami::ws::{WebSocketContext, WebSocket, Message, Connection, ReadHalf, WriteHalf};


async fn echo_text(c: WebSocketContext<'_>) -> WebSocket {
c.connect(|mut ws| async move {
while let Ok(Some(Message::Text(text))) = ws.recv().await {
c.upgrade(|mut c: Connection| async move {
while let Ok(Some(Message::Text(text))) = c.recv().await {
if text == "close" {
break
}
ws.send(Message::Text(text)).await.expect("Failed to send text");
c.send(text).await.expect("Failed to send text");
}
})
}
Expand All @@ -40,15 +27,15 @@ struct EchoTextSession<'ws> {
}
impl IntoResponse for EchoTextSession<'_> {
fn into_response(self) -> Response {
self.ctx.connect(|mut ws| async move {
ws.send(Message::Text(format!("Hello, {}!", self.name))).await.expect("failed to send");
self.ctx.upgrade(|mut c: Connection| async move {
c.send(format!("Hello, {}!", self.name)).await.expect("failed to send");

while let Ok(Some(Message::Text(text))) = ws.recv().await {
if text == "close" {
break
}
ws.send(Message::Text(text)).await.expect("failed to send text");
while let Ok(Some(Message::Text(text))) = c.recv().await {
if text == "close" {
break
}
c.send(text).await.expect("failed to send text");
}
}).into_response()
}
}
Expand All @@ -57,8 +44,7 @@ impl IntoResponse for EchoTextSession<'_> {
async fn echo_text_3(name: String,
ctx: WebSocketContext<'_>
) -> WebSocket {
ctx.connect(|ws| async {
let (mut r, mut w) = ws.split();
ctx.upgrade(|mut r: ReadHalf, mut w: WriteHalf| async {
let incoming = std::sync::Arc::new(tokio::sync::RwLock::new(std::collections::VecDeque::new()));
let (close_tx, close_rx) = tokio::sync::watch::channel(());

Expand All @@ -74,7 +60,7 @@ async fn echo_text_3(name: String,
}
}
}),
tokio::task::spawn({
tokio::spawn({
let (mut close, incoming) = (close_rx.clone(), incoming.clone());
async move {
loop {
Expand All @@ -93,15 +79,15 @@ async fn echo_text_3(name: String,
}
}
}),
tokio::task::spawn({
tokio::spawn({
let (name, close, closer, incoming) = (name, close_rx.clone(), close_tx, incoming.clone());
async move {
w.send(Message::Text(format!("Hello, {name}!"))).await.expect("failed to send");

loop {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;

w.send(Message::Text(format!("tick"))).await.expect("failed to send");
w.send("tick").await.expect("failed to send");

let poped = {
let mut incoming = incoming.write().await;
Expand All @@ -111,7 +97,7 @@ async fn echo_text_3(name: String,
if let Some(text) = poped {
if text == "close" {closer.send(()).unwrap()}

w.send(Message::Text(text)).await.expect("failed to send");
w.send(text).await.expect("failed to send");
}

if !close.has_changed().is_ok_and(|yes|!yes) {println!("break 3"); break}
Expand All @@ -123,31 +109,42 @@ async fn echo_text_3(name: String,
}


#[tokio::main]
async fn main() {
Ohkami::with(Logger, (
"/".Dir("./template").omit_extensions([".html"]),
"/echo1".GET(echo_text),
"/echo2/:name".GET(echo_text_2),
"/echo3/:name".GET(echo_text_3),
"/echo4/:name".GET(echo4),
)).howl("localhost:3030").await
}


async fn echo4((name,): (String,), ws: WebSocketContext<'_>) -> WebSocket {
ws.connect(|mut c| async {
/* spawn but not await handle */
tokio::task::spawn(async move {
async fn echo4(name: String, ws: WebSocketContext<'_>) -> WebSocket {
ws.upgrade(|mut c: Connection| async {
/* spawn but not join the handle */
tokio::spawn(async move {
#[cfg(feature="DEBUG")] println!("\n{c:#?}");

c.send(Message::Text(name)).await.expect("failed to send");
c.send(name).await.expect("failed to send");
while let Ok(Some(Message::Text(text))) = c.recv().await {
#[cfg(feature="DEBUG")] println!("\n{c:#?}");

if dbg!(&text) == "close" {break}
c.send(Message::Text(text)).await.expect("failed to send");
c.send(text).await.expect("failed to send");
}
});
})
}


#[tokio::main]
async fn main() {
#[derive(Clone)]
struct Logger;
impl FangAction for Logger {
async fn fore<'a>(&'a self, req: &'a mut Request) -> Result<(), Response> {
Ok(println!("\n{req:#?}"))
}
async fn back<'a>(&'a self, res: &'a mut Response) {
println!("\n{res:#?}")
}
}

Ohkami::with(Logger, (
"/".Dir("./template").omit_extensions([".html"]),
"/echo1".GET(echo_text),
"/echo2/:name".GET(echo_text_2),
"/echo3/:name".GET(echo_text_3),
"/echo4/:name".GET(echo4),
)).howl("localhost:3030").await
}
13 changes: 7 additions & 6 deletions ohkami/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,27 +33,28 @@ serde = { workspace = true }
serde_json = { version = "1.0" }
rustc-hash = { version = "2.0" }

base64 = { version = "0.22" }
hmac = { version = "0.12", default-features = false }
sha2 = { version = "0.10", default-features = false }
sha1 = { version = "0.10", optional = true, default-features = false }

num_cpus = { version = "1.16", optional = true }
futures-util = { version = "0.3", optional = true, default-features = false, features = ["io", "async-await-macro"] }
mews = { git = "https://github.com/ohkami-rs/mews", optional = true }


[features]
default = ["testing"]

rt_tokio = ["__rt__", "__rt_native__", "dep:tokio", "tokio/io-util", "tokio/macros", "ohkami_lib/signal"]
rt_async-std = ["__rt__", "__rt_native__", "dep:async-std", "dep:futures-util", "ohkami_lib/signal"]
rt_smol = ["__rt__", "__rt_native__", "dep:smol", "dep:futures-util", "ohkami_lib/signal"]
rt_glommio = ["__rt__", "__rt_native__", "dep:glommio", "dep:futures-util", "dep:num_cpus", "ohkami_lib/signal"]
rt_tokio = ["__rt__", "__rt_native__", "dep:tokio", "tokio/io-util", "tokio/macros", "ohkami_lib/signal", "mews?/tokio"]
rt_async-std = ["__rt__", "__rt_native__", "dep:async-std", "dep:futures-util", "ohkami_lib/signal", "mews?/async-std"]
rt_smol = ["__rt__", "__rt_native__", "dep:smol", "dep:futures-util", "ohkami_lib/signal", "mews?/smol"]
rt_glommio = ["__rt__", "__rt_native__", "dep:glommio", "dep:futures-util", "dep:num_cpus", "ohkami_lib/signal", "mews?/glommio"]
rt_worker = ["__rt__", "dep:worker", "ohkami_macros/worker"]

nightly = []
testing = []
sse = ["ohkami_lib/stream"]
ws = ["dep:sha1"]
ws = ["dep:mews"]

##### internal #####
__rt__ = []
Expand Down
3 changes: 2 additions & 1 deletion ohkami/src/fang/builtin/basicauth.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::prelude::*;
use ::base64::engine::{Engine as _, general_purpose::STANDARD as BASE64};


/// # Builtin fang for Basic Auth
Expand Down Expand Up @@ -74,7 +75,7 @@ const _: () = {
.strip_prefix("Basic ").ok_or_else(unauthorized)?;

let credential = String::from_utf8(
ohkami_lib::base64::decode(credential_base64.as_bytes())
BASE64.decode(credential_base64).map_err(|_| unauthorized())?
).map_err(|_| unauthorized())?;

Ok(credential)
Expand Down
38 changes: 22 additions & 16 deletions ohkami/src/fang/builtin/jwt.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#![allow(non_snake_case, non_camel_case_types)]

use crate::{Fang, FangProc, IntoResponse, Request, Response};
use std::{borrow::Cow, marker::PhantomData};
use serde::{Serialize, Deserialize};
use ohkami_lib::base64;
use crate::{Fang, FangProc, IntoResponse, Request, Response};
use base64::engine::{Engine as _, general_purpose::URL_SAFE as BASE64URL};


/// # Builtin fang and helper for JWT config
Expand Down Expand Up @@ -236,9 +236,9 @@ impl<Payload: Serialize> JWT<Payload> {
/// Build JWT token with the payload.
#[inline] pub fn issue(self, payload: Payload) -> JWTToken {
let unsigned_token = {
let mut ut = base64::encode_url(self.header_str());
let mut ut = BASE64URL.encode(self.header_str());
ut.push('.');
ut.push_str(&base64::encode_url(::serde_json::to_vec(&payload).expect("Failed to serialze payload")));
ut.push_str(&BASE64URL.encode(::serde_json::to_vec(&payload).expect("Failed to serialze payload")));
ut
};

Expand All @@ -247,17 +247,17 @@ impl<Payload: Serialize> JWT<Payload> {
use ::hmac::{Hmac, Mac};

match &self.alg {
VerifyingAlgorithm::HS256 => base64::encode_url({
VerifyingAlgorithm::HS256 => BASE64URL.encode({
let mut s = Hmac::<Sha256>::new_from_slice(self.secret.as_bytes()).unwrap();
s.update(unsigned_token.as_bytes());
s.finalize().into_bytes()
}),
VerifyingAlgorithm::HS384 => base64::encode_url({
VerifyingAlgorithm::HS384 => BASE64URL.encode({
let mut s = Hmac::<Sha384>::new_from_slice(self.secret.as_bytes()).unwrap();
s.update(unsigned_token.as_bytes());
s.finalize().into_bytes()
}),
VerifyingAlgorithm::HS512 => base64::encode_url({
VerifyingAlgorithm::HS512 => BASE64URL.encode({
let mut s = Hmac::<Sha512>::new_from_slice(self.secret.as_bytes()).unwrap();
s.update(unsigned_token.as_bytes());
s.finalize().into_bytes()
Expand Down Expand Up @@ -289,17 +289,22 @@ impl<Payload: for<'de> Deserialize<'de>> JWT<Payload> {

const UNAUTHORIZED_MESSAGE: &str = "missing or malformed jwt";

type Header = ::serde_json::Value;
type Payload = ::serde_json::Value;

let mut parts = (self.get_token)(req)
.ok_or_else(|| Response::Unauthorized().with_text(UNAUTHORIZED_MESSAGE))?
.split('.');

type Header = ::serde_json::Value;
type Payload = ::serde_json::Value;
fn part_value(part: &str) -> Result<::serde_json::Value, Response> {
let part = BASE64URL.decode(part)
.map_err(|_| Response::BadRequest().with_text("invalid base64"))?;
::serde_json::from_slice(&part)
.map_err(|_| Response::BadRequest().with_text("invalid json"))
}

let header_part = parts.next()
.ok_or_else(|| Response::BadRequest())?;
let header: Header = ::serde_json::from_slice(&base64::decode_url(header_part))
.map_err(|_| Response::InternalServerError())?;
let header: Header = part_value(header_part)?;
if header.get("typ").is_some_and(|typ| !typ.as_str().unwrap_or_default().eq_ignore_ascii_case("JWT")) {
return Err(Response::BadRequest())
}
Expand All @@ -312,8 +317,7 @@ impl<Payload: for<'de> Deserialize<'de>> JWT<Payload> {

let payload_part = parts.next()
.ok_or_else(|| Response::BadRequest())?;
let payload: Payload = ::serde_json::from_slice(&base64::decode_url(payload_part))
.map_err(|_| Response::InternalServerError())?;
let payload: Payload = part_value(payload_part)?;
let now = crate::util::unix_timestamp();
if payload.get("nbf").is_some_and(|nbf| nbf.as_u64().unwrap_or(0) > now) {
return Err(Response::Unauthorized().with_text(UNAUTHORIZED_MESSAGE))
Expand All @@ -325,8 +329,10 @@ impl<Payload: for<'de> Deserialize<'de>> JWT<Payload> {
return Err(Response::Unauthorized().with_text(UNAUTHORIZED_MESSAGE))
}

let signature_part = parts.next().ok_or_else(|| Response::BadRequest())?;
let requested_signature = base64::decode_url(signature_part);
let signature_part = parts.next()
.ok_or_else(|| Response::BadRequest())?;
let requested_signature = BASE64URL.decode(signature_part)
.map_err(|_| Response::BadRequest().with_text("invalid base64"))?;

let is_correct_signature = {
use ::sha2::{Sha256, Sha384, Sha512};
Expand Down
7 changes: 4 additions & 3 deletions ohkami/src/response/content.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use ohkami_lib::CowSlice;
use ohkami_lib::Stream;

#[cfg(all(feature="ws", feature="__rt_native__"))]
use crate::ws::{Config, Handler};
use mews::WebSocket;


pub enum Content {
Expand All @@ -16,8 +16,9 @@ pub enum Content {
Stream(std::pin::Pin<Box<dyn Stream<Item = Result<String, String>> + Send>>),

#[cfg(all(feature="ws", feature="__rt_native__"))]
WebSocket((Config, Handler)),
} const _: () = {
WebSocket(WebSocket),
}
const _: () = {
impl Default for Content {
fn default() -> Self {
Self::None
Expand Down
Loading

0 comments on commit 4cdad6e

Please sign in to comment.