From 1d6d2fdbccd83aa534006929c6fc384f1f247eb4 Mon Sep 17 00:00:00 2001 From: brianheineman Date: Mon, 1 Jul 2024 09:37:07 -0600 Subject: [PATCH] feat: utilize sqlx for database management to support PostgreSQL installations that do not bundle psql --- Cargo.lock | 76 +++++++++++++-- Cargo.toml | 1 + examples/zonky/Cargo.toml | 11 +++ examples/zonky/src/main.rs | 35 +++++++ postgresql_embedded/Cargo.toml | 14 ++- postgresql_embedded/src/error.rs | 3 + postgresql_embedded/src/postgresql.rs | 127 ++++++++++++-------------- postgresql_embedded/src/settings.rs | 2 + postgresql_embedded/tests/zonky.rs | 11 +-- 9 files changed, 193 insertions(+), 87 deletions(-) create mode 100644 examples/zonky/Cargo.toml create mode 100644 examples/zonky/src/main.rs diff --git a/Cargo.lock b/Cargo.lock index 3a3cfcf..96ae1dd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1045,7 +1045,7 @@ dependencies = [ "http", "hyper", "hyper-util", - "rustls", + "rustls 0.23.10", "rustls-native-certs", "rustls-pki-types", "tokio", @@ -1840,6 +1840,7 @@ dependencies = [ "postgresql_commands", "rand", "semver", + "sqlx", "target-triple", "tempfile", "test-log", @@ -1891,7 +1892,7 @@ dependencies = [ "quinn-proto", "quinn-udp", "rustc-hash", - "rustls", + "rustls 0.23.10", "thiserror", "tokio", "tracing", @@ -1907,7 +1908,7 @@ dependencies = [ "rand", "ring", "rustc-hash", - "rustls", + "rustls 0.23.10", "slab", "thiserror", "tinyvec", @@ -2083,9 +2084,9 @@ dependencies = [ "percent-encoding", "pin-project-lite", "quinn", - "rustls", + "rustls 0.23.10", "rustls-native-certs", - "rustls-pemfile", + "rustls-pemfile 2.1.2", "rustls-pki-types", "serde", "serde_json", @@ -2224,6 +2225,17 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "rustls" +version = "0.21.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f56a14d1f48b391359b22f731fd4bd7e43c97f3c50eee276f3aa09c94784d3e" +dependencies = [ + "ring", + "rustls-webpki 0.101.7", + "sct", +] + [[package]] name = "rustls" version = "0.23.10" @@ -2233,7 +2245,7 @@ dependencies = [ "once_cell", "ring", "rustls-pki-types", - "rustls-webpki", + "rustls-webpki 0.102.4", "subtle", "zeroize", ] @@ -2245,12 +2257,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f1fb85efa936c42c6d5fc28d2629bb51e4b2f4b8a5211e297d599cc5a093792" dependencies = [ "openssl-probe", - "rustls-pemfile", + "rustls-pemfile 2.1.2", "rustls-pki-types", "schannel", "security-framework", ] +[[package]] +name = "rustls-pemfile" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c" +dependencies = [ + "base64 0.21.7", +] + [[package]] name = "rustls-pemfile" version = "2.1.2" @@ -2267,6 +2288,16 @@ version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "976295e77ce332211c0d24d92c0e83e50f5c5f046d11082cea19f3df13a3562d" +[[package]] +name = "rustls-webpki" +version = "0.101.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" +dependencies = [ + "ring", + "untrusted", +] + [[package]] name = "rustls-webpki" version = "0.102.4" @@ -2308,6 +2339,16 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "sct" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" +dependencies = [ + "ring", + "untrusted", +] + [[package]] name = "security-framework" version = "2.11.0" @@ -2543,9 +2584,12 @@ dependencies = [ "indexmap", "log", "memchr", + "native-tls", "once_cell", "paste", "percent-encoding", + "rustls 0.21.12", + "rustls-pemfile 1.0.4", "serde", "serde_json", "sha2", @@ -2556,6 +2600,7 @@ dependencies = [ "tokio-stream", "tracing", "url", + "webpki-roots", ] [[package]] @@ -2952,7 +2997,7 @@ version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" dependencies = [ - "rustls", + "rustls 0.23.10", "rustls-pki-types", "tokio", ] @@ -3289,6 +3334,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "webpki-roots" +version = "0.25.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f20c57d8d7db6d3b86154206ae5d8fba62dd39573114de97c2cb0578251f8e1" + [[package]] name = "whoami" version = "1.5.1" @@ -3578,6 +3629,15 @@ dependencies = [ "zstd", ] +[[package]] +name = "zonky" +version = "0.12.0" +dependencies = [ + "postgresql_archive", + "postgresql_embedded", + "tokio", +] + [[package]] name = "zopfli" version = "0.8.1" diff --git a/Cargo.toml b/Cargo.toml index a5b0e09..5e08298 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,6 +46,7 @@ serde_json = "1.0.118" sha1 = "0.10.6" sha2 = "0.10.8" sha3 = "0.10.8" +sqlx = { version = "0.7.4", default-features = false, features = ["postgres"] } tar = "0.4.41" target-triple = "0.1.3" test-log = "0.2.16" diff --git a/examples/zonky/Cargo.toml b/examples/zonky/Cargo.toml new file mode 100644 index 0000000..469b823 --- /dev/null +++ b/examples/zonky/Cargo.toml @@ -0,0 +1,11 @@ +[package] +edition.workspace = true +name = "zonky" +publish = false +license.workspace = true +version.workspace = true + +[dependencies] +postgresql_archive = { path = "../../postgresql_archive" } +postgresql_embedded = { path = "../../postgresql_embedded" } +tokio = { workspace = true, features = ["full"] } diff --git a/examples/zonky/src/main.rs b/examples/zonky/src/main.rs new file mode 100644 index 0000000..b9ef79f --- /dev/null +++ b/examples/zonky/src/main.rs @@ -0,0 +1,35 @@ +#![forbid(unsafe_code)] +#![deny(clippy::pedantic)] + +use postgresql_archive::configuration::zonky; +use postgresql_archive::VersionReq; +use postgresql_embedded::{PostgreSQL, Result, Settings}; + +#[tokio::main] +async fn main() -> Result<()> { + let settings = Settings { + releases_url: zonky::URL.to_string(), + version: VersionReq::parse("=16.2.0")?, + ..Default::default() + }; + let mut postgresql = PostgreSQL::new(settings); + postgresql.setup().await?; + postgresql.start().await?; + + let database_name = "test"; + postgresql.create_database(database_name).await?; + postgresql.database_exists(database_name).await?; + postgresql.drop_database(database_name).await?; + + postgresql.stop().await +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_main() -> Result<()> { + main() + } +} diff --git a/postgresql_embedded/Cargo.toml b/postgresql_embedded/Cargo.toml index ea831d0..07680dc 100644 --- a/postgresql_embedded/Cargo.toml +++ b/postgresql_embedded/Cargo.toml @@ -24,6 +24,7 @@ postgresql_archive = { path = "../postgresql_archive", version = "0.12.0", defau postgresql_commands = { path = "../postgresql_commands", version = "0.12.0" } rand = { workspace = true } semver = { workspace = true } +sqlx = { workspace = true, features = ["runtime-tokio"] } tempfile = { workspace = true } thiserror = { workspace = true } tokio = { workspace = true, features = ["full"], optional = true } @@ -39,11 +40,18 @@ tokio = { workspace = true, features = ["full"] } default = ["rustls-tls"] blocking = ["tokio"] bundled = [] -native-tls = ["postgresql_archive/native-tls"] -rustls-tls = ["postgresql_archive/rustls-tls"] +native-tls = [ + "postgresql_archive/native-tls", + "sqlx/tls-native-tls", +] +rustls-tls = [ + "postgresql_archive/rustls-tls", + "sqlx/tls-rustls", +] tokio = [ "dep:tokio", - "postgresql_commands/tokio" + "postgresql_commands/tokio", + "sqlx/runtime-tokio", ] [package.metadata.release] diff --git a/postgresql_embedded/src/error.rs b/postgresql_embedded/src/error.rs index 1d1dfc8..4835449 100644 --- a/postgresql_embedded/src/error.rs +++ b/postgresql_embedded/src/error.rs @@ -15,6 +15,9 @@ pub enum Error { /// Error when the database could not be created #[error(transparent)] CreateDatabaseError(anyhow::Error), + /// Error when accessing the database + #[error(transparent)] + DatabaseError(#[from] sqlx::Error), /// Error when determining if the database exists #[error(transparent)] DatabaseExistsError(anyhow::Error), diff --git a/postgresql_embedded/src/postgresql.rs b/postgresql_embedded/src/postgresql.rs index 0a1f97a..83e5ccd 100644 --- a/postgresql_embedded/src/postgresql.rs +++ b/postgresql_embedded/src/postgresql.rs @@ -1,6 +1,6 @@ use crate::error::Error::{DatabaseInitializationError, DatabaseStartError, DatabaseStopError}; use crate::error::Result; -use crate::settings::{Settings, BOOTSTRAP_SUPERUSER}; +use crate::settings::{Settings, BOOTSTRAP_DATABASE, BOOTSTRAP_SUPERUSER}; use postgresql_archive::get_version; use postgresql_archive::{extract, get_archive}; use postgresql_archive::{ExactVersion, ExactVersionReq}; @@ -8,12 +8,12 @@ use postgresql_commands::initdb::InitDbBuilder; use postgresql_commands::pg_ctl::Mode::{Start, Stop}; use postgresql_commands::pg_ctl::PgCtlBuilder; use postgresql_commands::pg_ctl::ShutdownMode::Fast; -use postgresql_commands::psql::PsqlBuilder; #[cfg(feature = "tokio")] use postgresql_commands::AsyncCommandExecutor; use postgresql_commands::CommandBuilder; #[cfg(not(feature = "tokio"))] use postgresql_commands::CommandExecutor; +use sqlx::{PgPool, Row}; use std::fs::{remove_dir_all, remove_file}; use std::io::prelude::*; use std::net::TcpListener; @@ -283,36 +283,39 @@ impl PostgreSQL { } } + /// Get a connection pool to the bootstrap database. + async fn get_pool(&self) -> Result { + let mut settings = self.settings.clone(); + settings.username = BOOTSTRAP_SUPERUSER.to_string(); + let database_url = settings.url(BOOTSTRAP_DATABASE); + let pool = PgPool::connect(database_url.as_str()).await?; + Ok(pool) + } + /// Create a new database with the given name. #[instrument(skip(self))] pub async fn create_database(&self, database_name: S) -> Result<()> where S: AsRef + std::fmt::Debug, { + let database_name = database_name.as_ref(); debug!( - "Creating database {} for {}:{}", - database_name.as_ref(), - self.settings.host, - self.settings.port + "Creating database {database_name} for {host}:{port}", + host = self.settings.host, + port = self.settings.port ); - let psql = PsqlBuilder::from(&self.settings) - .env(PGDATABASE, "") - .command(format!("CREATE DATABASE \"{}\"", database_name.as_ref())) - .username(BOOTSTRAP_SUPERUSER) - .no_psqlrc(); - - match self.execute_command(psql).await { - Ok((_stdout, _stderr)) => { - debug!( - "Created database {} for {}:{}", - database_name.as_ref(), - self.settings.host, - self.settings.port - ); - Ok(()) - } - Err(error) => Err(CreateDatabaseError(error.into())), - } + let pool = self.get_pool().await?; + sqlx::query(format!("CREATE DATABASE \"{database_name}\"").as_str()) + .execute(&pool) + .await + .map_err(|error| CreateDatabaseError(error.into()))?; + pool.close().await; + debug!( + "Created database {database_name} for {host}:{port}", + host = self.settings.host, + port = self.settings.port + ); + Ok(()) } /// Check if a database with the given name exists. @@ -321,29 +324,22 @@ impl PostgreSQL { where S: AsRef + std::fmt::Debug, { + let database_name = database_name.as_ref(); debug!( - "Checking if database {} exists for {}:{}", - database_name.as_ref(), - self.settings.host, - self.settings.port + "Checking if database {database_name} exists for {host}:{port}", + host = self.settings.host, + port = self.settings.port ); - let psql = PsqlBuilder::from(&self.settings) - .env(PGDATABASE, "") - .command(format!( - "SELECT 1 FROM pg_database WHERE datname='{}'", - database_name.as_ref() - )) - .username(BOOTSTRAP_SUPERUSER) - .no_psqlrc() - .tuples_only(); - - match self.execute_command(psql).await { - Ok((stdout, _stderr)) => match stdout.trim() { - "1" => Ok(true), - _ => Ok(false), - }, - Err(error) => Err(DatabaseExistsError(error.into())), - } + let pool = self.get_pool().await?; + let row = sqlx::query("SELECT COUNT(*) FROM pg_database WHERE datname = $1") + .bind(database_name.to_string()) + .fetch_one(&pool) + .await + .map_err(|error| DatabaseExistsError(error.into()))?; + let count: i64 = row.get(0); + pool.close().await; + + Ok(count == 1) } /// Drop a database with the given name. @@ -352,33 +348,24 @@ impl PostgreSQL { where S: AsRef + std::fmt::Debug, { + let database_name = database_name.as_ref(); debug!( - "Dropping database {} for {}:{}", - database_name.as_ref(), - self.settings.host, - self.settings.port + "Dropping database {database_name} for {host}:{port}", + host = self.settings.host, + port = self.settings.port ); - let psql = PsqlBuilder::from(&self.settings) - .env(PGDATABASE, "") - .command(format!( - "DROP DATABASE IF EXISTS \"{}\"", - database_name.as_ref() - )) - .username(BOOTSTRAP_SUPERUSER) - .no_psqlrc(); - - match self.execute_command(psql).await { - Ok((_stdout, _stderr)) => { - debug!( - "Dropped database {} for {}:{}", - database_name.as_ref(), - self.settings.host, - self.settings.port - ); - Ok(()) - } - Err(error) => Err(DropDatabaseError(error.into())), - } + let pool = self.get_pool().await?; + sqlx::query(format!("DROP DATABASE IF EXISTS \"{database_name}\"").as_str()) + .execute(&pool) + .await + .map_err(|error| DropDatabaseError(error.into()))?; + pool.close().await; + debug!( + "Dropped database {database_name} for {host}:{port}", + host = self.settings.host, + port = self.settings.port + ); + Ok(()) } #[cfg(not(feature = "tokio"))] diff --git a/postgresql_embedded/src/settings.rs b/postgresql_embedded/src/settings.rs index 3677d80..5501044 100644 --- a/postgresql_embedded/src/settings.rs +++ b/postgresql_embedded/src/settings.rs @@ -31,6 +31,8 @@ pub(crate) const ARCHIVE: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/post /// `PostgreSQL` superuser pub const BOOTSTRAP_SUPERUSER: &str = "postgres"; +/// `PostgreSQL` database +pub const BOOTSTRAP_DATABASE: &str = "postgres"; /// Database settings #[derive(Clone, Debug, PartialEq)] diff --git a/postgresql_embedded/tests/zonky.rs b/postgresql_embedded/tests/zonky.rs index 8073fdc..1206f4f 100644 --- a/postgresql_embedded/tests/zonky.rs +++ b/postgresql_embedded/tests/zonky.rs @@ -23,12 +23,11 @@ async fn test_zonky() -> Result<()> { postgresql.start().await?; assert_eq!(Status::Started, postgresql.status()); - // TODO: consider updating following methods to use a Rust driver instead of the psql CLI - // let database_name = "test"; - // assert!(!postgresql.database_exists(database_name).await?); - // postgresql.create_database(database_name).await?; - // assert!(postgresql.database_exists(database_name).await?); - // postgresql.drop_database(database_name).await?; + let database_name = "test"; + assert!(!postgresql.database_exists(database_name).await?); + postgresql.create_database(database_name).await?; + assert!(postgresql.database_exists(database_name).await?); + postgresql.drop_database(database_name).await?; postgresql.stop().await?; assert_eq!(Status::Stopped, postgresql.status());