Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Dynamic CORS Configuration (Fixes #2) #10

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ chrono = "0.4.38"
sea-orm = { version = "1.0.0-rc.5", features = [ "sqlx-postgres", "runtime-tokio-rustls", "macros" ] }
uuid = { version = "1.10.0", features = ["v4"] }
bcrypt = "0.15.1"
tower-http = { version = "0.5.2", features = ["cors"] }
tower-http = { version = "0.6.1", features = ["cors"] }
HarshitShukla-dev marked this conversation as resolved.
Show resolved Hide resolved
dotenv = "0.15.0"
lazy_static = "1.5.0"
cookie = "0.18.1"
Expand Down
90 changes: 65 additions & 25 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use axum::{
routing::{get, post, put},
Extension, Router,
};
use dotenv::dotenv;
use handlers::{
auth_handlers::{
login_handler, new_password_handler, otp_handler, send_pass_reset_handler, signup_handler,
Expand All @@ -15,8 +16,11 @@ use handlers::{
update_contest_score_handler, update_score_handler, update_user_character_handler,
},
};
use http::{header::HeaderValue, uri::Uri};
use sea_orm::Database;
use std::env;
use tower_http::cors::{AllowOrigin, CorsLayer};

mod bcrypts;
mod configs;
mod handlers;
Expand All @@ -25,25 +29,47 @@ mod utils;

#[tokio::main]
async fn main() {
let db_string = (*utils::constants::DATABASE_URL).clone();
let cors = CorsLayer::new()
// Load environment variables from .env
dotenv().ok();

// Get the database URL from the environment, report an error and stop if not found
let db_string = env::var("DATABASE_URL").unwrap_or_else(|_| {
println!("Error: DATABASE_URL not found in environment.");
std::process::exit(1); // Terminate the program if DATABASE_URL is missing
});

// Get allowed origins from .env, if not found, use an empty list
let allowed_origins_env = env::var("ALLOWED_ORIGINS").unwrap_or_else(|_| {
println!("Warning: ALLOWED_ORIGINS not set in .env file, defaulting to an empty list.");
String::new()
});

// Parse the allowed origins from the environment
let allowed_origins = allowed_origins_env
.split(',')
.filter(|s| !s.trim().is_empty()) // Filter out any empty values
.filter_map(|origin| {
// Attempt to parse each origin into a Uri, print an error if it fails
match origin.parse::<Uri>() {
Ok(valid_origin) => Some(valid_origin),
Err(_) => {
println!("Warning: Invalid origin URL: {}", origin);
None
}
}
})
.collect::<Vec<_>>();

// If no origins are found, print a message
if allowed_origins.is_empty() {
println!(
"Warning: No valid origins found in ALLOWED_ORIGINS. CORS will not allow any origins."
);
}

// Configure CORS layer with dynamic origins
let mut cors = CorsLayer::new()
.allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE])
.allow_origin(AllowOrigin::exact(
"http://ec2-13-232-176-18.ap-south-1.compute.amazonaws.com:5173"
.parse()
.unwrap(),
))
.allow_origin(AllowOrigin::exact(
"http://ec2-13-232-176-18.ap-south-1.compute.amazonaws.com"
.parse()
.unwrap(),
))
.allow_origin(AllowOrigin::exact(
"http://ec2-13-126-149-80.ap-south-1.compute.amazonaws.com:5173"
.parse()
.unwrap(),
))
.allow_origin(AllowOrigin::exact("http://localhost:5173".parse().unwrap()))
.allow_headers([
http::header::ACCEPT,
http::header::CONTENT_TYPE,
Expand All @@ -53,17 +79,31 @@ async fn main() {
])
.allow_credentials(true);

let db = Database::connect(db_string)
.await
.expect("could not connect");
// Add the valid origins from the environment
for origin in &allowed_origins {
println!("Allowing origin: {}", origin);

// Convert Uri to HeaderValue
if let Ok(header_value) = HeaderValue::from_str(&origin.to_string()) {
cors = cors.allow_origin(AllowOrigin::exact(header_value));
} else {
println!("Warning: Failed to convert Uri to HeaderValue: {}", origin);
}
}

// Connect to the database
let db = Database::connect(&db_string).await.unwrap_or_else(|err| {
println!("Error: Failed to connect to the database: {}", err);
std::process::exit(1); // Terminate if the database connection fails
});

// Set up the app routes and layers
let app: Router<()> = Router::new()
.route("/sendpassreset", get(send_pass_reset_handler))
.route("/newpassword", get(new_password_handler))
.route("/otp", get(otp_handler))
.route("/login", post(login_handler))
// .route("/decode", get(decode_jwt))
.route("/signup", post(signup_handler))
// .route("/runcode", post(code_handler))
.route("/getuser", post(get_user_handler))
.route("/getboys", get(get_boys_handler))
.route("/getgirls", get(get_girls_handler))
Expand All @@ -82,8 +122,8 @@ async fn main() {
.layer(cors)
.layer(Extension(db));

let listner = tokio::net::TcpListener::bind("0.0.0.0:3001").await.unwrap();
let listener = tokio::net::TcpListener::bind("0.0.0.0:3001").await.unwrap();
println!("Listening");

axum::serve(listner, app).await.unwrap();
axum::serve(listener, app).await.unwrap();
}