Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Dynamic CORS Configuration (Fixes #2) #10

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ SMTP_PASSWORD="example"
SMTP_HOST="smtp.example.com"
OTP_SECRET="lorememsum"
JWT_SECRET="yoursecert"
PASS_RESET_LINK="http://localhost:5173/reset-password"
PASS_RESET_LINK="http://localhost:5173/reset-password"
ALOOWED_ORIGIN="http://localhost:3000,http://yourdomain.com"
68 changes: 11 additions & 57 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,8 @@ use axum::{
routing::{get, post, put},
Extension, Router,
};
use dotenv::dotenv;
use handlers::{
auth_handlers::{
login_handler, new_password_handler, otp_handler, send_pass_reset_handler, signup_handler,
},
auth_handlers::{login_handler, new_password_handler, otp_handler, send_pass_reset_handler, signup_handler},
cp_handler::code_handler,
crud_handlers::{
add_friend_handler, change_flag_handler, create_matched_handler, get_accepted_boys_handler,
Expand All @@ -16,11 +13,8 @@ use handlers::{
update_contest_score_handler, update_score_handler, update_user_character_handler,
},
};
use http::{header::HeaderValue, uri::Uri};
use sea_orm::Database;
use std::env;
use tower_http::cors::{AllowOrigin, CorsLayer};

mod bcrypts;
mod configs;
mod handlers;
Expand All @@ -29,45 +23,11 @@ mod utils;

#[tokio::main]
async fn main() {
// Load environment variables from .env
dotenv().ok();

// Get the database URL from the environment, report an error and stop if not found
let db_string = env::var("DATABASE_URL").unwrap_or_else(|_| {
println!("Error: DATABASE_URL not found in environment.");
std::process::exit(1); // Terminate the program if DATABASE_URL is missing
});

// Get allowed origins from .env, if not found, use an empty list
let allowed_origins_env = env::var("ALLOWED_ORIGINS").unwrap_or_else(|_| {
println!("Warning: ALLOWED_ORIGINS not set in .env file, defaulting to an empty list.");
String::new()
});

// Parse the allowed origins from the environment
let allowed_origins = allowed_origins_env
.split(',')
.filter(|s| !s.trim().is_empty()) // Filter out any empty values
.filter_map(|origin| {
// Attempt to parse each origin into a Uri, print an error if it fails
match origin.parse::<Uri>() {
Ok(valid_origin) => Some(valid_origin),
Err(_) => {
println!("Warning: Invalid origin URL: {}", origin);
None
}
}
})
.collect::<Vec<_>>();
let db_string = (*utils::constants::DATABASE_URL).clone();

// Use ALLOWED_ORIGINS from constants.rs
let allowed_origins = (*utils::constants::ALLOWED_ORIGINS).clone();

// If no origins are found, print a message
if allowed_origins.is_empty() {
println!(
"Warning: No valid origins found in ALLOWED_ORIGINS. CORS will not allow any origins."
);
}

// Configure CORS layer with dynamic origins
let mut cors = CorsLayer::new()
.allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE])
.allow_headers([
Expand All @@ -79,25 +39,19 @@ async fn main() {
])
.allow_credentials(true);

// Add the valid origins from the environment
// Configure CORS for each allowed origin
for origin in &allowed_origins {
println!("Allowing origin: {}", origin);

// Convert Uri to HeaderValue
if let Ok(header_value) = HeaderValue::from_str(&origin.to_string()) {
if let Ok(header_value) = http::header::HeaderValue::from_str(origin) {
cors = cors.allow_origin(AllowOrigin::exact(header_value));
} else {
println!("Warning: Failed to convert Uri to HeaderValue: {}", origin);
println!("Warning: Failed to convert origin to HeaderValue: {}", origin);
}
}

// Connect to the database
let db = Database::connect(&db_string).await.unwrap_or_else(|err| {
println!("Error: Failed to connect to the database: {}", err);
std::process::exit(1); // Terminate if the database connection fails
});

// Set up the app routes and layers
let db = Database::connect(db_string)
.await
.expect("could not connect");
let app: Router<()> = Router::new()
.route("/sendpassreset", get(send_pass_reset_handler))
.route("/newpassword", get(new_password_handler))
Expand Down
60 changes: 45 additions & 15 deletions src/utils/constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,53 @@ lazy_static! {
pub static ref DATABASE_URL: String = {
dotenv().ok(); // loading env vars from the file
env::var("DATABASE_URL").expect("DATABASE_URL must be set")
};
pub static ref SMTP_USERNAME: String = {
};

pub static ref SMTP_USERNAME: String = {
env::var("SMTP_USERNAME").expect("SMTP_USERNAME must be set")
};
pub static ref SMTP_PASSWORD: String = {
};

pub static ref SMTP_PASSWORD: String = {
env::var("SMTP_PASSWORD").expect("SMTP_PASSWORD must be set")
};
pub static ref SMTP_HOST: String = {
env::var("SMTP_HOST").expect("SMTP_HOST must be set")
};
pub static ref OTP_SECRET: String = {
};

pub static ref SMTP_HOST: String = {
env::var("SMTP_HOST").expect("SMTP_HOST must be set")
};

pub static ref OTP_SECRET: String = {
env::var("OTP_SECRET").expect("OTP_SECRET must be set")
};
pub static ref JWT_SECRET: String ={
env::var("JWT_SECRET").expect("JWT_SECRET must be set")
};
pub static ref PASS_RESET_LINK: String ={
env::var("PASS_RESET_LINK").expect("PASS_RESET_LINK must be set")
};

pub static ref JWT_SECRET: String = {
env::var("JWT_SECRET").expect("JWT_SECRET must be set")
};

pub static ref PASS_RESET_LINK: String = {
env::var("PASS_RESET_LINK").expect("PASS_RESET_LINK must be set")
};

pub static ref ALLOWED_ORIGINS: Vec<String> = {
dotenv().ok();
HarshitShukla-dev marked this conversation as resolved.
Show resolved Hide resolved
match env::var("ALLOWED_ORIGINS") {
Ok(allowed_origins_env) => {
allowed_origins_env
.split(',')
.filter_map(|origin| {
let trimmed_origin = origin.trim();
if !trimmed_origin.is_empty() {
Some(trimmed_origin.to_string())
} else {
None
}
})
.collect()
}
Err(_) => {
// If ALLOWED_ORIGINS is not set, it defaults to an empty list
println!("Warning: ALLOWED_ORIGINS variable not set. Defaulting to an empty list.");
Vec::new()
}
}
};
}
Loading