diff --git a/Cargo.lock b/Cargo.lock index fc2090ed..fe506944 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -220,7 +220,7 @@ checksum = "b9ccdd8f2a161be9bd5c023df56f1b2a0bd1d83872ae53b71a84a12c9bf6e842" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] @@ -374,9 +374,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.3.2" +version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6dbe3c979c178231552ecba20214a8272df4e09f232a87aef4320cf06539aded" +checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36" [[package]] name = "block-buffer" @@ -458,7 +458,7 @@ dependencies = [ "proc-macro2", "quote", "serde_json", - "syn 2.0.90", + "syn 2.0.96", "zstd", ] @@ -523,11 +523,13 @@ dependencies = [ [[package]] name = "cc" -version = "1.0.79" +version = "1.2.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f" +checksum = "13208fcbb66eaeffe09b99fffbe1af420f00a7b35aa99ad683dfc1aa76145229" dependencies = [ "jobserver", + "libc", + "shlex", ] [[package]] @@ -590,7 +592,7 @@ dependencies = [ "heck 0.4.1", "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] @@ -656,9 +658,9 @@ dependencies = [ [[package]] name = "const-oid" -version = "0.9.5" +version = "0.9.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28c122c3980598d243d63d9a704629a2d748d101f278052ff068be5a4423ab6f" +checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" [[package]] name = "const-random" @@ -882,10 +884,12 @@ name = "corro-pg" version = "0.1.0" dependencies = [ "bytes", + "camino", "chrono", "compact_str 0.7.0", "corro-tests", "corro-types", + "eyre", "fallible-iterator 0.3.0", "futures", "hex", @@ -893,7 +897,10 @@ dependencies = [ "metrics", "pgwire", "postgres-types", + "rcgen", "rusqlite", + "rustls", + "rustls-pemfile", "socket2 0.5.5", "spawn", "sqlite3-parser", @@ -902,6 +909,8 @@ dependencies = [ "thiserror", "tokio", "tokio-postgres", + "tokio-postgres-rustls", + "tokio-rustls", "tokio-util", "tracing", "tracing-subscriber", @@ -1181,7 +1190,7 @@ dependencies = [ "proc-macro2", "quote", "scratch", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] @@ -1198,7 +1207,7 @@ checksum = "2345488264226bf682893e25de0769f3360aac9957980ec49361b083ddaa5bc5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] @@ -1388,7 +1397,7 @@ checksum = "487585f4d0c6655fe74905e2504d8ad6908e4db67f744eb140876906c2f3175d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] @@ -1439,7 +1448,7 @@ dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] @@ -1644,7 +1653,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] @@ -1695,9 +1704,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.9" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c85e1d9ab2eadba7e5040d4e09cbd6d072b76a557ad64e797c2cb9d4da21d7e4" +checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" dependencies = [ "cfg-if", "libc", @@ -1728,7 +1737,7 @@ version = "0.18.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b3ba52851e73b46a4c3df1d89343741112003f0f6f13beb0dfac9e457c3fdcd" dependencies = [ - "bitflags 2.3.2", + "bitflags 2.8.0", "libc", "libgit2-sys", "log", @@ -2173,7 +2182,7 @@ checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] @@ -2338,9 +2347,9 @@ checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6" [[package]] name = "jobserver" -version = "0.1.28" +version = "0.1.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab46a6e9526ddef3ae7f787c06f0f2600639ba80ea3eade3d8e670a2230f51d6" +checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0" dependencies = [ "libc", ] @@ -2382,9 +2391,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "libc" -version = "0.2.150" +version = "0.2.169" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89d92a4743f9a61002fae18374ed11e7973f530cb3a3255fb354818118b2203c" +checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" [[package]] name = "libgit2-sys" @@ -2487,7 +2496,7 @@ dependencies = [ "proc-macro2", "quote", "regex-syntax", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] @@ -3405,7 +3414,7 @@ version = "11.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d86a7c4638d42c44551f4791a20e687dbb4c3de1f33c43dd71e355cd429def1" dependencies = [ - "bitflags 2.3.2", + "bitflags 2.8.0", ] [[package]] @@ -3551,7 +3560,7 @@ version = "0.30.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a78046161564f5e7cd9008aff3b2990b3850dc8e0349119b98e8f251e099f24d" dependencies = [ - "bitflags 2.3.2", + "bitflags 2.8.0", "chrono", "fallible-iterator 0.3.0", "fallible-streaming-iterator", @@ -3757,7 +3766,7 @@ checksum = "4eca7ac642d82aa35b60049a6eccb4be6be75e599bd2e9adb5f875a737654af2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] @@ -3867,6 +3876,12 @@ dependencies = [ "regex", ] +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + [[package]] name = "signal-hook-registry" version = "1.4.1" @@ -3967,7 +3982,7 @@ checksum = "7d395866cb6778625150f77a430cc0af764ce0300f6a3d3413477785fa34b6c7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] @@ -3993,9 +4008,9 @@ dependencies = [ [[package]] name = "spki" -version = "0.7.2" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d1e996ef02c474957d681f1b05213dfb0abab947b446a62d37770b23500184a" +checksum = "d91ed6c858b01f942cd56b37a94b3e0a1798290327d1236e4d9cf4eaca44d29d" dependencies = [ "base64ct", "der", @@ -4024,7 +4039,7 @@ version = "0.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9add252f9b70a7d493b03127524ed06cdf7480b3dc8c1b2ccfda89384d90a8b7" dependencies = [ - "bitflags 2.3.2", + "bitflags 2.8.0", "cc", "fallible-iterator 0.3.0", "indexmap 2.1.0", @@ -4155,9 +4170,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.90" +version = "2.0.96" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "919d3b74a5dd0ccd15aeb8f93e7006bd9e14c295087c9896a110f490752bcf31" +checksum = "d5d0adab1ae378d7f53bdebc67a39f1f151407ef230f0ce2883572f5d8985c80" dependencies = [ "proc-macro2", "quote", @@ -4190,7 +4205,7 @@ checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] @@ -4259,7 +4274,7 @@ checksum = "090198534930841fab3a5d1bb637cde49e339654e606195f8d9c76eeb081dc96" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] @@ -4389,7 +4404,7 @@ checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] @@ -4430,11 +4445,25 @@ dependencies = [ "whoami", ] +[[package]] +name = "tokio-postgres-rustls" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd5831152cb0d3f79ef5523b357319ba154795d64c7078b2daa95a803b54057f" +dependencies = [ + "futures", + "ring", + "rustls", + "tokio", + "tokio-postgres", + "tokio-rustls", +] + [[package]] name = "tokio-rustls" -version = "0.24.0" +version = "0.24.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0d409377ff5b1e3ca6437aa86c1eb7d40c134bfec254e44c830defa92669db5" +checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" dependencies = [ "rustls", "tokio", @@ -4603,7 +4632,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] @@ -4934,7 +4963,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", "wasm-bindgen-shared", ] @@ -4956,7 +4985,7 @@ checksum = "54681b18a46765f095758388f2d0cf16eb8d4169b639ab575a8f5693af210c7b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -5355,7 +5384,7 @@ checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", "synstructure 0.13.1", ] @@ -5376,7 +5405,7 @@ checksum = "dd7e48ccf166952882ca8bd778a43502c64f33bf94c12ebe2a7f08e5a0f6689f" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] @@ -5396,15 +5425,15 @@ checksum = "595eed982f7d355beb85837f651fa22e90b3c044842dc7f2c2842c086f295808" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", "synstructure 0.13.1", ] [[package]] name = "zeroize" -version = "1.6.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a0956f1ba7c7909bfb66c2e9e4124ab6f6482560f6628b5aaeba39207c9aad9" +checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" [[package]] name = "zerovec" @@ -5425,7 +5454,7 @@ checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] diff --git a/crates/corro-pg/Cargo.toml b/crates/corro-pg/Cargo.toml index 24edddee..87b5818c 100644 --- a/crates/corro-pg/Cargo.toml +++ b/crates/corro-pg/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" bytes = { workspace = true } compact_str = { workspace = true } corro-types = { path = "../corro-types" } +eyre = { workspace = true } fallible-iterator = { workspace = true } futures = { workspace = true } hex = { workspace = true } @@ -15,6 +16,8 @@ metrics = { workspace = true } pgwire = { version = "0.16.1" } postgres-types = { version = "0.2", features = ["with-time-0_3"] } rusqlite = { workspace = true } +rustls = { workspace = true } +rustls-pemfile = "*" spawn = { path = "../spawn" } sqlite3-parser = { workspace = true } tempfile = { workspace = true } @@ -26,8 +29,12 @@ tripwire = { path = "../tripwire" } sqlparser = { version = "0.39.0" } chrono = { version = "0.4.31" } socket2 = { version = "0.5" } +tokio-rustls = "0.24.1" [dev-dependencies] corro-tests = { path = "../corro-tests" } tokio-postgres = { version = "0.7.10" } -tracing-subscriber = { workspace = true } \ No newline at end of file +tracing-subscriber = { workspace = true } +camino = { workspace = true } +tokio-postgres-rustls = "0.10.0" +rcgen = { workspace = true } diff --git a/crates/corro-pg/src/lib.rs b/crates/corro-pg/src/lib.rs old mode 100644 new mode 100755 index 01d54350..0465403a --- a/crates/corro-pg/src/lib.rs +++ b/crates/corro-pg/src/lib.rs @@ -47,6 +47,7 @@ use rusqlite::{ ffi::SQLITE_CONSTRAINT_UNIQUE, functions::FunctionFlags, types::ValueRef, vtab::eponymous_only_module, Connection, Statement, }; +use rustls::ServerConfig; use socket2::{SockRef, TcpKeepalive}; use spawn::spawn_counted; use sqlite3_parser::ast::{ @@ -62,7 +63,8 @@ use tokio::{ AcquireError, OwnedSemaphorePermit, }, }; -use tokio_util::{codec::Framed, sync::CancellationToken}; +use tokio_rustls::TlsAcceptor; +use tokio_util::{codec::Framed, either::Either, sync::CancellationToken}; use tracing::{debug, error, info, trace, warn}; use tripwire::{Outcome, PreemptibleFutureExt, Tripwire}; @@ -423,11 +425,8 @@ enum OpenTxKind { Explicit, } -async fn peek_for_sslrequest( - tcp_socket: &mut TcpStream, - ssl_supported: bool, -) -> std::io::Result { - let mut ssl = false; +async fn peek_for_sslrequest(tcp_socket: &mut TcpStream) -> std::io::Result { + let mut want_ssl = false; let mut buf = [0u8; SslRequest::BODY_SIZE]; let mut buf = ReadBuf::new(&mut buf); loop { @@ -447,15 +446,10 @@ async fn peek_for_sslrequest( .read_exact(&mut [0u8; SslRequest::BODY_SIZE]) .await?; // ssl configured - if ssl_supported { - ssl = true; - tcp_socket.write_all(b"S").await?; - } else { - tcp_socket.write_all(b"N").await?; - } + want_ssl = true; } - return Ok(ssl); + return Ok(want_ssl); } } } @@ -466,6 +460,86 @@ pub enum PgStartError { Io(#[from] std::io::Error), #[error(transparent)] Rusqlite(#[from] rusqlite::Error), + #[error(transparent)] + PgTlsError(#[from] eyre::Error), +} + +async fn setup_tls(pg: PgConfig) -> eyre::Result<(Option, bool)> { + let tls = match pg.tls { + Some(tls) => tls, + None => { + return Ok((None, false)); + } + }; + + let ssl_required = tls.verify_client; + + let key = tokio::fs::read(&tls.key_file).await?; + let key = if tls.key_file.extension().map_or(false, |x| x == "der") { + rustls::PrivateKey(key) + } else { + let pkcs8 = rustls_pemfile::pkcs8_private_keys(&mut &*key)?; + match pkcs8.into_iter().next() { + Some(x) => rustls::PrivateKey(x), + None => { + let rsa = rustls_pemfile::rsa_private_keys(&mut &*key)?; + match rsa.into_iter().next() { + Some(x) => rustls::PrivateKey(x), + None => { + eyre::bail!("no private keys found"); + } + } + } + } + }; + + let certs = tokio::fs::read(&tls.cert_file).await?; + let certs = if tls.cert_file.extension().map_or(false, |x| x == "der") { + vec![rustls::Certificate(certs)] + } else { + rustls_pemfile::certs(&mut &*certs)? + .into_iter() + .map(rustls::Certificate) + .collect() + }; + + let server_crypto = ServerConfig::builder().with_safe_defaults(); + + let server_crypto = if ssl_required { + let ca_file = match &tls.ca_file { + None => { + eyre::bail!( + "ca_file required in tls config for server client cert auth verification" + ); + } + Some(ca_file) => ca_file, + }; + + let ca_certs = tokio::fs::read(&ca_file).await?; + let ca_certs = if ca_file.extension().map_or(false, |x| x == "der") { + vec![rustls::Certificate(ca_certs)] + } else { + rustls_pemfile::certs(&mut &*ca_certs)? + .into_iter() + .map(rustls::Certificate) + .collect() + }; + + let mut root_store = rustls::RootCertStore::empty(); + + for cert in ca_certs { + root_store.add(&cert)?; + } + + server_crypto.with_client_cert_verifier(Arc::new( + rustls::server::AllowAnyAuthenticatedClient::new(root_store), + )) + } else { + server_crypto.with_no_client_auth() + }; + + let config = server_crypto.with_single_cert(certs, key)?; + Ok((Some(TlsAcceptor::from(Arc::new(config))), ssl_required)) } pub async fn start( @@ -474,6 +548,7 @@ pub async fn start( mut tripwire: Tripwire, ) -> Result { let server = TcpListener::bind(pg.bind_addr).await?; + let (tls_acceptor, ssl_required) = setup_tls(pg).await?; let local_addr = server.local_addr()?; tokio::spawn(async move { @@ -482,6 +557,7 @@ pub async fn start( Outcome::Completed(res) => res?, Outcome::Preempted(_) => break, }; + let tls_acceptor = tls_acceptor.clone(); debug!("Accepted a PostgreSQL connection (from: {remote_addr})"); let agent = agent.clone(); @@ -492,13 +568,45 @@ pub async fn start( let ka = TcpKeepalive::new().with_time(Duration::from_secs(10)).with_interval(Duration::from_secs(10)).with_retries(4); sock.set_tcp_keepalive(&ka)?; } - let ssl = peek_for_sslrequest(&mut conn, false).await?; - trace!("SSL? {ssl}"); + let is_sslrequest = peek_for_sslrequest(&mut conn).await?; - let mut framed = Framed::new( - conn, - PgWireMessageServerCodec::new(ClientInfoHolder::new(remote_addr, false)), - ); + // reject non-ssl connections if ssl is required (client cert auth) + if ssl_required && !is_sslrequest { + debug!("rejecting non-ssl connection"); + return Ok(()); + } + + let (mut framed, secured) = match (tls_acceptor, is_sslrequest) { + (Some(tls_acceptor), true) => { + conn.write_all(b"S").await?; + let tls_conn = tls_acceptor.accept(conn).await?; + ( + Framed::new( + Either::Left(tls_conn), + PgWireMessageServerCodec::new(ClientInfoHolder::new( + local_addr, true, + )), + ), + true, + ) + } + (_, is_sslreq) => { + if is_sslreq { + conn.write_all(b"N").await?; + } + ( + Framed::new( + Either::Right(conn), + PgWireMessageServerCodec::new(ClientInfoHolder::new( + local_addr, false, + )), + ), + false, + ) + } + }; + + trace!("SSL ? {secured}"); let msg = match framed.next().await { Some(msg) => msg?, @@ -3271,20 +3379,32 @@ fn field_types( #[cfg(test)] mod tests { - use std::time::{Duration, Instant}; + use std::{ + io::BufReader, + time::{Duration, Instant}, + }; + use camino::Utf8PathBuf; use chrono::{DateTime, Utc}; - use corro_tests::launch_test_agent; + use corro_tests::{launch_test_agent, TestAgent}; + use corro_types::{ + config::PgTlsConfig, + tls::{generate_ca, generate_client_cert, generate_server_cert}, + }; + use rcgen::Certificate; use spawn::wait_for_all_pending_handles; + use tempfile::TempDir; use tokio_postgres::NoTls; + use tokio_postgres_rustls::MakeRustlsConnect; use tripwire::Tripwire; use super::*; - #[tokio::test(flavor = "multi_thread")] - async fn test_pg() -> Result<(), BoxError> { + async fn setup_pg_test_server( + tripwire: Tripwire, + tls_config: Option, + ) -> Result<(TestAgent, PgServer), BoxError> { _ = tracing_subscriber::fmt::try_init(); - let (tripwire, tripwire_worker, tripwire_tx) = Tripwire::new_simple(); let tmpdir = tempfile::tempdir()?; @@ -3310,17 +3430,27 @@ mod tests { ) .await?; - let sema = ta.agent.write_sema().clone(); - let server = start( ta.agent.clone(), PgConfig { bind_addr: "127.0.0.1:0".parse()?, + tls: tls_config, }, tripwire, ) .await?; + Ok((ta, server)) + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_pg() -> Result<(), BoxError> { + let (tripwire, tripwire_worker, tripwire_tx) = Tripwire::new_simple(); + + let (ta, server) = setup_pg_test_server(tripwire, None).await?; + + let sema = ta.agent.write_sema().clone(); + let conn_str = format!( "host={} port={} user=testuser", server.local_addr.ip(), @@ -3491,6 +3621,193 @@ mod tests { Ok(()) } + struct TestCertificates { + ca_cert: Certificate, + client_cert_signed: String, + client_key: Vec, + ca_file: Utf8PathBuf, + server_cert_file: Utf8PathBuf, + server_key_file: Utf8PathBuf, + } + + async fn generate_and_write_certs(tmpdir: &TempDir) -> Result { + let ca_cert = generate_ca()?; + let (server_cert, server_cert_signed) = generate_server_cert( + &ca_cert.serialize_pem()?, + &ca_cert.serialize_private_key_pem(), + "127.0.0.1".parse()?, + )?; + + let (client_cert, client_cert_signed) = generate_client_cert( + &ca_cert.serialize_pem()?, + &ca_cert.serialize_private_key_pem(), + )?; + + let base_path = Utf8PathBuf::from(tmpdir.path().display().to_string()); + + let cert_file = base_path.join("cert.pem"); + let key_file = base_path.join("cert.key"); + let ca_file = base_path.join("ca.pem"); + + let client_cert_file = base_path.join("client-cert.pem"); + let client_key_file = base_path.join("client-cert.key"); + + tokio::fs::write(&cert_file, &server_cert_signed).await?; + tokio::fs::write(&key_file, server_cert.serialize_private_key_pem()).await?; + + tokio::fs::write(&ca_file, ca_cert.serialize_pem()?).await?; + + tokio::fs::write(&client_cert_file, &client_cert_signed).await?; + tokio::fs::write(&client_key_file, client_cert.serialize_private_key_pem()).await?; + + Ok(TestCertificates { + server_cert_file: cert_file, + server_key_file: key_file, + ca_cert, + client_cert_signed: client_cert_signed, + client_key: client_cert.serialize_private_key_der(), + ca_file, + }) + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_pg_ssl() -> Result<(), BoxError> { + let (tripwire, tripwire_worker, tripwire_tx) = Tripwire::new_simple(); + + let tmpdir = TempDir::new()?; + let certs = generate_and_write_certs(&tmpdir).await?; + + let (ta, server) = setup_pg_test_server( + tripwire, + Some(PgTlsConfig { + cert_file: certs.server_cert_file, + key_file: certs.server_key_file, + ca_file: None, + verify_client: false, + }), + ) + .await?; + + let sema = ta.agent.write_sema().clone(); + + let conn_str = format!( + "host={} port={} user=testuser", + server.local_addr.ip(), + server.local_addr.port() + ); + + { + let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty(); + root_cert_store.add(&rustls::Certificate(certs.ca_cert.serialize_der()?))?; + let config = rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_cert_store) + .with_no_client_auth(); + + let connector = MakeRustlsConnect::new(config); + + println!("connecting to: {conn_str}"); + + let (client, client_conn) = tokio_postgres::connect(&conn_str, connector).await?; + + tokio::spawn(client_conn); + + let _permit = sema.acquire().await; + + println!("before query"); + + client.simple_query("SELECT 1").await?; + } + + tripwire_tx.send(()).await.ok(); + tripwire_worker.await; + wait_for_all_pending_handles().await; + + Ok(()) + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_pg_mtls() -> Result<(), BoxError> { + let (tripwire, tripwire_worker, tripwire_tx) = Tripwire::new_simple(); + + let tmpdir = TempDir::new()?; + + let certs = generate_and_write_certs(&tmpdir).await?; + + let (ta, server) = setup_pg_test_server( + tripwire, + Some(PgTlsConfig { + cert_file: certs.server_cert_file, + key_file: certs.server_key_file, + ca_file: Some(certs.ca_file), + verify_client: true, + }), + ) + .await?; + + let sema = ta.agent.write_sema().clone(); + + let conn_str = format!( + "host={} port={} user=testuser", + server.local_addr.ip(), + server.local_addr.port() + ); + + { + let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty(); + root_cert_store.add(&rustls::Certificate(certs.ca_cert.serialize_der()?))?; + + let client_cert = + rustls_pemfile::certs(&mut BufReader::new(certs.client_cert_signed.as_bytes())) + .map_err(|e| format!("failed to read client cert: {e}"))?; + + let client_cert: Vec = client_cert + .iter() + .map(|cert| rustls::Certificate(cert.clone())) + .collect(); + + let config = rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_cert_store.clone()) + .with_client_auth_cert(client_cert, rustls::PrivateKey(certs.client_key))?; + + let connector = MakeRustlsConnect::new(config); + + println!("connecting to: {conn_str} with client auth cert"); + let (client, client_conn) = tokio_postgres::connect(&conn_str, connector).await?; + + tokio::spawn(client_conn); + + println!("successfully connected!"); + + let _permit = sema.acquire().await; + + client.simple_query("SELECT 1").await?; + + let config = rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_cert_store) + .with_no_client_auth(); + + let connector = MakeRustlsConnect::new(config); + + println!("connecting to: {conn_str} without client auth cert"); + let result = tokio_postgres::connect(&conn_str, connector).await; + assert!( + result.is_err(), + "expected connect to fail without client auth cert" + ); + + println!("successfully failed to connect without client auth cert"); + } + + tripwire_tx.send(()).await.ok(); + tripwire_worker.await; + wait_for_all_pending_handles().await; + + Ok(()) + } + // #[tokio::test(flavor = "multi_thread")] // async fn test_write_permit_released_on_error() -> Result<(), BoxError> { // _ = tracing_subscriber::fmt::try_init(); diff --git a/crates/corro-types/src/config.rs b/crates/corro-types/src/config.rs old mode 100644 new mode 100755 index e647b068..c4a5fa7f --- a/crates/corro-types/src/config.rs +++ b/crates/corro-types/src/config.rs @@ -135,6 +135,17 @@ pub struct ApiConfig { pub struct PgConfig { #[serde(alias = "addr")] pub bind_addr: SocketAddr, + pub tls: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PgTlsConfig { + pub cert_file: Utf8PathBuf, + pub key_file: Utf8PathBuf, + #[serde(default)] + pub ca_file: Option, + #[serde(default)] + pub verify_client: bool, } #[derive(Debug, Clone, Serialize, Deserialize)]