-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(websocket): Add jwt module to ws connection
- Loading branch information
1 parent
305433b
commit 731d5b5
Showing
10 changed files
with
950 additions
and
106 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
use reqwest::StatusCode; | ||
use thiserror::Error; | ||
|
||
#[derive(Error, Debug)] | ||
pub enum AuthError { | ||
#[error("Invalid token header: {0}")] | ||
InvalidTokenHeader(String), | ||
|
||
#[error("No `kid` found in token header")] | ||
NoKidInTokenHeader, | ||
|
||
#[error("No matching JWK found for kid: {0}")] | ||
NoMatchingJwk(String), | ||
|
||
#[error("Invalid JWK key: {0}")] | ||
InvalidJwkKey(String), | ||
|
||
#[error("Token validation failed: {0}")] | ||
TokenValidationFailed(String), | ||
|
||
#[error("Failed to fetch JWKS: {0}")] | ||
FailedToFetchJwks(String), | ||
|
||
#[error("HTTP request failed with status: {0}")] | ||
HttpRequestFailed(StatusCode), | ||
|
||
#[error("Unexpected error: {0}")] | ||
Unexpected(String), | ||
} | ||
|
||
pub type Result<T> = std::result::Result<T, AuthError>; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,221 @@ | ||
use super::error::{AuthError, Result}; | ||
use super::types::{Claims, Jwk, Jwks}; | ||
use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, TokenData, Validation}; | ||
use reqwest::Client; | ||
use std::collections::HashMap; | ||
use std::sync::Arc; | ||
use tokio::sync::RwLock; | ||
|
||
#[derive(Debug, Clone)] | ||
pub struct JWTProvider { | ||
pub iss: String, | ||
pub jwks_uri: String, | ||
pub aud: Vec<String>, | ||
pub alg: Algorithm, | ||
pub ttl: i64, | ||
} | ||
|
||
impl JWTProvider { | ||
pub fn new(iss: String, jwks_uri: String, aud: Vec<String>, alg: Algorithm, ttl: i64) -> Self { | ||
Self { | ||
iss, | ||
jwks_uri, | ||
aud, | ||
alg, | ||
ttl, | ||
} | ||
} | ||
async fn fetch_jwks(&self) -> Result<HashMap<String, Jwk>> { | ||
let response = Client::new() | ||
.get(&self.jwks_uri) | ||
.send() | ||
.await | ||
.map_err(|e| AuthError::FailedToFetchJwks(e.to_string()))?; | ||
|
||
if !response.status().is_success() { | ||
return Err(AuthError::HttpRequestFailed(response.status())); | ||
} | ||
|
||
let jwks: Jwks = response | ||
.json() | ||
.await | ||
.map_err(|e| AuthError::FailedToFetchJwks(e.to_string()))?; | ||
|
||
Ok(jwks | ||
.keys | ||
.into_iter() | ||
.map(|key| (key.kid.clone(), key)) | ||
.collect()) | ||
} | ||
|
||
async fn validate_jwt( | ||
&self, | ||
token: &str, | ||
jwks: Arc<RwLock<HashMap<String, Jwk>>>, | ||
) -> Result<Claims> { | ||
// get token header | ||
let header = | ||
decode_header(token).map_err(|e| AuthError::InvalidTokenHeader(e.to_string()))?; | ||
let kid = header.kid.ok_or(AuthError::NoKidInTokenHeader)?; | ||
|
||
// use kid to get jwk | ||
let jwks_read = jwks.read().await; | ||
let jwk = jwks_read | ||
.get(&kid) | ||
.ok_or_else(|| AuthError::NoMatchingJwk(kid.clone()))?; | ||
|
||
// generate decoding key | ||
let decoding_key = DecodingKey::from_rsa_components(&jwk.n, &jwk.e) | ||
.map_err(|e| AuthError::InvalidJwkKey(e.to_string()))?; | ||
|
||
// validate token | ||
let mut validation = Validation::new(self.alg); | ||
validation.set_audience(&self.aud); | ||
validation.set_issuer(&[&self.iss]); | ||
|
||
let token_data: TokenData<Claims> = decode::<Claims>(token, &decoding_key, &validation) | ||
.map_err(|err| AuthError::TokenValidationFailed(err.to_string()))?; | ||
|
||
Ok(token_data.claims) | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::super::types::CustomClaims; | ||
use super::*; | ||
use base64::{encode_config, URL_SAFE_NO_PAD}; | ||
use chrono::{Duration, Utc}; | ||
use jsonwebtoken::{encode, EncodingKey, Header}; | ||
use rand::rngs::OsRng; | ||
use rsa::traits::PublicKeyParts; | ||
use rsa::{pkcs1::LineEnding, pkcs8::EncodePrivateKey, RsaPrivateKey, RsaPublicKey}; | ||
use serde::{Deserialize, Serialize}; | ||
use wiremock::matchers::{method, path}; | ||
use wiremock::{Mock, MockServer, ResponseTemplate}; | ||
|
||
#[derive(Debug, Serialize, Deserialize)] | ||
pub struct TestClaims { | ||
pub sub: String, | ||
pub exp: usize, | ||
pub iss: String, | ||
pub aud: Vec<String>, | ||
pub iat: usize, | ||
pub custom_claims: CustomClaims, | ||
} | ||
|
||
#[tokio::test] | ||
async fn test_fetch_jwks() { | ||
let mock_server = MockServer::start().await; | ||
|
||
let jwks_response = r#" | ||
{ | ||
"keys": [ | ||
{ | ||
"kty": "RSA", | ||
"kid": "m3eP7rPCsf3hgtWU1seMZ", | ||
"use": "sig", | ||
"n": "xokzqjjolOLqzFYs7bNVyTafNmfWT-9_i8dgV4OcnBaKP8Es5O5u4dHxx6pZoZYxUyZS0-IOc5w3Em3g2PYUS4TmwDD65nrNZEJwr-CbxwbUwEGkYtTNcQML1LHhkRnGvxdwl_3st5HFBPRHnU5y8hMgke_nMTqrUTLaS7Px7v5lpt_nNH60FnfPJqcgD6pG4TG510HhnLV0ELDCqD8F79omuuqwqgntQLr-XR7mw_2PfMV8QdMx-kcwtVVhBeM5hr-KdKAQ-56MbU5GAke7kZJt94_2DHv8wpmQtlmKuIOEBFJNoS3prisdXmlmP6qKDSGufRNg3x5wJ-di0IlIeQ", | ||
"e": "AQAB", | ||
"alg": "RS256" | ||
} | ||
] | ||
}"#; | ||
|
||
Mock::given(method("GET")) | ||
.and(path("/.well-known/jwks.json")) | ||
.respond_with(ResponseTemplate::new(200).set_body_string(jwks_response)) | ||
.mount(&mock_server) | ||
.await; | ||
|
||
let provider = JWTProvider { | ||
iss: "https://issuer.com".to_string(), | ||
jwks_uri: format!("{}/.well-known/jwks.json", &mock_server.uri()), | ||
aud: vec!["my_audience".to_string()], | ||
alg: Algorithm::RS256, | ||
ttl: 3600, | ||
}; | ||
|
||
let jwks = provider.fetch_jwks().await.unwrap(); | ||
|
||
assert_eq!(jwks.len(), 1); | ||
assert!(jwks.contains_key("m3eP7rPCsf3hgtWU1seMZ")); | ||
} | ||
|
||
#[tokio::test] | ||
async fn test_validate_jwt() { | ||
let mut rng = OsRng; | ||
let private_key = RsaPrivateKey::new(&mut rng, 2048).expect("Failed to generate a key"); | ||
let public_key = RsaPublicKey::from(&private_key); | ||
|
||
let n_base64 = encode_config(public_key.n().to_bytes_be(), URL_SAFE_NO_PAD); | ||
let e_base64 = encode_config(public_key.e().to_bytes_be(), URL_SAFE_NO_PAD); | ||
|
||
let mock_server = MockServer::start().await; | ||
let jwks_response = format!( | ||
r#" | ||
{{ | ||
"keys": [ | ||
{{ | ||
"kty": "RSA", | ||
"kid": "m3eP7rPCsf3hgtWU1seMZ", | ||
"use": "sig", | ||
"n": "{}", | ||
"e": "{}", | ||
"alg": "RS256" | ||
}} | ||
] | ||
}}"#, | ||
n_base64, e_base64 | ||
); | ||
|
||
Mock::given(method("GET")) | ||
.and(path("/.well-known/jwks.json")) | ||
.respond_with(ResponseTemplate::new(200).set_body_string(jwks_response)) | ||
.mount(&mock_server) | ||
.await; | ||
|
||
let now = Utc::now(); | ||
let my_claims = Claims { | ||
sub: "test_sub".to_string(), | ||
exp: (now + Duration::hours(1)).timestamp() as usize, | ||
iss: "https://issuer.com".to_string(), | ||
aud: vec!["my_audience".to_string()], | ||
iat: now.timestamp() as usize, | ||
custom_claims: CustomClaims { | ||
name: Some("test_user".to_string()), | ||
nickname: Some("test_nickname".to_string()), | ||
email: Some("[email protected]".to_string()), | ||
email_verified: Some(true), | ||
}, | ||
}; | ||
let encoding_key = | ||
EncodingKey::from_rsa_pem(private_key.to_pkcs8_pem(LineEnding::LF).unwrap().as_bytes()) | ||
.expect("Invalid private key"); | ||
let mut header = Header::new(Algorithm::RS256); | ||
header.kid = Some("m3eP7rPCsf3hgtWU1seMZ".to_string()); | ||
let token = encode(&header, &my_claims, &encoding_key).expect("Failed to encode token"); | ||
|
||
let decoding_key = DecodingKey::from_rsa_components(&n_base64, &e_base64) | ||
.expect("Failed to create decoding key"); | ||
let mut validation = Validation::new(Algorithm::RS256); | ||
validation.set_audience(&["my_audience"]); | ||
let token_data: TokenData<TestClaims> = | ||
decode::<TestClaims>(&token, &decoding_key, &validation) | ||
.expect("Failed to decode token"); | ||
|
||
assert_eq!( | ||
token_data.claims.custom_claims.name, | ||
Some("test_user".to_string()) | ||
); | ||
assert_eq!( | ||
token_data.claims.custom_claims.nickname, | ||
Some("test_nickname".to_string()) | ||
); | ||
assert_eq!( | ||
token_data.claims.custom_claims.email, | ||
Some("[email protected]".to_string()) | ||
); | ||
assert_eq!(token_data.claims.custom_claims.email_verified, Some(true)); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
pub mod error; | ||
pub mod jwt; | ||
pub mod types; | ||
//todo: add auth module to ws connection |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
use serde::{Deserialize, Serialize}; | ||
|
||
#[derive(Debug, Serialize, Deserialize)] | ||
pub struct Claims { | ||
pub sub: String, | ||
pub exp: usize, | ||
pub iss: String, | ||
pub aud: Vec<String>, | ||
pub iat: usize, | ||
pub custom_claims: CustomClaims, | ||
} | ||
|
||
#[derive(Debug, Deserialize, Serialize)] | ||
pub struct CustomClaims { | ||
pub name: Option<String>, | ||
pub nickname: Option<String>, | ||
pub email: Option<String>, | ||
pub email_verified: Option<bool>, | ||
} | ||
|
||
#[derive(Debug, Deserialize)] | ||
pub struct Jwk { | ||
pub kid: String, | ||
pub kty: String, | ||
pub use_: Option<String>, | ||
pub alg: Option<String>, | ||
pub n: String, | ||
pub e: String, | ||
pub x5t: Option<String>, | ||
} | ||
|
||
#[derive(Debug, Deserialize)] | ||
pub struct Jwks { | ||
pub keys: Vec<Jwk>, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
pub mod auth; | ||
pub mod persistence; | ||
pub mod socket; |