-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(websocket): Add jwt module to ws connection
- Loading branch information
1 parent
305433b
commit df1d031
Showing
11 changed files
with
1,036 additions
and
106 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
use reqwest::StatusCode; | ||
use thiserror::Error; | ||
|
||
#[derive(Error, Debug)] | ||
pub enum AuthError { | ||
#[error("Invalid token header: {0}")] | ||
InvalidTokenHeader(String), | ||
|
||
#[error("No `kid` found in token header")] | ||
NoKidInTokenHeader, | ||
|
||
#[error("No matching JWK found for kid: {0}")] | ||
NoMatchingJwk(String), | ||
|
||
#[error("Invalid JWK key: {0}")] | ||
InvalidJwkKey(String), | ||
|
||
#[error("Token validation failed: {0}")] | ||
TokenValidationFailed(String), | ||
|
||
#[error("Failed to fetch JWKS: {0}")] | ||
FailedToFetchJwks(String), | ||
|
||
#[error("HTTP request failed with status: {0}")] | ||
HttpRequestFailed(StatusCode), | ||
|
||
#[error("Unexpected error: {0}")] | ||
Unexpected(String), | ||
} | ||
|
||
pub type Result<T> = std::result::Result<T, AuthError>; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,243 @@ | ||
use super::error::{AuthError, Result}; | ||
use super::jwt_validator::JWTValidator; | ||
use super::types::{Jwk, Jwks}; | ||
use jsonwebtoken::Algorithm; | ||
use reqwest::Client; | ||
use std::collections::HashMap; | ||
use std::sync::Arc; | ||
use std::time::{Duration, Instant}; | ||
use tokio::sync::RwLock; | ||
|
||
pub struct JWTProvider { | ||
iss: String, | ||
aud: Vec<String>, | ||
alg: Algorithm, | ||
jwks_uri: String, | ||
ttl: Duration, | ||
jwks_cache: Arc<RwLock<JwksCache>>, | ||
} | ||
|
||
struct JwksCache { | ||
keys: HashMap<String, Jwk>, | ||
last_updated: Instant, // Last time the keys were updated | ||
} | ||
|
||
impl JWTProvider { | ||
// Create a new JWTProvider | ||
pub fn new(iss: String, aud: Vec<String>, alg: Algorithm, jwks_uri: String, ttl: u64) -> Self { | ||
JWTProvider { | ||
iss, | ||
aud, | ||
alg, | ||
jwks_uri, | ||
ttl: Duration::from_secs(ttl), | ||
jwks_cache: Arc::new(RwLock::new(JwksCache { | ||
keys: HashMap::new(), | ||
last_updated: Instant::now() - Duration::from_secs(ttl), | ||
})), | ||
} | ||
} | ||
|
||
// Fetch the JWKS and cache it, respecting TTL | ||
pub async fn fetch_jwks(&self) -> Result<()> { | ||
let mut cache = self.jwks_cache.write().await; | ||
|
||
// Check if cache is still valid | ||
if !cache.is_expired(self.ttl) { | ||
return Ok(()); | ||
} | ||
|
||
let response = Client::new() | ||
.get(&self.jwks_uri) | ||
.send() | ||
.await | ||
.map_err(|e| AuthError::FailedToFetchJwks(e.to_string()))?; | ||
|
||
if !response.status().is_success() { | ||
return Err(AuthError::HttpRequestFailed(response.status())); | ||
} | ||
|
||
let jwks: Jwks = response | ||
.json() | ||
.await | ||
.map_err(|e| AuthError::FailedToFetchJwks(e.to_string()))?; | ||
|
||
cache.keys.clear(); // Clear old cache | ||
for jwk in jwks.keys { | ||
cache.keys.insert(jwk.kid.clone(), jwk); | ||
} | ||
cache.last_updated = Instant::now(); // Update last_updated time | ||
|
||
Ok(()) | ||
} | ||
|
||
pub async fn get_validator(&self) -> JWTValidator { | ||
let keys = { | ||
let cache = self.jwks_cache.read().await; | ||
Arc::new(RwLock::new(cache.keys.clone())) // Clone only the keys | ||
}; | ||
|
||
JWTValidator::new(self.iss.clone(), self.aud.clone(), self.alg, keys) | ||
} | ||
} | ||
|
||
impl JwksCache { | ||
#[inline] | ||
pub fn is_expired(&self, ttl: Duration) -> bool { | ||
self.last_updated.elapsed() >= ttl | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use crate::auth::jwt_validator::JWTValidatorTrait; | ||
use crate::auth::types::Claims; | ||
|
||
use super::super::types::CustomClaims; | ||
use super::*; | ||
use base64::{encode_config, URL_SAFE_NO_PAD}; | ||
use chrono::{Duration, Utc}; | ||
use jsonwebtoken::{encode, EncodingKey, Header}; | ||
use rand::rngs::OsRng; | ||
use rsa::traits::PublicKeyParts; | ||
use rsa::{pkcs1::LineEnding, pkcs8::EncodePrivateKey, RsaPrivateKey, RsaPublicKey}; | ||
use serde::{Deserialize, Serialize}; | ||
use wiremock::matchers::{method, path}; | ||
use wiremock::{Mock, MockServer, ResponseTemplate}; | ||
|
||
#[derive(Debug, Serialize, Deserialize)] | ||
pub struct TestClaims { | ||
pub sub: String, | ||
pub exp: usize, | ||
pub iss: String, | ||
pub aud: Vec<String>, | ||
pub iat: usize, | ||
pub custom_claims: CustomClaims, | ||
} | ||
|
||
#[tokio::test] | ||
async fn test_fetch_jwks() { | ||
let mock_server = MockServer::start().await; | ||
|
||
let jwks_response = r#" | ||
{ | ||
"keys": [ | ||
{ | ||
"kty": "RSA", | ||
"kid": "m3eP7rPCsf3hgtWU1seMZ", | ||
"use": "sig", | ||
"n": "xokzqjjolOLqzFYs7bNVyTafNmfWT-9_i8dgV4OcnBaKP8Es5O5u4dHxx6pZoZYxUyZS0-IOc5w3Em3g2PYUS4TmwDD65nrNZEJwr-CbxwbUwEGkYtTNcQML1LHhkRnGvxdwl_3st5HFBPRHnU5y8hMgke_nMTqrUTLaS7Px7v5lpt_nNH60FnfPJqcgD6pG4TG510HhnLV0ELDCqD8F79omuuqwqgntQLr-XR7mw_2PfMV8QdMx-kcwtVVhBeM5hr-KdKAQ-56MbU5GAke7kZJt94_2DHv8wpmQtlmKuIOEBFJNoS3prisdXmlmP6qKDSGufRNg3x5wJ-di0IlIeQ", | ||
"e": "AQAB", | ||
"alg": "RS256" | ||
} | ||
] | ||
}"#; | ||
|
||
Mock::given(method("GET")) | ||
.and(path("/.well-known/jwks.json")) | ||
.respond_with(ResponseTemplate::new(200).set_body_string(jwks_response)) | ||
.mount(&mock_server) | ||
.await; | ||
|
||
let provider = JWTProvider::new( | ||
"https://issuer.com".to_string(), | ||
vec!["my_audience".to_string()], | ||
Algorithm::RS256, | ||
format!("{}/.well-known/jwks.json", &mock_server.uri()), | ||
3600, | ||
); | ||
|
||
provider.fetch_jwks().await.unwrap(); // Now we call fetch_jwks() directly | ||
|
||
let cache = provider.jwks_cache.read().await; | ||
assert_eq!(cache.keys.len(), 1); | ||
assert!(cache.keys.contains_key("m3eP7rPCsf3hgtWU1seMZ")); | ||
} | ||
|
||
#[tokio::test] | ||
async fn test_validate_jwt() { | ||
let mut rng = OsRng; | ||
let private_key = RsaPrivateKey::new(&mut rng, 2048).expect("Failed to generate a key"); | ||
let public_key = RsaPublicKey::from(&private_key); | ||
|
||
let n_base64 = encode_config(public_key.n().to_bytes_be(), URL_SAFE_NO_PAD); | ||
let e_base64 = encode_config(public_key.e().to_bytes_be(), URL_SAFE_NO_PAD); | ||
|
||
let mock_server = MockServer::start().await; | ||
let jwks_response = format!( | ||
r#" | ||
{{ | ||
"keys": [ | ||
{{ | ||
"kty": "RSA", | ||
"kid": "m3eP7rPCsf3hgtWU1seMZ", | ||
"use": "sig", | ||
"n": "{}", | ||
"e": "{}", | ||
"alg": "RS256" | ||
}} | ||
] | ||
}}"#, | ||
n_base64, e_base64 | ||
); | ||
|
||
Mock::given(method("GET")) | ||
.and(path("/.well-known/jwks.json")) | ||
.respond_with(ResponseTemplate::new(200).set_body_string(jwks_response)) | ||
.mount(&mock_server) | ||
.await; | ||
|
||
let now = Utc::now(); | ||
let my_claims = Claims { | ||
sub: "test_sub".to_string(), | ||
exp: (now + Duration::hours(1)).timestamp() as usize, | ||
iss: "https://issuer.com".to_string(), | ||
aud: vec!["my_audience".to_string()], | ||
iat: now.timestamp() as usize, | ||
custom_claims: CustomClaims { | ||
name: Some("test_user".to_string()), | ||
nickname: Some("test_nickname".to_string()), | ||
email: Some("[email protected]".to_string()), | ||
email_verified: Some(true), | ||
}, | ||
}; | ||
let encoding_key = | ||
EncodingKey::from_rsa_pem(private_key.to_pkcs8_pem(LineEnding::LF).unwrap().as_bytes()) | ||
.expect("Invalid private key"); | ||
let mut header = Header::new(Algorithm::RS256); | ||
header.kid = Some("m3eP7rPCsf3hgtWU1seMZ".to_string()); | ||
let token = encode(&header, &my_claims, &encoding_key).expect("Failed to encode token"); | ||
|
||
let provider = JWTProvider::new( | ||
"https://issuer.com".to_string(), | ||
vec!["my_audience".to_string()], | ||
Algorithm::RS256, | ||
format!("{}/.well-known/jwks.json", &mock_server.uri()), | ||
3600, | ||
); | ||
|
||
provider.fetch_jwks().await.unwrap(); // Fetch keys first | ||
|
||
let validator = provider.get_validator(); // Get the validator instance | ||
|
||
let token_data = validator | ||
.await | ||
.validate_token(&token) | ||
.await | ||
.expect("Token validation failed"); | ||
|
||
assert_eq!( | ||
token_data.claims.custom_claims.name, | ||
Some("test_user".to_string()) | ||
); | ||
assert_eq!( | ||
token_data.claims.custom_claims.nickname, | ||
Some("test_nickname".to_string()) | ||
); | ||
assert_eq!( | ||
token_data.claims.custom_claims.email, | ||
Some("[email protected]".to_string()) | ||
); | ||
assert_eq!(token_data.claims.custom_claims.email_verified, Some(true)); | ||
} | ||
} |
Oops, something went wrong.