diff --git a/ohkami/Cargo.toml b/ohkami/Cargo.toml index 0afdaa37..d84d9f7d 100644 --- a/ohkami/Cargo.toml +++ b/ohkami/Cargo.toml @@ -36,9 +36,9 @@ DEBUG = [ "tokio?/macros", "async-std?/attributes", ] -#default = [ -# "rt_tokio", -# #"rt_async-std", -# "DEBUG", -# "nightly", -#] \ No newline at end of file +default = [ + "rt_tokio", + #"rt_async-std", + "DEBUG", + "nightly", +] \ No newline at end of file diff --git a/ohkami/src/x_utils/jwt.rs b/ohkami/src/x_utils/jwt.rs index deb7e434..6eb43f4b 100644 --- a/ohkami/src/x_utils/jwt.rs +++ b/ohkami/src/x_utils/jwt.rs @@ -1,9 +1,11 @@ -#![allow(non_snake_case)] +#![allow(non_snake_case, non_camel_case_types)] /// --- /// /// ## Fang and generator for JWT. /// +/// **NOTE**:In current version, this only supports `HMAC-SHA256` as the verifying algorithm (and select it by default). +/// ///
/// /// #### Example, a tiny project @@ -20,7 +22,7 @@ /// // Get secret key from somewhere, `.env` file for example /// let secret = "MY_VERY_SECRET_KEY"; /// -/// JWT(secret) +/// JWT(secret) // Using HMAC-SHA256 (by default) /// } /// ``` ///
@@ -71,8 +73,6 @@ /// use crate::config::my_jwt_config; // <-- used as a fang /// /// fn profile_ohkami() -> Ohkami { -/// let my_secret_key = todo!(); -/// /// Ohkami::with(( /// // Verifies JWT in requests' `Authorization` header /// // and early returns error response if it's missing or malformed. @@ -103,65 +103,24 @@ pub use internal::JWT; mod internal { use crate::layer0_lib::{base64, HMAC_SHA256}; - use crate::{IntoFang, Fang, Context, Request}; + use crate::{IntoFang, Fang, Context, Request, Response}; pub struct JWT { secret: String, - alg: VerifyingAlgorithm, } impl JWT { pub fn new(secret: impl Into) -> Self { Self { secret: secret.into(), - alg: VerifyingAlgorithm::default(), } } } - macro_rules! VerifyingAlgorithm { - { $( $alg:ident, )+ @default: $default:ident } => { - enum VerifyingAlgorithm { - $( - $alg, - )* - } - impl Default for VerifyingAlgorithm { - fn default() -> Self { - VerifyingAlgorithm::$default - } - } - impl VerifyingAlgorithm { - const fn as_str(&self) -> &'static str { - match self { - $( - Self::$alg => stringify!($alg), - )* - } - } - } - - impl JWT { - $( - pub fn $alg(mut self) -> Self { - self.alg = VerifyingAlgorithm::$alg; - self - } - )* - } - }; - } VerifyingAlgorithm! { - HS256, - HS384, - HS512, - - @default: HS256 - } - impl JWT { pub fn issue(self, payload: impl ::serde::Serialize) -> String { let unsigned_token = { - let mut ut = base64::encode_url([b"{\"typ\":\"JWT\",\"alg\":\"", self.alg.as_str().as_bytes(), b"\"}"].concat()); + let mut ut = base64::encode_url("{\"typ\":\"JWT\",\"alg\":\"HS256\""); ut.push('.'); ut.push_str(&base64::encode_url(::serde_json::to_vec(&payload).expect("Failed to serialze payload"))); ut @@ -180,64 +139,82 @@ mod internal { } } - impl IntoFang for JWT { - fn into_fang(self) -> Fang { + impl JWT { + pub(crate/* for test */) fn verify(&self, c: &Context, req: &Request) -> Result<(), Response> { const UNAUTHORIZED_MESSAGE: &str = "missing or malformed jwt"; type Header = ::serde_json::Value; type Payload = ::serde_json::Value; - Fang(move |c: &Context, req: &Request| { - let mut parts = req - .headers.Authorization().ok_or_else(|| c.Unauthorized().text(UNAUTHORIZED_MESSAGE))? - .strip_prefix("Bearer ").ok_or_else(|| c.BadRequest())? - .split('.'); - - let header_part = parts.next() - .ok_or_else(|| c.BadRequest())?; - let header: Header = ::serde_json::from_slice(&base64::decode_url(header_part)) - .map_err(|_| c.InternalServerError())?; - if header.get("typ").is_some_and(|typ| typ.as_str().unwrap_or_default().eq_ignore_ascii_case("JWT")) { - return Err(c.BadRequest()) - } - if header.get("cty").is_some_and(|cty| cty.as_str().unwrap_or_default().eq_ignore_ascii_case("JWT")) { - return Err(c.BadRequest()) - } - if header.get("alg").ok_or_else(|| c.BadRequest())? != self.alg.as_str() { - return Err(c.BadRequest()) - } - - let payload_part = parts.next() - .ok_or_else(|| c.BadRequest())?; - let payload: Payload = ::serde_json::from_slice(&base64::decode_url(payload_part)) - .map_err(|_| c.InternalServerError())?; - let now = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_secs(); - if payload.get("nbf").is_some_and(|nbf| nbf.as_u64().unwrap_or_default() > now) { - return Err(c.Unauthorized().text(UNAUTHORIZED_MESSAGE)) - } - if payload.get("exp").is_some_and(|exp| exp.as_u64().unwrap_or_default() <= now) { - return Err(c.Unauthorized().text(UNAUTHORIZED_MESSAGE)) - } - if payload.get("iat").is_some_and(|iat| iat.as_u64().unwrap_or_default() > now) { - return Err(c.Unauthorized().text(UNAUTHORIZED_MESSAGE)) - } - - let signature_part = parts.next() - .ok_or_else(|| c.BadRequest())?; - let requested_signature = base64::decode_url(signature_part); - let actual_signature = { - let mut hs = HMAC_SHA256::new(&self.secret); - hs.write(header_part.as_bytes()); - hs.write(b"."); - hs.write(payload_part.as_bytes()); - hs.sum() - }; - if requested_signature != actual_signature { - return Err(c.Unauthorized().text(UNAUTHORIZED_MESSAGE)) - } - - Ok(()) - }) + let mut parts = req + .headers.Authorization().ok_or_else(|| c.Unauthorized().text(UNAUTHORIZED_MESSAGE))? + .strip_prefix("Bearer ").ok_or_else(|| c.BadRequest())? + .split('.'); + + let header_part = parts.next() + .ok_or_else(|| c.BadRequest())?; + let header: Header = ::serde_json::from_slice(&base64::decode_url(header_part)) + .map_err(|_| c.InternalServerError())?; + if header.get("typ").is_some_and(|typ| typ.as_str().unwrap_or_default().eq_ignore_ascii_case("JWT")) { + return Err(c.BadRequest()) + } + if header.get("cty").is_some_and(|cty| cty.as_str().unwrap_or_default().eq_ignore_ascii_case("JWT")) { + return Err(c.BadRequest()) + } + if header.get("alg").ok_or_else(|| c.BadRequest())? != "HS256" { + return Err(c.BadRequest()) + } + + let payload_part = parts.next() + .ok_or_else(|| c.BadRequest())?; + let payload: Payload = ::serde_json::from_slice(&base64::decode_url(payload_part)) + .map_err(|_| c.InternalServerError())?; + let now = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_secs(); + if payload.get("nbf").is_some_and(|nbf| nbf.as_u64().unwrap_or_default() > now) { + return Err(c.Unauthorized().text(UNAUTHORIZED_MESSAGE)) + } + if payload.get("exp").is_some_and(|exp| exp.as_u64().unwrap_or_default() <= now) { + return Err(c.Unauthorized().text(UNAUTHORIZED_MESSAGE)) + } + if payload.get("iat").is_some_and(|iat| iat.as_u64().unwrap_or_default() > now) { + return Err(c.Unauthorized().text(UNAUTHORIZED_MESSAGE)) + } + + let signature_part = parts.next() + .ok_or_else(|| c.BadRequest())?; + let requested_signature = base64::decode_url(signature_part); + let actual_signature = { + let mut hs = HMAC_SHA256::new(&self.secret); + hs.write(header_part.as_bytes()); + hs.write(b"."); + hs.write(payload_part.as_bytes()); + hs.sum() + }; + if requested_signature != actual_signature { + return Err(c.Unauthorized().text(UNAUTHORIZED_MESSAGE)) + } + + Ok(()) + } + } + impl IntoFang for JWT { + fn into_fang(self) -> Fang { + Fang(move |c: &Context, req: &Request| self.verify(c, req)) + } + } +} + + + + +#[cfg(test)] mod test { + use super::JWT; + use serde_json::json; + + #[test] fn test_jwt_issue() { + assert_eq! { + JWT("secret").issue(json!({"name":"kanarus","id":42,"iat":1516239022})), + "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJuYW1lIjoia2FuYXJ1cyIsImlkIjo0MiwiaWF0IjoxNTE2MjM5MDIyfQ.1zMqW4iyzBih6lVeUfKf_0mIgnvwSm1bxerypEhbxak" } } }