Skip to content

Commit

Permalink
feat(websocket): Add jwt module to ws connection
Browse files Browse the repository at this point in the history
  • Loading branch information
kasugamirai committed Sep 8, 2024
1 parent 305433b commit df1d031
Show file tree
Hide file tree
Showing 11 changed files with 1,036 additions and 106 deletions.
714 changes: 627 additions & 87 deletions websocket/Cargo.lock

Large diffs are not rendered by default.

37 changes: 21 additions & 16 deletions websocket/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
[workspace]
members = [
"crates/*",
]
members = ["crates/*"]

resolver = "2"

Expand Down Expand Up @@ -35,27 +33,34 @@ panic = "abort"
strip = true

[workspace.dependencies]
flow-websocket-domain = {path = "crates/domain"}
flow-websocket-domain = { path = "crates/domain" }

async-trait = "0.1.80"
axum = {version = "0.7", features = ["ws"]}
axum-extra = {version = "0.9", features = ["typed-header"]}
axum = { version = "0.7", features = ["ws"] }
axum-extra = { version = "0.9", features = ["typed-header"] }
axum-macros = "0.4"
chrono = {version = "0.4", features = ["serde"]}
chrono = { version = "0.4", features = ["serde"] }
google-cloud-storage = "0.18"
redis = {version = "0.25.4", features = ["aio", "tokio-comp"]}
redis = { version = "0.25.4", features = ["aio", "tokio-comp"] }
rslock = "0.3.0"
serde = {version = "1.0", features = ["derive"]}
serde_json = {version = "1.0.117", features = ["arbitrary_precision"]}
tokio = {version = "1.38.0", features = ["full", "time"]}
tower = {version = "0.4", features = ["timeout"]}
tower-http = {version = "0.5", features = ["fs", "trace"]}
serde = { version = "1.0", features = ["derive"] }
serde_json = { version = "1.0.117", features = ["arbitrary_precision"] }
tokio = { version = "1.38.0", features = ["full", "time"] }
tower = { version = "0.4", features = ["timeout"] }
tower-http = { version = "0.5", features = ["fs", "trace"] }
tracing = "0.1"
tracing-subscriber = {version = "0.3", features = ["env-filter"]}
uuid = {version = "1.8.0", features = [
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
uuid = { version = "1.8.0", features = [
"v4",
"fast-rng",
"macro-diagnostics",
"serde",
]}
] }
yrs = "0.18"
reqwest = { version = "0.12.7", features = ["json"] }
thiserror = "1.0.63"
jsonwebtoken = "9.3.0"
wiremock = "0.5"
rsa = { version = "0.10.0-pre.2", features = ["pem"] }
rand = "0.8"
base64 = "0.13"
4 changes: 2 additions & 2 deletions websocket/crates/domain/src/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ impl ProjectEditingSession {
let session_id = generate_id(14, "editor-session");
self.session_id = Some(session_id.clone());
if !self.session_setup_complete {
let _latest_snapshot_state = snapshot_repo
let latest_snapshot_state = snapshot_repo
.get_latest_snapshot_state(&self.project_id)
.await?;
// Initialize Redis with latest snapshot state
Expand Down Expand Up @@ -110,7 +110,7 @@ impl ProjectEditingSession {
snapshot_repo: &impl ProjectSnapshotRepository,
data: SnapshotData,
) -> Result<(), Box<dyn Error>> {
self.merge_updates().await?;
let merged_state = self.merge_updates().await?;
let snapshot = ProjectSnapshot {
id: generate_id(14, "snap"),
project_id: self.project_id.clone(),
Expand Down
2 changes: 1 addition & 1 deletion websocket/crates/domain/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ use uuid::Uuid;
pub fn generate_id(length: usize, prefix: &str) -> String {
let _ = length;
format!("{}{}", prefix, Uuid::new_v4().to_string())
}
}
7 changes: 7 additions & 0 deletions websocket/crates/infra/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,10 @@ tracing-subscriber.workspace = true
tracing.workspace = true
uuid.workspace = true
yrs.workspace = true
thiserror.workspace = true
reqwest.workspace = true
jsonwebtoken.workspace = true
wiremock.workspace = true
rsa.workspace = true
rand.workspace = true
base64.workspace = true
31 changes: 31 additions & 0 deletions websocket/crates/infra/src/auth/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
use reqwest::StatusCode;
use thiserror::Error;

#[derive(Error, Debug)]
pub enum AuthError {
#[error("Invalid token header: {0}")]
InvalidTokenHeader(String),

#[error("No `kid` found in token header")]
NoKidInTokenHeader,

#[error("No matching JWK found for kid: {0}")]
NoMatchingJwk(String),

#[error("Invalid JWK key: {0}")]
InvalidJwkKey(String),

#[error("Token validation failed: {0}")]
TokenValidationFailed(String),

#[error("Failed to fetch JWKS: {0}")]
FailedToFetchJwks(String),

#[error("HTTP request failed with status: {0}")]
HttpRequestFailed(StatusCode),

#[error("Unexpected error: {0}")]
Unexpected(String),
}

pub type Result<T> = std::result::Result<T, AuthError>;
243 changes: 243 additions & 0 deletions websocket/crates/infra/src/auth/jwt.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
use super::error::{AuthError, Result};
use super::jwt_validator::JWTValidator;
use super::types::{Jwk, Jwks};
use jsonwebtoken::Algorithm;
use reqwest::Client;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;

pub struct JWTProvider {
iss: String,
aud: Vec<String>,
alg: Algorithm,
jwks_uri: String,
ttl: Duration,
jwks_cache: Arc<RwLock<JwksCache>>,
}

struct JwksCache {
keys: HashMap<String, Jwk>,
last_updated: Instant, // Last time the keys were updated
}

impl JWTProvider {
// Create a new JWTProvider
pub fn new(iss: String, aud: Vec<String>, alg: Algorithm, jwks_uri: String, ttl: u64) -> Self {
JWTProvider {
iss,
aud,
alg,
jwks_uri,
ttl: Duration::from_secs(ttl),
jwks_cache: Arc::new(RwLock::new(JwksCache {
keys: HashMap::new(),
last_updated: Instant::now() - Duration::from_secs(ttl),
})),
}
}

// Fetch the JWKS and cache it, respecting TTL
pub async fn fetch_jwks(&self) -> Result<()> {
let mut cache = self.jwks_cache.write().await;

// Check if cache is still valid
if !cache.is_expired(self.ttl) {
return Ok(());
}

let response = Client::new()
.get(&self.jwks_uri)
.send()
.await
.map_err(|e| AuthError::FailedToFetchJwks(e.to_string()))?;

if !response.status().is_success() {
return Err(AuthError::HttpRequestFailed(response.status()));
}

let jwks: Jwks = response
.json()
.await
.map_err(|e| AuthError::FailedToFetchJwks(e.to_string()))?;

cache.keys.clear(); // Clear old cache
for jwk in jwks.keys {
cache.keys.insert(jwk.kid.clone(), jwk);
}
cache.last_updated = Instant::now(); // Update last_updated time

Ok(())
}

pub async fn get_validator(&self) -> JWTValidator {
let keys = {
let cache = self.jwks_cache.read().await;
Arc::new(RwLock::new(cache.keys.clone())) // Clone only the keys
};

JWTValidator::new(self.iss.clone(), self.aud.clone(), self.alg, keys)
}
}

impl JwksCache {
#[inline]
pub fn is_expired(&self, ttl: Duration) -> bool {
self.last_updated.elapsed() >= ttl
}
}

#[cfg(test)]
mod tests {
use crate::auth::jwt_validator::JWTValidatorTrait;
use crate::auth::types::Claims;

use super::super::types::CustomClaims;
use super::*;
use base64::{encode_config, URL_SAFE_NO_PAD};
use chrono::{Duration, Utc};
use jsonwebtoken::{encode, EncodingKey, Header};
use rand::rngs::OsRng;
use rsa::traits::PublicKeyParts;
use rsa::{pkcs1::LineEnding, pkcs8::EncodePrivateKey, RsaPrivateKey, RsaPublicKey};
use serde::{Deserialize, Serialize};
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};

#[derive(Debug, Serialize, Deserialize)]
pub struct TestClaims {
pub sub: String,
pub exp: usize,
pub iss: String,
pub aud: Vec<String>,
pub iat: usize,
pub custom_claims: CustomClaims,
}

#[tokio::test]
async fn test_fetch_jwks() {
let mock_server = MockServer::start().await;

let jwks_response = r#"
{
"keys": [
{
"kty": "RSA",
"kid": "m3eP7rPCsf3hgtWU1seMZ",
"use": "sig",
"n": "xokzqjjolOLqzFYs7bNVyTafNmfWT-9_i8dgV4OcnBaKP8Es5O5u4dHxx6pZoZYxUyZS0-IOc5w3Em3g2PYUS4TmwDD65nrNZEJwr-CbxwbUwEGkYtTNcQML1LHhkRnGvxdwl_3st5HFBPRHnU5y8hMgke_nMTqrUTLaS7Px7v5lpt_nNH60FnfPJqcgD6pG4TG510HhnLV0ELDCqD8F79omuuqwqgntQLr-XR7mw_2PfMV8QdMx-kcwtVVhBeM5hr-KdKAQ-56MbU5GAke7kZJt94_2DHv8wpmQtlmKuIOEBFJNoS3prisdXmlmP6qKDSGufRNg3x5wJ-di0IlIeQ",
"e": "AQAB",
"alg": "RS256"
}
]
}"#;

Mock::given(method("GET"))
.and(path("/.well-known/jwks.json"))
.respond_with(ResponseTemplate::new(200).set_body_string(jwks_response))
.mount(&mock_server)
.await;

let provider = JWTProvider::new(
"https://issuer.com".to_string(),
vec!["my_audience".to_string()],
Algorithm::RS256,
format!("{}/.well-known/jwks.json", &mock_server.uri()),
3600,
);

provider.fetch_jwks().await.unwrap(); // Now we call fetch_jwks() directly

let cache = provider.jwks_cache.read().await;
assert_eq!(cache.keys.len(), 1);
assert!(cache.keys.contains_key("m3eP7rPCsf3hgtWU1seMZ"));
}

#[tokio::test]
async fn test_validate_jwt() {
let mut rng = OsRng;
let private_key = RsaPrivateKey::new(&mut rng, 2048).expect("Failed to generate a key");
let public_key = RsaPublicKey::from(&private_key);

let n_base64 = encode_config(public_key.n().to_bytes_be(), URL_SAFE_NO_PAD);
let e_base64 = encode_config(public_key.e().to_bytes_be(), URL_SAFE_NO_PAD);

let mock_server = MockServer::start().await;
let jwks_response = format!(
r#"
{{
"keys": [
{{
"kty": "RSA",
"kid": "m3eP7rPCsf3hgtWU1seMZ",
"use": "sig",
"n": "{}",
"e": "{}",
"alg": "RS256"
}}
]
}}"#,
n_base64, e_base64
);

Mock::given(method("GET"))
.and(path("/.well-known/jwks.json"))
.respond_with(ResponseTemplate::new(200).set_body_string(jwks_response))
.mount(&mock_server)
.await;

let now = Utc::now();
let my_claims = Claims {
sub: "test_sub".to_string(),
exp: (now + Duration::hours(1)).timestamp() as usize,
iss: "https://issuer.com".to_string(),
aud: vec!["my_audience".to_string()],
iat: now.timestamp() as usize,
custom_claims: CustomClaims {
name: Some("test_user".to_string()),
nickname: Some("test_nickname".to_string()),
email: Some("[email protected]".to_string()),
email_verified: Some(true),
},
};
let encoding_key =
EncodingKey::from_rsa_pem(private_key.to_pkcs8_pem(LineEnding::LF).unwrap().as_bytes())
.expect("Invalid private key");
let mut header = Header::new(Algorithm::RS256);
header.kid = Some("m3eP7rPCsf3hgtWU1seMZ".to_string());
let token = encode(&header, &my_claims, &encoding_key).expect("Failed to encode token");

let provider = JWTProvider::new(
"https://issuer.com".to_string(),
vec!["my_audience".to_string()],
Algorithm::RS256,
format!("{}/.well-known/jwks.json", &mock_server.uri()),
3600,
);

provider.fetch_jwks().await.unwrap(); // Fetch keys first

let validator = provider.get_validator(); // Get the validator instance

let token_data = validator
.await
.validate_token(&token)
.await
.expect("Token validation failed");

assert_eq!(
token_data.claims.custom_claims.name,
Some("test_user".to_string())
);
assert_eq!(
token_data.claims.custom_claims.nickname,
Some("test_nickname".to_string())
);
assert_eq!(
token_data.claims.custom_claims.email,
Some("[email protected]".to_string())
);
assert_eq!(token_data.claims.custom_claims.email_verified, Some(true));
}
}
Loading

0 comments on commit df1d031

Please sign in to comment.