diff --git a/.env.example b/.env.example index f90ec84..e89173c 100644 --- a/.env.example +++ b/.env.example @@ -4,5 +4,6 @@ SMTP_PASSWORD="example" SMTP_HOST="smtp.example.com" OTP_SECRET="lorememsum" JWT_SECRET="yoursecert" +HMAC_SECRET="your-HMAC-secret-key" PASS_RESET_LINK="http://localhost:5173/reset-password" ALLOWED_ORIGINS="http://localhost:3000,http://yourdomain.com" \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 7f327a4..d752cde 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2751,12 +2751,14 @@ version = "0.1.0" dependencies = [ "axum", "axum-extra", + "base64 0.21.7", "bcrypt", "chrono", "cookie", "dotenv", "entity", "handlebars", + "hmac", "json", "jsonwebtoken", "lazy_static", @@ -2767,6 +2769,7 @@ dependencies = [ "sea-orm", "serde", "serde_json", + "sha2", "tokio", "totp-rs", "tower-http", diff --git a/Cargo.toml b/Cargo.toml index 59c0250..534232c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,4 +38,7 @@ handlebars = "6.1.0" totp-rs = "5.6.0" rand = "0.8.5" reqwest = "0.12.8" +hmac = "0.12.1" +sha2 = "0.10.6" +base64 = "0.21.0" diff --git a/src/handlers/auth_handlers.rs b/src/handlers/auth_handlers.rs index f737239..6b7ac0b 100644 --- a/src/handlers/auth_handlers.rs +++ b/src/handlers/auth_handlers.rs @@ -1,4 +1,5 @@ -use std::collections::HashMap; +// auth_handler.rs +use std::{collections::HashMap, string}; use axum::{ extract::Query, @@ -23,6 +24,7 @@ use crate::{ model::{Claims, LoginInfo, SignUpInfo}, utils::{ constants::{JWT_SECRET, PASS_RESET_LINK}, + hash_token::{generate_secure_token, verify_token}, pass_reset::PassReset, verify_email::EmailOTP, }, @@ -188,6 +190,7 @@ pub async fn login_handler( }) } +// Modified send_pass_reset_handler pub async fn send_pass_reset_handler( Extension(db): Extension, Query(params): Query>, @@ -206,22 +209,10 @@ pub async fn send_pass_reset_handler( .await .unwrap() { - let token: String = - rand::Rng::sample_iter(rand::thread_rng(), &rand::distributions::Alphanumeric) - .take(64) - .map(char::from) - .collect(); - - // let hashed_token = match hash_password(token.as_str()) { - // Ok(hash) => hash, - // Err(e) => { - // eprintln!("Password could not be hashed -> {}", e); - // return Err(StatusCode::INTERNAL_SERVER_ERROR); - // } - // }; - - let token_expiry = Utc::now() + chrono::Duration::hours(1); // Adds 1 hour to the current time - let token_expiry_timestamp = token_expiry.timestamp(); // Converts to i64 (seconds since Unix epoch) + let (token, hmac) = generate_secure_token(); + + let token_expiry = Utc::now() + chrono::Duration::hours(1); + let token_expiry_timestamp = token_expiry.timestamp(); let username = user.user_name; let reset_link = format!("{}?token={}", PASS_RESET_LINK.to_string(), token); @@ -231,7 +222,7 @@ pub async fn send_pass_reset_handler( let pass_reset_model = pass_reset::ActiveModel { user_id: Set(user.id), - token: Set(token), + token: Set(hmac), // Store the HMAC instead of the plain token token_expiry: Set(token_expiry_timestamp), ..Default::default() }; @@ -245,7 +236,6 @@ pub async fn send_pass_reset_handler( } } - // Temporarily return a success message Ok(Json(format!("Password reset link sent to {}", email))) } @@ -253,58 +243,54 @@ pub async fn new_password_handler( Extension(db): Extension, Query(params): Query>, ) -> Result, StatusCode> { - if let (Some(reset_token), Some(hashed_password)) = - (params.get("token"), params.get("password")) - { - // let hashed_reset_token = match hash_password(reset_token) { - // Ok(hashed) => hashed, // Successfully hashed - // Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR), // Handle error - // }; - - // println!("{}", hashed_reset_token); - - let user = pass_reset::Entity::find() - .filter(pass_reset::Column::Token.contains(reset_token)) - .one(&db) + if let (Some(reset_token), Some(new_password)) = (params.get("token"), params.get("password")) { + let txn = db + .begin() .await - .unwrap(); - - let user_id = match user { - Some(entity) => entity.user_id, - None => return Err(StatusCode::NOT_FOUND), - }; - - let txn = db.begin().await.unwrap(); + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - let tokens = pass_reset::Entity::find() - .filter(pass_reset::Column::UserId.eq(user_id)) + // Fetch all password reset entries + let all_resets = pass_reset::Entity::find() .all(&txn) .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - let matched_token = tokens.into_iter().find(|row| row.token == *reset_token); + // Find the matching token + let matched_reset = all_resets + .into_iter() + .find(|reset| verify_token(&reset_token, &reset.token)); - if let Some(matched_token) = matched_token { + if let Some(reset) = matched_reset { // Check token expiry let current_time = Utc::now().timestamp(); - if matched_token.token_expiry < current_time { - txn.rollback().await.unwrap(); + if reset.token_expiry < current_time { + txn.rollback() + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; return Err(StatusCode::BAD_REQUEST); // Token has expired } // Delete all tokens for the user pass_reset::Entity::delete_many() - .filter(pass_reset::Column::UserId.eq(user_id)) + .filter(pass_reset::Column::UserId.eq(reset.user_id)) .exec(&txn) .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; // Update the user's password - let user_model = user::Entity::find_by_id(user_id).one(&txn).await.unwrap(); + let user_model = user::Entity::find_by_id(reset.user_id) + .one(&txn) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? + .ok_or(StatusCode::NOT_FOUND)?; - let mut user: user::ActiveModel = user_model.unwrap().into(); + let mut user: user::ActiveModel = user_model.into(); - user.password = Set(hashed_password.to_owned()); + // Hash the new password before storing + let hashed_password = + hash_password(new_password).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + user.password = Set(hashed_password); user.update(&txn) .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; @@ -316,12 +302,14 @@ pub async fn new_password_handler( Ok(Json("Password updated successfully".to_string())) } else { - // Token not found - txn.rollback().await.unwrap(); + // Token not found or invalid + txn.rollback() + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; Err(StatusCode::BAD_REQUEST) } } else { - // One or both are missing + // One or both parameters are missing Err(StatusCode::BAD_REQUEST) } } diff --git a/src/utils/constants.rs b/src/utils/constants.rs index 304bc24..c6b3e08 100644 --- a/src/utils/constants.rs +++ b/src/utils/constants.rs @@ -28,6 +28,10 @@ lazy_static! { env::var("JWT_SECRET").expect("JWT_SECRET must be set") }; + pub static ref HMAC_SECRET: String = { + env::var("HMAC_SECRET").expect("HMAC_SECRET must be set") + }; + pub static ref PASS_RESET_LINK: String = { env::var("PASS_RESET_LINK").expect("PASS_RESET_LINK must be set") }; diff --git a/src/utils/hash_token.rs b/src/utils/hash_token.rs new file mode 100644 index 0000000..c964f11 --- /dev/null +++ b/src/utils/hash_token.rs @@ -0,0 +1,31 @@ +use crate::utils::constants::HMAC_SECRET; +use base64::{engine::general_purpose, Engine as _}; +use hmac::{Hmac, Mac}; +use rand::Rng; +use sha2::Sha256; + +fn create_hmac(token: &str) -> String { + let hmac_secret_bytes = HMAC_SECRET.as_bytes(); + + let mut mac = + Hmac::::new_from_slice(hmac_secret_bytes).expect("HMAC can take key of any size"); + mac.update(token.as_bytes()); + let result = mac.finalize(); + general_purpose::STANDARD_NO_PAD.encode(result.into_bytes()) +} + +pub fn generate_secure_token() -> (String, String) { + let token: String = rand::thread_rng() + .sample_iter(&rand::distributions::Alphanumeric) + .take(64) + .map(char::from) + .collect(); + + let hmac = create_hmac(&token); + (token, hmac) +} + +pub fn verify_token(token: &str, stored_hmac: &str) -> bool { + let calculated_hmac = create_hmac(token); + calculated_hmac == stored_hmac +} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index b917ab6..2fff85d 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,4 +1,5 @@ pub mod constants; +pub mod hash_token; pub mod pass_reset; pub mod scripts; pub mod verify_email;