From 964a420bca4d34acfcdd5f54035a230b24726d3c Mon Sep 17 00:00:00 2001 From: brianheineman Date: Fri, 28 Jun 2024 10:11:03 -0600 Subject: [PATCH] feat!: add semantic versioing support and configurable repositories --- Cargo.lock | 31 +- Cargo.toml | 16 +- examples/archive_async/src/main.rs | 5 +- examples/archive_sync/src/main.rs | 5 +- postgresql_archive/Cargo.toml | 2 + postgresql_archive/README.md | 6 - postgresql_archive/benches/archive.rs | 6 +- postgresql_archive/src/archive.rs | 366 +++--------- postgresql_archive/src/blocking/archive.rs | 66 +-- postgresql_archive/src/blocking/mod.rs | 2 +- postgresql_archive/src/error.rs | 48 +- postgresql_archive/src/lib.rs | 31 +- .../src/repository/github/mod.rs | 2 + .../github/models.rs} | 0 .../src/repository/github/repository.rs | 554 ++++++++++++++++++ postgresql_archive/src/repository/mod.rs | 5 + postgresql_archive/src/repository/model.rs | 114 ++++ postgresql_archive/src/repository/registry.rs | 135 +++++ postgresql_archive/src/version.rs | 300 +++------- postgresql_archive/tests/archive.rs | 83 +-- postgresql_archive/tests/blocking.rs | 37 +- postgresql_embedded/Cargo.toml | 1 + postgresql_embedded/build/bundle.rs | 14 +- .../src/blocking/postgresql.rs | 7 +- postgresql_embedded/src/error.rs | 5 +- postgresql_embedded/src/lib.rs | 39 +- postgresql_embedded/src/postgresql.rs | 49 +- postgresql_embedded/src/settings.rs | 32 +- 28 files changed, 1207 insertions(+), 754 deletions(-) create mode 100644 postgresql_archive/src/repository/github/mod.rs rename postgresql_archive/src/{github.rs => repository/github/models.rs} (100%) create mode 100644 postgresql_archive/src/repository/github/repository.rs create mode 100644 postgresql_archive/src/repository/mod.rs create mode 100644 postgresql_archive/src/repository/model.rs create mode 100644 postgresql_archive/src/repository/registry.rs diff --git a/Cargo.lock b/Cargo.lock index 2820edf..d228cca 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1133,9 +1133,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.21" +version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" +checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" [[package]] name = "matchers" @@ -1620,6 +1620,7 @@ dependencies = [ "reqwest-middleware", "reqwest-retry", "reqwest-tracing", + "semver", "serde", "serde_json", "sha2", @@ -1630,6 +1631,7 @@ dependencies = [ "thiserror", "tokio", "tracing", + "url", ] [[package]] @@ -1655,6 +1657,7 @@ dependencies = [ "postgresql_archive", "postgresql_commands", "rand", + "semver", "tempfile", "test-log", "thiserror", @@ -1902,9 +1905,9 @@ dependencies = [ [[package]] name = "reqwest-middleware" -version = "0.3.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a45d100244a467870f6cb763c4484d010a6bed6bd610b3676e3825d93fb4cfbd" +checksum = "39346a33ddfe6be00cbc17a34ce996818b97b230b87229f10114693becca1268" dependencies = [ "anyhow", "async-trait", @@ -1917,9 +1920,9 @@ dependencies = [ [[package]] name = "reqwest-retry" -version = "0.5.0" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40f342894422862af74c50e1e9601cf0931accc9c6981e5eb413c46603b616b5" +checksum = "cf2a94ba69ceb30c42079a137e2793d6d0f62e581a24c06cd4e9bb32e973c7da" dependencies = [ "anyhow", "async-trait", @@ -1939,9 +1942,9 @@ dependencies = [ [[package]] name = "reqwest-tracing" -version = "0.5.0" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b253954a1979e02eabccd7e9c3d61d8f86576108baa160775e7f160bb4e800a3" +checksum = "71a37668dccbd75e045f26811891dd939f28c38d3b7ca572a4fce4bc462b83ec" dependencies = [ "anyhow", "async-trait", @@ -1955,12 +1958,10 @@ dependencies = [ [[package]] name = "retry-policies" -version = "0.3.0" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "493b4243e32d6eedd29f9a398896e35c6943a123b55eec97dcaee98310d25810" +checksum = "5875471e6cab2871bc150ecb8c727db5113c9338cc3354dc5ee3425b6aa40a1c" dependencies = [ - "anyhow", - "chrono", "rand", ] @@ -2131,6 +2132,12 @@ dependencies = [ "libc", ] +[[package]] +name = "semver" +version = "1.0.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" + [[package]] name = "serde" version = "1.0.203" diff --git a/Cargo.toml b/Cargo.toml index c49eb5c..9c6d458 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,15 @@ [workspace] -default-members = ["postgresql_archive", "postgresql_commands", "postgresql_embedded"] -members = ["examples/*", "postgresql_archive", "postgresql_commands", "postgresql_embedded"] +default-members = [ + "postgresql_archive", + "postgresql_commands", + "postgresql_embedded", +] +members = [ + "examples/*", + "postgresql_archive", + "postgresql_commands", + "postgresql_embedded", +] resolver = "2" [workspace.package] @@ -28,8 +37,9 @@ rand = "0.8.5" regex = "1.10.5" reqwest = { version = "0.12.5", default-features = false } reqwest-middleware = "0.3.1" -reqwest-retry = "0.5.0" +reqwest-retry = "0.6.0" reqwest-tracing = "0.5.0" +semver = "1.0.23" serde = "1.0.203" serde_json = "1.0.118" sha2 = "0.10.8" diff --git a/examples/archive_async/src/main.rs b/examples/archive_async/src/main.rs index 278b1a2..33d12c9 100644 --- a/examples/archive_async/src/main.rs +++ b/examples/archive_async/src/main.rs @@ -1,11 +1,12 @@ #![forbid(unsafe_code)] #![deny(clippy::pedantic)] -use postgresql_archive::{extract, get_archive, Result, DEFAULT_RELEASES_URL, LATEST}; +use postgresql_archive::{extract, get_archive, Result, VersionReq, DEFAULT_POSTGRESQL_URL}; #[tokio::main] async fn main() -> Result<()> { - let (archive_version, archive) = get_archive(DEFAULT_RELEASES_URL, &LATEST).await?; + let version_req = VersionReq::STAR; + let (archive_version, archive) = get_archive(DEFAULT_POSTGRESQL_URL, &version_req).await?; let out_dir = tempfile::tempdir()?.into_path(); extract(&archive, &out_dir).await?; println!( diff --git a/examples/archive_sync/src/main.rs b/examples/archive_sync/src/main.rs index db3d1d9..4214e0b 100644 --- a/examples/archive_sync/src/main.rs +++ b/examples/archive_sync/src/main.rs @@ -2,10 +2,11 @@ #![deny(clippy::pedantic)] use postgresql_archive::blocking::{extract, get_archive}; -use postgresql_archive::{Result, DEFAULT_RELEASES_URL, LATEST}; +use postgresql_archive::{Result, VersionReq, DEFAULT_POSTGRESQL_URL}; fn main() -> Result<()> { - let (archive_version, archive) = get_archive(DEFAULT_RELEASES_URL, &LATEST)?; + let version_req = VersionReq::STAR; + let (archive_version, archive) = get_archive(DEFAULT_POSTGRESQL_URL, &version_req)?; let out_dir = tempfile::tempdir()?.into_path(); extract(&archive, &out_dir)?; println!( diff --git a/postgresql_archive/Cargo.toml b/postgresql_archive/Cargo.toml index 5651d66..47211b7 100644 --- a/postgresql_archive/Cargo.toml +++ b/postgresql_archive/Cargo.toml @@ -24,6 +24,7 @@ reqwest = { workspace = true, default-features = false, features = ["json"] } reqwest-middleware = { workspace = true } reqwest-retry = { workspace = true } reqwest-tracing = { workspace = true } +semver = { workspace = true } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } sha2 = { workspace = true } @@ -33,6 +34,7 @@ tempfile = { workspace = true } thiserror = { workspace = true } tokio = { workspace = true, features = ["full"], optional = true } tracing = { workspace = true, features = ["log"] } +url = "2.5.2" [dev-dependencies] criterion = { workspace = true } diff --git a/postgresql_archive/README.md b/postgresql_archive/README.md index 5b2f67d..b0b0737 100644 --- a/postgresql_archive/README.md +++ b/postgresql_archive/README.md @@ -98,12 +98,6 @@ Licensed under either of at your option. -PostgreSQL is covered under [The PostgreSQL License](https://opensource.org/licenses/postgresql). - -## Notes - -Uses PostgreSQL binaries from [theseus-rs/postgresql-binaries](https://github.com/theseus-rs/postgresql-binaries). - ## Contribution Unless you explicitly state otherwise, any contribution intentionally submitted diff --git a/postgresql_archive/benches/archive.rs b/postgresql_archive/benches/archive.rs index 9234591..906374f 100644 --- a/postgresql_archive/benches/archive.rs +++ b/postgresql_archive/benches/archive.rs @@ -1,7 +1,7 @@ use bytes::Bytes; use criterion::{criterion_group, criterion_main, Criterion}; use postgresql_archive::blocking::{extract, get_archive}; -use postgresql_archive::{Result, DEFAULT_RELEASES_URL, LATEST}; +use postgresql_archive::{Result, VersionReq, DEFAULT_POSTGRESQL_URL}; use std::fs::{create_dir_all, remove_dir_all}; use std::time::Duration; @@ -10,8 +10,8 @@ fn benchmarks(criterion: &mut Criterion) { } fn bench_extract(criterion: &mut Criterion) -> Result<()> { - let version = &LATEST; - let (_archive_version, archive) = get_archive(DEFAULT_RELEASES_URL, version)?; + let version_req = VersionReq::STAR; + let (_archive_version, archive) = get_archive(DEFAULT_POSTGRESQL_URL, &version_req)?; criterion.bench_function("extract", |bencher| { bencher.iter(|| { diff --git a/postgresql_archive/src/archive.rs b/postgresql_archive/src/archive.rs index fc54cee..2ae3c6a 100644 --- a/postgresql_archive/src/archive.rs +++ b/postgresql_archive/src/archive.rs @@ -1,284 +1,78 @@ -//! Manage PostgreSQL archive +//! Manage PostgreSQL archives #![allow(dead_code)] -use crate::error::Error::{AssetHashNotFound, AssetNotFound, ReleaseNotFound, Unexpected}; +use crate::error::Error::Unexpected; use crate::error::Result; -use crate::github::{Asset, Release}; -use crate::version::Version; -use crate::Error::ArchiveHashMismatch; +use crate::repository; use bytes::Bytes; use flate2::bufread::GzDecoder; -use http::Extensions; use human_bytes::human_bytes; use num_format::{Locale, ToFormattedString}; -use regex::Regex; -use reqwest::{header, Request, Response}; -use reqwest_middleware::{ClientBuilder, ClientWithMiddleware, Middleware, Next}; -use reqwest_retry::policies::ExponentialBackoff; -use reqwest_retry::RetryTransientMiddleware; -use reqwest_tracing::TracingMiddleware; -use sha2::{Digest, Sha256}; +use semver::{Version, VersionReq}; use std::fs::{create_dir_all, remove_dir_all, remove_file, rename, File}; use std::io::{copy, BufReader, Cursor}; use std::path::{Path, PathBuf}; -use std::str::FromStr; use std::thread::sleep; use std::time::Duration; use tar::Archive; use tracing::{debug, instrument, warn}; -const GITHUB_API_VERSION_HEADER: &str = "X-GitHub-Api-Version"; -const GITHUB_API_VERSION: &str = "2022-11-28"; -pub const DEFAULT_RELEASES_URL: &str = - "https://api.github.com/repos/theseus-rs/postgresql-binaries/releases"; +pub const DEFAULT_POSTGRESQL_URL: &str = "https://github.com/theseus-rs/postgresql-binaries"; -lazy_static! { - static ref GITHUB_TOKEN: Option = match std::env::var("GITHUB_TOKEN") { - Ok(token) => { - debug!("GITHUB_TOKEN environment variable found"); - Some(token) - } - Err(_) => None, - }; -} - -lazy_static! { - static ref USER_AGENT: String = format!( - "{PACKAGE}/{VERSION}", - PACKAGE = env!("CARGO_PKG_NAME"), - VERSION = env!("CARGO_PKG_VERSION") - ); -} - -/// Middleware to add GitHub headers to the request. If a GitHub token is set, then it is added as a -/// bearer token. This is used to authenticate with the GitHub API to increase the rate limit. -#[derive(Debug)] -struct GithubMiddleware; - -impl GithubMiddleware { - #[allow(clippy::unnecessary_wraps)] - fn add_github_headers(request: &mut Request) -> Result<()> { - let headers = request.headers_mut(); - - headers.append( - GITHUB_API_VERSION_HEADER, - GITHUB_API_VERSION.parse().unwrap(), - ); - headers.append(header::USER_AGENT, USER_AGENT.parse().unwrap()); - - if let Some(token) = &*GITHUB_TOKEN { - headers.append( - header::AUTHORIZATION, - format!("Bearer {token}").parse().unwrap(), - ); - } - - Ok(()) - } -} - -#[async_trait::async_trait] -impl Middleware for GithubMiddleware { - async fn handle( - &self, - mut request: Request, - extensions: &mut Extensions, - next: Next<'_>, - ) -> reqwest_middleware::Result { - match GithubMiddleware::add_github_headers(&mut request) { - Ok(()) => next.run(request, extensions).await, - Err(error) => Err(reqwest_middleware::Error::Middleware(error.into())), - } - } -} - -/// Creates a new reqwest client with middleware for tracing, GitHub, and retrying transient errors. -fn reqwest_client() -> ClientWithMiddleware { - let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3); - ClientBuilder::new(reqwest::Client::new()) - .with(TracingMiddleware::default()) - .with(GithubMiddleware) - .with(RetryTransientMiddleware::new_with_policy(retry_policy)) - .build() -} - -/// Gets a release from GitHub for a given [version](Version) of PostgreSQL. If a release for the -/// [version](Version) is not found, then a [ReleaseNotFound] error is returned. -#[instrument(level = "debug")] -async fn get_release(releases_url: &str, version: &Version) -> Result { - let client = reqwest_client(); - - debug!("Attempting to locate release for version {version}"); - - if version.minor.is_some() && version.release.is_some() { - let request = client.get(format!("{releases_url}/tags/{version}")); - let response = request.send().await?.error_for_status()?; - let release = response.json::().await?; - - debug!("Release found for version {version}"); - return Ok(release); - } - - let mut result: Option = None; - let mut page = 1; - - loop { - let request = client - .get(releases_url) - .query(&[("page", page.to_string().as_str()), ("per_page", "100")]); - let response = request.send().await?.error_for_status()?; - let response_releases = response.json::>().await?; - if response_releases.is_empty() { - break; - } - - for release in response_releases { - let Ok(release_version) = Version::from_str(&release.tag_name) else { - warn!("Failed to parse release version {}", release.tag_name); - continue; - }; - - if version.matches(&release_version) { - match &result { - Some(result_release) => { - let result_version = Version::from_str(&result_release.tag_name)?; - if release_version > result_version { - result = Some(release); - } - } - None => { - result = Some(release); - } - } - } - } - - page += 1; - } - - match result { - Some(release) => { - let release_version = Version::from_str(&release.tag_name)?; - debug!("Release {release_version} found for version {version}"); - Ok(release) - } - None => Err(ReleaseNotFound(version.to_string())), - } -} - -/// Gets the version of PostgreSQL for the specified [version](Version). If the version minor or release is not -/// specified, then the latest version is returned. If a release for the [version](Version) is not found, then a -/// [ReleaseNotFound] error is returned. +/// Gets the version for the specified [version requirement](VersionReq). If a version for the +/// [version requirement](VersionReq) is not found, then a [ReleaseNotFound] error is returned. +/// +/// # Arguments +/// * `url` - The URL to released archives. +/// * `version_req` - The version requirement. +/// +/// # Returns +/// * The version matching the requirement. +/// +/// # Errors +/// * If the version is not found. #[instrument(level = "debug")] -pub async fn get_version(releases_url: &str, version: &Version) -> Result { - let release = get_release(releases_url, version).await?; - Version::from_str(&release.tag_name) +pub async fn get_version(url: &str, version_req: &VersionReq) -> Result { + let repository = repository::registry::get(url)?; + let version = repository.get_version(version_req).await?; + Ok(version) } -/// Gets the assets for a given [version](Version) of PostgreSQL and -/// [target](https://doc.rust-lang.org/nightly/rustc/platform-support.html). -/// If the [version](Version) or [target](https://doc.rust-lang.org/nightly/rustc/platform-support.html) -/// is not found, then an [error](crate::error::Error) is returned. +/// Gets the archive for a given [version requirement](VersionReq) that passes the default +/// matcher. If no archive is found for the [version requirement](VersionReq) and matcher then +/// an [error](crate::error::Error) is returned. /// -/// Two assets are returned. The first [asset](Asset) is the archive, and the second [asset](Asset) is the archive hash. -#[instrument(level = "debug", skip(target))] -async fn get_asset>( - releases_url: &str, - version: &Version, - target: S, -) -> Result<(Version, Asset, Asset)> { - let release = get_release(releases_url, version).await?; - let asset_version = Version::from_str(&release.tag_name)?; - let mut asset: Option = None; - let mut asset_hash: Option = None; - let asset_name = format!("postgresql-{}-{}.tar.gz", asset_version, target.as_ref()); - let asset_hash_name = format!("{asset_name}.sha256"); - - for release_asset in release.assets { - if release_asset.name == asset_name { - asset = Some(release_asset); - } else if release_asset.name == asset_hash_name { - asset_hash = Some(release_asset); - } - - if asset.is_some() && asset_hash.is_some() { - break; - } - } - - match (asset, asset_hash) { - (Some(asset), Some(asset_hash)) => Ok((asset_version, asset, asset_hash)), - (_, None) | (None, _) => Err(AssetNotFound(asset_name.to_string())), - } -} - -/// Gets the archive for a given [version](Version) of PostgreSQL for the current target. -/// If the [version](Version) is not found for this target, then an -/// [error](crate::error::Error) is returned. +/// # Arguments +/// * `url` - The URL to the archive resources. +/// * `version_req` - The version requirement. /// -/// Returns the archive version and bytes. -#[instrument] -pub async fn get_archive(releases_url: &str, version: &Version) -> Result<(Version, Bytes)> { - get_archive_for_target(releases_url, version, target_triple::TARGET).await -} - -/// Gets the archive for a given [version](Version) of PostgreSQL and -/// [target](https://doc.rust-lang.org/nightly/rustc/platform-support.html). -/// If the [version](Version) or [target](https://doc.rust-lang.org/nightly/rustc/platform-support.html) -/// is not found, then an [error](crate::error::Error) is returned. +/// # Returns +/// * The archive version and bytes. /// -/// Returns the archive version and bytes. -#[allow(clippy::cast_precision_loss)] -#[instrument(level = "debug", skip(target))] -pub async fn get_archive_for_target>( - releases_url: &str, - version: &Version, - target: S, -) -> Result<(Version, Bytes)> { - let (asset_version, asset, asset_hash) = get_asset(releases_url, version, target).await?; - - debug!( - "Downloading archive hash {}", - asset_hash.browser_download_url - ); - let client = reqwest_client(); - let request = client.get(&asset_hash.browser_download_url); - let response = request.send().await?.error_for_status()?; - let text = response.text().await?; - let re = Regex::new(r"[0-9a-f]{64}")?; - let hash = match re.find(&text) { - Some(hash) => hash.as_str().to_string(), - None => return Err(AssetHashNotFound(asset.name)), - }; - debug!( - "Archive hash {} downloaded: {}", - asset_hash.browser_download_url, - human_bytes(text.len() as f64) - ); - - debug!("Downloading archive {}", asset.browser_download_url); - let request = client.get(&asset.browser_download_url); - let response = request.send().await?.error_for_status()?; - let archive: Bytes = response.bytes().await?; - debug!( - "Archive {} downloaded: {}", - asset.browser_download_url, - human_bytes(archive.len() as f64) - ); - - let mut hasher = Sha256::new(); - hasher.update(&archive); - let archive_hash = hex::encode(hasher.finalize()); - - if archive_hash != hash { - return Err(ArchiveHashMismatch { archive_hash, hash }); - } - - Ok((asset_version, archive)) +/// # Errors +/// * If the archive is not found. +/// * If the archive cannot be downloaded. +#[instrument] +pub async fn get_archive(url: &str, version_req: &VersionReq) -> Result<(Version, Bytes)> { + let repository = repository::registry::get(url)?; + let archive = repository.get_archive(version_req).await?; + let version = archive.version().clone(); + let archive_bytes = archive.bytes().to_vec(); + let bytes = Bytes::from(archive_bytes.clone()); + Ok((version, bytes)) } /// Acquires a lock file in the [out_dir](Path) to prevent multiple processes from extracting the /// archive at the same time. +/// +/// # Arguments +/// * `out_dir` - The directory to extract the archive to. +/// +/// # Returns +/// * The lock file. +/// +/// # Errors +/// * If the lock file cannot be acquired. #[instrument(level = "debug")] fn acquire_lock(out_dir: &Path) -> Result { let lock_file = out_dir.join("postgresql-archive.lock"); @@ -324,6 +118,16 @@ fn acquire_lock(out_dir: &Path) -> Result { } /// Extracts the compressed tar [bytes](Bytes) to the [out_dir](Path). +/// +/// # Arguments +/// * `bytes` - The compressed tar bytes. +/// * `out_dir` - The directory to extract the tar to. +/// +/// # Returns +/// * The extracted files. +/// +/// # Errors +/// Returns an error if the extraction fails. #[allow(clippy::cast_precision_loss)] #[instrument(skip(bytes))] pub async fn extract(bytes: &Bytes, out_dir: &Path) -> Result<()> { @@ -434,51 +238,21 @@ pub async fn extract(bytes: &Bytes, out_dir: &Path) -> Result<()> { #[cfg(test)] mod tests { use super::*; - use test_log::test; - - /// Use a known, fully defined version to speed up test execution - const VERSION: Version = Version::new(16, Some(1), Some(0)); - const INVALID_VERSION: Version = Version::new(1, Some(0), Some(0)); - - #[test(tokio::test)] - async fn test_get_release() -> Result<()> { - let _ = get_release(DEFAULT_RELEASES_URL, &VERSION).await?; - Ok(()) - } - - #[test(tokio::test)] - async fn test_get_release_version_not_found() -> Result<()> { - let release = get_release(DEFAULT_RELEASES_URL, &INVALID_VERSION).await; - assert!(release.is_err()); - Ok(()) - } - - #[test(tokio::test)] - async fn test_get_asset() -> Result<()> { - let target_triple = "x86_64-unknown-linux-musl".to_string(); - let (asset_version, asset, asset_hash) = - get_asset(DEFAULT_RELEASES_URL, &VERSION, &target_triple).await?; - assert!(asset_version.matches(&VERSION)); - assert!(asset.name.contains(&target_triple)); - assert!(asset_hash.name.contains(&target_triple)); - assert!(asset_hash.name.starts_with(asset.name.as_str())); - assert!(asset_hash.name.ends_with(".sha256")); - Ok(()) - } - #[test(tokio::test)] - async fn test_get_asset_version_not_found() -> Result<()> { - let target_triple = "x86_64-unknown-linux-musl".to_string(); - let result = get_asset(DEFAULT_RELEASES_URL, &INVALID_VERSION, &target_triple).await; - assert!(result.is_err()); + #[tokio::test] + async fn test_get_version() -> Result<()> { + let version_req = VersionReq::parse("=16.3.0")?; + let version = get_version(DEFAULT_POSTGRESQL_URL, &version_req).await?; + assert_eq!(Version::new(16, 3, 0), version); Ok(()) } - #[test(tokio::test)] - async fn test_get_asset_target_not_found() -> Result<()> { - let target_triple = "wasm64-unknown-unknown".to_string(); - let result = get_asset(DEFAULT_RELEASES_URL, &VERSION, &target_triple).await; - assert!(result.is_err()); + #[tokio::test] + async fn test_get_archive() -> Result<()> { + let version_req = VersionReq::parse("=16.3.0")?; + let (version, bytes) = get_archive(DEFAULT_POSTGRESQL_URL, &version_req).await?; + assert_eq!(Version::new(16, 3, 0), version); + assert!(!bytes.is_empty()); Ok(()) } } diff --git a/postgresql_archive/src/blocking/archive.rs b/postgresql_archive/src/blocking/archive.rs index c46c49c..df80477 100644 --- a/postgresql_archive/src/blocking/archive.rs +++ b/postgresql_archive/src/blocking/archive.rs @@ -1,4 +1,4 @@ -use crate::Version; +use crate::{Version, VersionReq}; use bytes::Bytes; use std::path::Path; use tokio::runtime::Runtime; @@ -7,58 +7,54 @@ lazy_static! { static ref RUNTIME: Runtime = Runtime::new().unwrap(); } -/// Gets the version of PostgreSQL for the specified [version](Version). If the version minor or release is not -/// specified, then the latest version is returned. If a release for the [version](Version) is not found, then a -/// [ReleaseNotFound](crate::Error::ReleaseNotFound) error is returned. +/// Gets the version for the specified [version requirement](VersionReq). If a version for the +/// [version requirement](VersionReq) is not found, then a [ReleaseNotFound] error is returned. /// -/// # Errors -/// -/// Returns an error if the version is not found. -pub fn get_version(releases_url: &str, version: &Version) -> crate::Result { - RUNTIME - .handle() - .block_on(async move { crate::get_version(releases_url, version).await }) -} - -/// Gets the archive for a given [version](Version) of PostgreSQL for the current target. -/// If the [version](Version) is not found for this target, then an -/// [error](crate::Error) is returned. +/// # Arguments +/// * `url` - The URL to released archives. +/// * `version_req` - The version requirement. /// -/// Returns the archive version and bytes. +/// # Returns +/// * The version matching the requirement. /// /// # Errors -/// -/// Returns an error if the version is not found. -pub fn get_archive(releases_url: &str, version: &Version) -> crate::Result<(Version, Bytes)> { +/// * If the version is not found. +pub fn get_version(url: &str, version_req: &VersionReq) -> crate::Result { RUNTIME .handle() - .block_on(async move { crate::get_archive(releases_url, version).await }) + .block_on(async move { crate::get_version(url, version_req).await }) } -/// Gets the archive for a given [version](Version) of PostgreSQL and -/// [target](https://doc.rust-lang.org/nightly/rustc/platform-support.html). -/// If the [version](Version) or [target](https://doc.rust-lang.org/nightly/rustc/platform-support.html) -/// is not found, then an [error](crate::error::Error) is returned. +/// Gets the archive for a given [version requirement](VersionReq) that passes the default +/// matcher. If no archive is found for the [version requirement](VersionReq) and matcher then +/// an [error](crate::error::Error) is returned. /// -/// Returns the archive version and bytes. +/// # Arguments +/// * `url` - The URL to the archive resources. +/// * `version_req` - The version requirement. /// -/// # Errors +/// # Returns +/// * The archive version and bytes. /// -/// Returns an error if the version or target is not found. -pub fn get_archive_for_target>( - releases_url: &str, - version: &Version, - target: S, -) -> crate::Result<(Version, Bytes)> { +/// # Errors +/// * If the archive is not found. +/// * If the archive cannot be downloaded. +pub fn get_archive(url: &str, version_req: &VersionReq) -> crate::Result<(Version, Bytes)> { RUNTIME .handle() - .block_on(async move { crate::get_archive_for_target(releases_url, version, target).await }) + .block_on(async move { crate::get_archive(url, version_req).await }) } /// Extracts the compressed tar [bytes](Bytes) to the [out_dir](Path). /// -/// # Errors +/// # Arguments +/// * `bytes` - The compressed tar bytes. +/// * `out_dir` - The directory to extract the tar to. +/// +/// # Returns +/// * The extracted files. /// +/// # Errors /// Returns an error if the extraction fails. pub fn extract(bytes: &Bytes, out_dir: &Path) -> crate::Result<()> { RUNTIME diff --git a/postgresql_archive/src/blocking/mod.rs b/postgresql_archive/src/blocking/mod.rs index 21664a4..2264a73 100644 --- a/postgresql_archive/src/blocking/mod.rs +++ b/postgresql_archive/src/blocking/mod.rs @@ -1,3 +1,3 @@ mod archive; -pub use archive::{extract, get_archive, get_archive_for_target, get_version}; +pub use archive::{extract, get_archive, get_version}; diff --git a/postgresql_archive/src/error.rs b/postgresql_archive/src/error.rs index 30da755..b383889 100644 --- a/postgresql_archive/src/error.rs +++ b/postgresql_archive/src/error.rs @@ -5,16 +5,16 @@ pub type Result = core::result::Result; #[derive(Debug, thiserror::Error)] pub enum Error { /// Asset not found - #[error("asset [{0}] not found")] - AssetNotFound(String), + #[error("asset not found")] + AssetNotFound, /// Asset hash not found - #[error("asset hash not found for asset [{0}]")] + #[error("asset hash not found for asset '{0}'")] AssetHashNotFound(String), /// Error when the hash of the archive does not match the expected hash #[error("Archive hash [{archive_hash}] does not match expected hash [{hash}]")] ArchiveHashMismatch { archive_hash: String, hash: String }, /// Invalid version - #[error("version [{0}] is invalid")] + #[error("version '{0}' is invalid")] InvalidVersion(String), /// IO error #[error(transparent)] @@ -23,11 +23,17 @@ pub enum Error { #[error(transparent)] ParseError(anyhow::Error), /// Release not found - #[error("release not found for version [{0}]")] + #[error("release not found for '{0}'")] ReleaseNotFound(String), + /// Repository failure + #[error("{0}")] + RepositoryFailure(String), /// Unexpected error #[error("{0}")] Unexpected(String), + /// Unsupported repository + #[error("unsupported repository for '{0}'")] + UnsupportedRepository(String), } /// Converts a [`regex::Error`] into an [`ParseError`](Error::ParseError) @@ -72,6 +78,13 @@ impl From for Error { } } +/// Converts a [`semver::Error`] into an [`ParseError`](Error::ParseError) +impl From for Error { + fn from(error: semver::Error) -> Self { + Error::IoError(error.into()) + } +} + /// Converts a [`std::path::StripPrefixError`] into an [`ParseError`](Error::ParseError) impl From for Error { fn from(error: std::path::StripPrefixError) -> Self { @@ -86,12 +99,20 @@ impl From for Error { } } +/// Converts a [`url::ParseError`] into an [`ParseError`](Error::ParseError) +impl From for Error { + fn from(error: url::ParseError) -> Self { + Error::ParseError(error.into()) + } +} + /// These are relatively low value tests; they are here to reduce the coverage gap and /// ensure that the error conversions are working as expected. #[cfg(test)] mod test { use super::*; use anyhow::anyhow; + use semver::VersionReq; use std::ops::Add; use std::path::PathBuf; use std::str::FromStr; @@ -136,6 +157,16 @@ mod test { assert_eq!(error.to_string(), "invalid digit found in string"); } + #[test] + fn test_from_semver_error() { + let semver_error = VersionReq::parse("foo").expect_err("semver error"); + let error = Error::from(semver_error); + assert_eq!( + error.to_string(), + "unexpected character 'f' while parsing major version number" + ); + } + #[test] fn test_from_strip_prefix_error() { let path = PathBuf::from("test"); @@ -163,4 +194,11 @@ mod test { let error = Error::from(anyhow_error); assert_eq!(error.to_string(), "test"); } + + #[test] + fn test_from_url_parse_error() { + let parse_error = url::ParseError::EmptyHost; + let error = Error::from(parse_error); + assert_eq!(error.to_string(), "empty host"); + } } diff --git a/postgresql_archive/src/lib.rs b/postgresql_archive/src/lib.rs index 8d8170d..51fc33b 100644 --- a/postgresql_archive/src/lib.rs +++ b/postgresql_archive/src/lib.rs @@ -21,11 +21,11 @@ //! ### Asynchronous API //! //! ```no_run -//! use postgresql_archive::{extract, get_archive, Result, DEFAULT_RELEASES_URL, LATEST}; +//! use postgresql_archive::{extract, get_archive, Result, VersionReq, DEFAULT_POSTGRESQL_URL}; //! //! #[tokio::main] //! async fn main() -> Result<()> { -//! let (archive_version, archive) = get_archive(DEFAULT_RELEASES_URL, &LATEST).await?; +//! let (archive_version, archive) = get_archive(DEFAULT_POSTGRESQL_URL, &VersionReq::STAR).await?; //! let out_dir = std::env::temp_dir(); //! extract(&archive, &out_dir).await //! } @@ -34,10 +34,10 @@ //! ### Synchronous API //! ```no_run //! #[cfg(feature = "blocking")] { -//! use postgresql_archive::{DEFAULT_RELEASES_URL, LATEST}; +//! use postgresql_archive::{VersionReq, DEFAULT_POSTGRESQL_URL}; //! use postgresql_archive::blocking::{extract, get_archive}; //! -//! let (archive_version, archive) = get_archive(DEFAULT_RELEASES_URL, &LATEST).unwrap(); +//! let (archive_version, archive) = get_archive(DEFAULT_POSTGRESQL_URL, &VersionReq::STAR).unwrap(); //! let out_dir = std::env::temp_dir(); //! let result = extract(&archive, &out_dir).unwrap(); //! } @@ -50,9 +50,11 @@ //! //! The following features are available: //! -//! | Name | Description | Default? | -//! |--------------|--------------------------|----------| -//! | `blocking` | Enables the blocking API | No | +//! | Name | Description | Default? | +//! |--------------|----------------------------|----------| +//! | `blocking` | Enables the blocking API | No | +//! | `native-tls` | Enables native-tls support | No | +//! | `rustls-tls` | Enables rustls-tls support | Yes | //! //! ## Supported platforms //! @@ -98,10 +100,6 @@ //! at your option. //! //! PostgreSQL is covered under [The PostgreSQL License](https://opensource.org/licenses/postgresql). -//! -//! ## Notes -//! -//! Uses PostgreSQL binaries from [theseus-rs/postgresql-binaries](https://github.com/theseus-rs/postgresql-binaries). #![forbid(unsafe_code)] #![deny(clippy::pedantic)] @@ -115,11 +113,12 @@ mod archive; #[cfg(feature = "blocking")] pub mod blocking; mod error; -mod github; +mod repository; mod version; -pub use archive::DEFAULT_RELEASES_URL; -pub use archive::{extract, get_archive, get_archive_for_target, get_version}; +pub use archive::DEFAULT_POSTGRESQL_URL; +pub use archive::{extract, get_archive, get_version}; pub use error::{Error, Result}; -#[allow(deprecated)] -pub use version::{Version, LATEST, V12, V13, V14, V15, V16}; +pub use repository::{Archive, Repository}; +pub use semver::{Version, VersionReq}; +pub use version::{ExactVersion, ExactVersionReq}; diff --git a/postgresql_archive/src/repository/github/mod.rs b/postgresql_archive/src/repository/github/mod.rs new file mode 100644 index 0000000..ab588ca --- /dev/null +++ b/postgresql_archive/src/repository/github/mod.rs @@ -0,0 +1,2 @@ +pub(crate) mod models; +pub(crate) mod repository; diff --git a/postgresql_archive/src/github.rs b/postgresql_archive/src/repository/github/models.rs similarity index 100% rename from postgresql_archive/src/github.rs rename to postgresql_archive/src/repository/github/models.rs diff --git a/postgresql_archive/src/repository/github/repository.rs b/postgresql_archive/src/repository/github/repository.rs new file mode 100644 index 0000000..bb8426f --- /dev/null +++ b/postgresql_archive/src/repository/github/repository.rs @@ -0,0 +1,554 @@ +use crate::repository::github::models::{Asset, Release}; +use crate::repository::model::Repository; +use crate::Error::{ + ArchiveHashMismatch, AssetHashNotFound, AssetNotFound, ReleaseNotFound, RepositoryFailure, +}; +use crate::{Archive, Result}; +use async_trait::async_trait; +use bytes::Bytes; +use http::{header, Extensions}; +use human_bytes::human_bytes; +use regex::Regex; +use reqwest::{Request, Response}; +use reqwest_middleware::{ClientBuilder, ClientWithMiddleware, Middleware, Next}; +use reqwest_retry::policies::ExponentialBackoff; +use reqwest_retry::RetryTransientMiddleware; +use reqwest_tracing::TracingMiddleware; +use semver::{Version, VersionReq}; +use sha2::{Digest, Sha256}; +use std::env; +use std::str::FromStr; +use tracing::{debug, instrument, warn}; +use url::Url; + +const GITHUB_API_VERSION_HEADER: &str = "X-GitHub-Api-Version"; +const GITHUB_API_VERSION: &str = "2022-11-28"; + +lazy_static! { + static ref GITHUB_TOKEN: Option = match std::env::var("GITHUB_TOKEN") { + Ok(token) => { + debug!("GITHUB_TOKEN environment variable found"); + Some(token) + } + Err(_) => None, + }; +} + +lazy_static! { + static ref USER_AGENT: String = format!( + "{PACKAGE}/{VERSION}", + PACKAGE = env!("CARGO_PKG_NAME"), + VERSION = env!("CARGO_PKG_VERSION") + ); +} + +/// GitHub repository. +/// +/// This repository is used to interact with GitHub. The configuration url should be +/// in the format "https://github.com//" +/// (e.g. https://github.com/theseus-rs/postgresql-binaries). +#[derive(Debug)] +pub(crate) struct GitHub { + url: String, +} + +impl GitHub { + /// Creates a new GitHub repository from the specified URL. + /// + /// # Arguments + /// * `url` - The URL to the GitHub repository in the format "https://github.com//" + /// + /// # Returns + /// * The GitHub repository. + /// + /// # Errors + /// * If the URL is invalid. + #[allow(clippy::new_ret_no_self)] + pub fn new(url: &str) -> Result> { + let parsed_url = Url::parse(url)?; + let path = parsed_url.path().trim_start_matches('/'); + let path_parts = path.split('/').collect::>(); + let owner = (*path_parts + .first() + .ok_or_else(|| RepositoryFailure(format!("No owner in URL {url}")))?) + .to_string(); + let repo = (*path_parts + .get(1) + .ok_or_else(|| RepositoryFailure(format!("No repo in URL {url}")))?) + .to_string(); + let url = format!("https://api.github.com/repos/{owner}/{repo}/releases"); + + Ok(Box::new(Self { + url: url.to_string(), + })) + } + + /// Determines if the specified URL is supported by the GitHub repository. + /// + /// # Arguments + /// * `url` - The URL to check for support. + /// + /// # Returns + /// * Whether the URL is supported. + /// + /// # Errors + /// * If the URL cannot be parsed. + pub fn supports(url: &str) -> bool { + let Ok(parsed_url) = Url::parse(url) else { + return false; + }; + let host = parsed_url.host_str().unwrap_or_default(); + host.contains("github.com") + } + + /// Gets the version from the specified tag name. + /// + /// # Arguments + /// * `tag_name` - The tag name. + /// + /// # Returns + /// * The version. + /// + /// # Errors + /// * If the version cannot be parsed. + fn get_version_from_tag_name(tag_name: &str) -> Result { + // Trim and prefix characters from the tag name (e.g., "v16.3.0" -> "16.3.0"). + let tag_name = tag_name.trim_start_matches(|c: char| !c.is_numeric()); + match Version::from_str(tag_name) { + Ok(version) => Ok(version), + Err(error) => { + warn!("Failed to parse version {tag_name}"); + Err(error.into()) + } + } + } + + /// Gets the release for the specified [version requirement](VersionReq). If a release for the + /// [version requirement](VersionReq) is not found, then a [ReleaseNotFound] error is returned. + /// + /// # Arguments + /// * `version_req` - The version requirement. + /// + /// # Returns + /// * The release matching the requirement. + /// + /// # Errors + /// * If the release is not found. + #[instrument(level = "debug")] + async fn get_release(&self, version_req: &VersionReq) -> Result { + debug!("Attempting to locate release for version requirement {version_req}"); + let client = reqwest_client(); + let mut result: Option = None; + let mut page = 1; + + loop { + let request = client + .get(&self.url) + .query(&[("page", page.to_string().as_str()), ("per_page", "100")]); + let response = request.send().await?.error_for_status()?; + let response_releases = response.json::>().await?; + if response_releases.is_empty() { + break; + } + + for release in response_releases { + let tag_name = release.tag_name.clone(); + let Ok(release_version) = Self::get_version_from_tag_name(tag_name.as_str()) else { + warn!("Failed to parse release version {tag_name}"); + continue; + }; + + if version_req.matches(&release_version) { + if let Some(result_release) = &result { + let result_version = + Self::get_version_from_tag_name(result_release.tag_name.as_str())?; + if release_version > result_version { + result = Some(release); + } + } else { + result = Some(release); + } + } + } + + page += 1; + } + + match result { + Some(release) => { + let release_version = Self::get_version_from_tag_name(&release.tag_name)?; + debug!("Release {release_version} found for version requirement {version_req}"); + Ok(release) + } + None => Err(ReleaseNotFound(version_req.to_string())), + } + } + + /// Gets the asset for the specified release that passes the supplied matcher. If an asset for + /// that passes the matcher is not found, then an [AssetNotFound] error is returned. + /// + /// # Arguments + /// * `release` - The release. + /// * `matcher` - The matcher function. + /// + /// # Returns + /// * The asset and hash asset. + /// + /// # Errors + /// * If the asset is not found. + #[instrument(level = "debug", skip(release, matcher))] + fn get_asset( + release: Release, + matcher: impl Fn(&str) -> Result, + ) -> Result<(Asset, Option)> { + let mut release_asset: Option = None; + for asset in &release.assets { + if matcher(asset.name.as_str())? { + release_asset = Some(asset.clone()); + break; + } + } + + let Some(asset) = release_asset else { + return Err(AssetNotFound); + }; + + let mut asset_hash: Option = None; + for asset in release.assets { + if asset.name.ends_with(".sha256") { + asset_hash = Some(asset.clone()); + break; + } + } + + Ok((asset, asset_hash)) + } +} + +#[async_trait] +impl Repository for GitHub { + #[instrument(level = "debug")] + fn name(&self) -> &str { + "GitHub" + } + + #[instrument(level = "debug")] + async fn get_version(&self, version_req: &VersionReq) -> Result { + let release = self.get_release(version_req).await?; + let version = Self::get_version_from_tag_name(release.tag_name.as_str())?; + Ok(version) + } + + #[instrument] + #[allow(clippy::cast_precision_loss)] + async fn get_archive(&self, version_req: &VersionReq) -> Result { + let release = self.get_release(version_req).await?; + let version = Self::get_version_from_tag_name(release.tag_name.as_str())?; + let (asset, asset_hash) = Self::get_asset(release, asset_matcher)?; + let name = asset.name.clone(); + + let client = reqwest_client(); + debug!("Downloading archive {}", asset.browser_download_url); + let request = client.get(&asset.browser_download_url); + let response = request.send().await?.error_for_status()?; + let archive: Bytes = response.bytes().await?; + let bytes = archive.to_vec(); + debug!( + "Archive {} downloaded: {}", + asset.browser_download_url, + human_bytes(archive.len() as f64) + ); + + if let Some(asset_hash) = asset_hash { + debug!( + "Downloading archive hash {}", + asset_hash.browser_download_url + ); + let request = client.get(&asset_hash.browser_download_url); + let response = request.send().await?.error_for_status()?; + let text = response.text().await?; + let re = Regex::new(r"[0-9a-f]{64}")?; + let hash = match re.find(&text) { + Some(hash) => hash.as_str().to_string(), + None => return Err(AssetHashNotFound(asset.name)), + }; + debug!( + "Archive hash {} downloaded: {}", + asset_hash.browser_download_url, + human_bytes(text.len() as f64) + ); + + let mut hasher = Sha256::new(); + hasher.update(&archive); + let archive_hash = hex::encode(hasher.finalize()); + + if archive_hash != hash { + return Err(ArchiveHashMismatch { archive_hash, hash }); + } + } + + let archive = Archive::new(name, version, bytes); + Ok(archive) + } +} + +/// Middleware to add GitHub headers to the request. If a GitHub token is set, then it is added as a +/// bearer token. This is used to authenticate with the GitHub API to increase the rate limit. +#[derive(Debug)] +struct GithubMiddleware; + +impl GithubMiddleware { + #[allow(clippy::unnecessary_wraps)] + fn add_github_headers(request: &mut Request) -> Result<()> { + let headers = request.headers_mut(); + + headers.append( + GITHUB_API_VERSION_HEADER, + GITHUB_API_VERSION.parse().unwrap(), + ); + headers.append(header::USER_AGENT, USER_AGENT.parse().unwrap()); + + if let Some(token) = &*GITHUB_TOKEN { + headers.append( + header::AUTHORIZATION, + format!("Bearer {token}").parse().unwrap(), + ); + } + + Ok(()) + } +} + +#[async_trait::async_trait] +impl Middleware for GithubMiddleware { + async fn handle( + &self, + mut request: Request, + extensions: &mut Extensions, + next: Next<'_>, + ) -> reqwest_middleware::Result { + match GithubMiddleware::add_github_headers(&mut request) { + Ok(()) => next.run(request, extensions).await, + Err(error) => Err(reqwest_middleware::Error::Middleware(error.into())), + } + } +} + +/// Creates a new reqwest client with middleware for tracing, GitHub, and retrying transient errors. +fn reqwest_client() -> ClientWithMiddleware { + let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3); + ClientBuilder::new(reqwest::Client::new()) + .with(TracingMiddleware::default()) + .with(GithubMiddleware) + .with(RetryTransientMiddleware::new_with_policy(retry_policy)) + .build() +} + +/// Matcher assets. +/// +/// # Arguments +/// * `name` - The name of the asset. +/// +/// # Returns +/// * Whether the asset matches. +/// +/// # Errors +/// * If the asset matcher fails. +fn asset_matcher(name: &str) -> Result { + if !name.ends_with(".tar.gz") { + return Ok(false); + } + let target_re = regex(target_triple::TARGET)?; + if target_re.is_match(name) { + return Ok(true); + } + let os = env::consts::OS; + let os_re = regex(os)?; + let matches_os = match os { + "macos" => { + let darwin_re = regex("darwin")?; + os_re.is_match(name) || darwin_re.is_match(name) + } + _ => os_re.is_match(name), + }; + let arch = env::consts::ARCH; + let arch_re = regex(arch)?; + let matches_arch = match arch { + "x86_64" => { + let amd64_re = regex("amd64")?; + arch_re.is_match(name) || amd64_re.is_match(name) + } + "aarch64" => { + let arm64_re = regex("arm64")?; + arch_re.is_match(name) || arm64_re.is_match(name) + } + _ => arch_re.is_match(name), + }; + if matches_os && matches_arch { + return Ok(true); + } + Ok(false) +} + +/// Creates a new regex for the specified key. +/// +/// # Arguments +/// * `key` - The key to create the regex for. +/// +/// # Returns +/// * The regex. +/// +/// # Errors +/// * If the regex cannot be created. +fn regex(key: &str) -> Result { + let regex = Regex::new(format!(r"[\W_]{key}[\W_]").as_str())?; + Ok(regex) +} + +#[cfg(test)] +mod tests { + use super::*; + + const URL: &str = "https://github.com/theseus-rs/postgresql-binaries"; + + #[test] + fn test_supports() { + assert!(GitHub::supports(URL)); + } + + #[test] + fn test_supports_error() { + assert!(!GitHub::supports("https://foo.com")); + } + + #[test] + fn test_name() { + let github = GitHub::new(URL).unwrap(); + assert_eq!("GitHub", github.name()); + } + + // + // get_version tests + // + + #[tokio::test] + async fn test_get_version() -> Result<()> { + let github = GitHub::new(URL)?; + let version_req = VersionReq::STAR; + let version = github.get_version(&version_req).await?; + assert!(version > Version::new(0, 0, 0)); + Ok(()) + } + + #[tokio::test] + async fn test_get_specific_version() -> Result<()> { + let github = GitHub::new(URL)?; + let version_req = VersionReq::parse("=16.3.0")?; + let version = github.get_version(&version_req).await?; + assert_eq!(Version::new(16, 3, 0), version); + Ok(()) + } + + #[tokio::test] + async fn test_get_specific_not_found() -> Result<()> { + let github = GitHub::new(URL)?; + let version_req = VersionReq::parse("=0.0.0")?; + let error = github.get_version(&version_req).await.unwrap_err(); + assert_eq!("release not found for '=0.0.0'", error.to_string()); + Ok(()) + } + + // + // get_archive tests + // + + #[tokio::test] + async fn test_get_archive() -> Result<()> { + let github = GitHub::new(URL)?; + let version_req = VersionReq::parse("=16.3.0")?; + let archive = github.get_archive(&version_req).await?; + assert_eq!( + format!("postgresql-16.3.0-{}.tar.gz", target_triple::TARGET), + archive.name() + ); + assert_eq!(&Version::new(16, 3, 0), archive.version()); + assert!(!archive.bytes().is_empty()); + Ok(()) + } + + // + // Plugin Support + // + + /// Test that a version with a 'v' prefix is correctly parsed; this is a common convention + /// for GitHub releases. Use a known PostgreSQL plugin repository for the test. + #[tokio::test] + async fn test_get_version_with_v_prefix() -> Result<()> { + let github = GitHub::new("https://github.com/turbot/steampipe-plugin-csv")?; + let version_req = VersionReq::parse("=0.12.0")?; + let version = github.get_version(&version_req).await?; + assert_eq!(Version::new(0, 12, 0), version); + Ok(()) + } + + /// Test that a version with a 'v' prefix is correctly parsed; this is a common convention + /// for GitHub releases. Use a known PostgreSQL plugin repository for the test. + #[tokio::test] + async fn test_get_archive_with_v_prefix() -> Result<()> { + let github = GitHub::new("https://github.com/turbot/steampipe-plugin-csv")?; + let version_req = VersionReq::parse("=0.12.0")?; + let archive = github.get_archive(&version_req).await?; + let name = archive.name(); + // Note: this plugin repository has 3 artifacts that can match: + // steampipe_export... + // steampipe_postgres... + // steampipe_sqlite... + // custom matchers will be needed to disambiguate plugins + assert!(name.starts_with("steampipe_")); + assert!(name.contains("csv")); + assert!(name.ends_with(".tar.gz")); + assert_eq!(&Version::new(0, 12, 0), archive.version()); + assert!(!archive.bytes().is_empty()); + Ok(()) + } + + // + // asset matcher tests + // + + #[test] + fn test_asset_match_success() -> Result<()> { + let target = target_triple::TARGET; + let os = env::consts::OS; + let arch = env::consts::ARCH; + let names = vec![ + format!("postgresql-16.3.0-{target}.tar.gz"), + format!("postgresql-16.3.0-{os}-{arch}.tar.gz"), + format!("foo.{target}.tar.gz"), + format!("foo.{os}.{arch}.tar.gz"), + format!("foo-{arch}-{os}.tar.gz"), + ]; + for name in names { + assert!(asset_matcher(name.as_str())?, "{}", name); + } + Ok(()) + } + + #[test] + fn test_asset_match_errors() -> Result<()> { + let target = target_triple::TARGET; + let os = env::consts::OS; + let arch = env::consts::ARCH; + let names = vec![ + format!("foo{target}.tar.gz"), + format!("foo{os}-{arch}.tar.gz"), + format!("foo-{target}.tar"), + format!("foo-{os}-{arch}.tar"), + format!("foo-{os}{arch}.tar.gz"), + ]; + for name in names { + assert!(!asset_matcher(name.as_str())?, "{}", name); + } + Ok(()) + } +} diff --git a/postgresql_archive/src/repository/mod.rs b/postgresql_archive/src/repository/mod.rs new file mode 100644 index 0000000..cacd7e4 --- /dev/null +++ b/postgresql_archive/src/repository/mod.rs @@ -0,0 +1,5 @@ +mod github; +pub mod model; +pub mod registry; + +pub use model::{Archive, Repository}; diff --git a/postgresql_archive/src/repository/model.rs b/postgresql_archive/src/repository/model.rs new file mode 100644 index 0000000..9922542 --- /dev/null +++ b/postgresql_archive/src/repository/model.rs @@ -0,0 +1,114 @@ +use async_trait::async_trait; +use semver::{Version, VersionReq}; +use std::fmt::Debug; + +/// A trait for archive repository implementations. +#[async_trait] +pub trait Repository: Debug + Send + Sync { + /// Gets the name of the repository. + /// + /// # Returns + /// * The name of the repository. + fn name(&self) -> &str; + + /// Gets the version for the specified [version requirement](VersionReq). If a + /// [version](Version) for the [version requirement](VersionReq) is not found, + /// then a [ReleaseNotFound] error is returned. + /// + /// # Arguments + /// * `version_req` - The version requirement. + /// + /// # Returns + /// * The version matching the requirement. + /// + /// # Errors + /// * If the version is not found. + async fn get_version(&self, version_req: &VersionReq) -> crate::Result; + + /// Gets the archive for a given [version requirement](VersionReq) that passes the default + /// matcher. If no archive is found for the [version requirement](VersionReq) and matcher then + /// an [error](crate::error::Error) is returned. + /// + /// # Arguments + /// * `version_req` - The version requirement. + /// + /// # Returns + /// * The archive version and bytes. + /// + /// # Errors + /// * If the archive is not found. + /// * If the archive cannot be downloaded. + async fn get_archive(&self, version_req: &VersionReq) -> crate::Result; +} + +/// A struct representing an archive. +#[derive(Clone, Debug)] +pub struct Archive { + name: String, + version: Version, + bytes: Vec, +} + +impl Archive { + /// Creates a new archive. + /// + /// # Arguments + /// * `name` - The name of the archive. + /// * `version` - The version of the archive. + /// * `bytes` - The bytes of the archive. + /// + /// # Returns + /// * The archive. + #[must_use] + pub fn new(name: String, version: Version, bytes: Vec) -> Self { + Self { + name, + version, + bytes, + } + } + + /// Gets the name of the archive. + /// + /// # Returns + /// * The name of the archive. + #[must_use] + pub fn name(&self) -> &str { + &self.name + } + + /// Gets the version of the archive. + /// + /// # Returns + /// * The version of the archive. + #[must_use] + pub fn version(&self) -> &Version { + &self.version + } + + /// Gets the bytes of the archive. + /// + /// # Returns + /// * The bytes of the archive. + #[must_use] + pub fn bytes(&self) -> &[u8] { + &self.bytes + } +} + +#[cfg(test)] +mod tests { + use super::*; + use semver::Version; + + #[test] + fn test_archive() { + let name = "test".to_string(); + let version = Version::parse("1.0.0").unwrap(); + let bytes = vec![0, 1, 2, 3]; + let archive = Archive::new(name.clone(), version.clone(), bytes.clone()); + assert_eq!(archive.name(), name); + assert_eq!(archive.version(), &version); + assert_eq!(archive.bytes(), bytes.as_slice()); + } +} diff --git a/postgresql_archive/src/repository/registry.rs b/postgresql_archive/src/repository/registry.rs new file mode 100644 index 0000000..b4f1b47 --- /dev/null +++ b/postgresql_archive/src/repository/registry.rs @@ -0,0 +1,135 @@ +use crate::repository::github::repository::GitHub; +use crate::repository::model::Repository; +use crate::Error::UnsupportedRepository; +use crate::Result; +use lazy_static::lazy_static; +use std::sync::{Arc, Mutex, RwLock}; + +lazy_static! { + static ref REGISTRY: Arc> = + Arc::new(Mutex::new(RepositoryRegistry::default())); +} + +type RepoSupportsFn = Arc bool + Send + Sync>>; +type SupportsFn = Box bool + Send + Sync>; +type RepoNewFn = Arc Result> + Send + Sync>>; +type NewFn = Box Result> + Send + Sync>; + +/// Singleton struct to store repositories +struct RepositoryRegistry { + repositories: Vec<(RepoSupportsFn, RepoNewFn)>, +} + +impl RepositoryRegistry { + /// Creates a new repository registry. + /// + /// # Returns + /// * The repository registry. + fn new() -> Self { + Self { + repositories: Vec::new(), + } + } + + /// Registers a repository. Newly registered repositories can override existing ones. + /// + /// # Arguments + /// * `supports_fn` - The function to check if the repository supports the URL. + /// * `new_fn` - The repository constructor function to register. + fn register(&mut self, supports_fn: SupportsFn, new_fn: NewFn) { + self.repositories.insert( + 0, + ( + Arc::new(RwLock::new(supports_fn)), + Arc::new(RwLock::new(new_fn)), + ), + ); + } + + /// Gets a repository that supports the specified URL + /// + /// # Arguments + /// * `url` - The URL to check for support. + /// + /// # Returns + /// * The repository that supports the URL. + fn get(&self, url: &str) -> Result> { + for (supports_fn, new_fn) in &self.repositories { + let supports_function = supports_fn.read().unwrap(); + if supports_function(url) { + let new_function = new_fn.read().unwrap(); + return new_function(url); + } + } + + Err(UnsupportedRepository(url.to_string())) + } +} + +impl Default for RepositoryRegistry { + fn default() -> Self { + let mut registry = Self::new(); + registry.register(Box::new(GitHub::supports), Box::new(GitHub::new)); + registry + } +} + +/// Registers a repository. Newly registered repositories can override existing ones. +/// +/// # Arguments +/// * `supports_fn` - The function to check if the repository supports the URL. +/// * `new_fn` - The repository constructor function to register. +#[allow(dead_code)] +pub fn register(supports_fn: SupportsFn, new_fn: NewFn) { + let mut registry = REGISTRY.lock().unwrap(); + registry.register(supports_fn, new_fn); +} + +/// Gets a repository that supports the specified URL +/// +/// # Arguments +/// * `url` - The URL to check for support. +/// +/// # Returns +/// * The repository that supports the URL. +/// +/// # Errors +/// * If the URL is not supported. +pub fn get(url: &str) -> Result> { + let registry = REGISTRY.lock().unwrap(); + registry.get(url) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_register() -> Result<()> { + assert!(!REGISTRY.lock().unwrap().repositories.is_empty()); + REGISTRY.lock().unwrap().repositories.truncate(0); + assert!(REGISTRY.lock().unwrap().repositories.is_empty()); + + register(Box::new(GitHub::supports), Box::new(GitHub::new)); + let url = "https://github.com/theseus-rs/postgresql-binaries"; + let result = get(url); + assert!(result.is_ok()); + Ok(()) + } + + #[tokio::test] + async fn test_get_no_host() -> Result<()> { + let url = "https://"; + let error = get(url).err().unwrap(); + assert_eq!("unsupported repository for 'https://'", error.to_string()); + Ok(()) + } + + #[tokio::test] + async fn test_get_github() -> Result<()> { + let url = "https://github.com/theseus-rs/postgresql-binaries"; + let result = get(url); + assert!(result.is_ok()); + Ok(()) + } +} diff --git a/postgresql_archive/src/version.rs b/postgresql_archive/src/version.rs index 0d16e86..8ed48f8 100644 --- a/postgresql_archive/src/version.rs +++ b/postgresql_archive/src/version.rs @@ -1,285 +1,125 @@ -//! PostgreSQL version -#![allow(dead_code)] +use crate::Result; +use semver::{Version, VersionReq}; -use crate::error::Error::InvalidVersion; -use crate::error::{Error, Result}; -use serde::{Deserialize, Serialize}; -use std::fmt; -use std::str::FromStr; - -/// PostgreSQL version struct. The version is a simple wrapper around a string. -/// [Actively supported](https://www.postgresql.org/developer/roadmap/) major versions of -/// PostgreSQL are defined as constants. The oldest supported version is will be marked -/// as deprecated. Deprecated versions will be removed in a future release following semver -/// conventions for this crate. -#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)] -pub struct Version { - pub major: u64, - pub minor: Option, - pub release: Option, -} - -/// The latest PostgreSQL version -pub const LATEST: Version = V16; - -/// The latest PostgreSQL version 16 -pub const V16: Version = Version::new(16, None, None); - -/// The latest PostgreSQL version 15 -pub const V15: Version = Version::new(15, None, None); - -/// The latest PostgreSQL version 14 -pub const V14: Version = Version::new(14, None, None); - -/// The latest PostgreSQL version 13 -pub const V13: Version = Version::new(13, None, None); - -/// The latest PostgreSQL version 12 -#[allow(deprecated)] -#[deprecated( - since = "0.1.0", - note = "See https://www.postgresql.org/developer/roadmap/" -)] -pub const V12: Version = Version::new(12, None, None); - -impl Version { - #[must_use] - pub const fn new(major: u64, minor: Option, release: Option) -> Self { - Self { - major, - minor, - release, - } - } - - /// Matches the version against another version. Provides a simple way to match - /// against a major, major/minor, or major/minor/release version. Returns `true` - /// if the major version matches and the minor version matches or is not specified. - /// Returns `true` if the major and minor versions match and the release matches or - /// is not specified. Returns `false` otherwise. - /// - /// # Examples - /// The methods of this trait must be consistent with each other and with those of [`PartialEq`]. - /// The following conditions must hold: +/// A trait for getting the exact version from a [version requirement](VersionReq). +pub trait ExactVersion { + /// Gets the exact version from a [version requirement](VersionReq). /// - /// 1. `16` matches `16.1.0`, `16.1.1`, `16.2.0`, etc. - /// 2. `16.1` matches `16.1.0`, `16.1.1`, etc. - /// 3. `16.1.0` matches only `16.1.0` - /// 4. `15` does not match `16.1.0` - /// 5. `16.0` does not match `16.1.0` - /// 6. `16.1.0` does not match `16.1.1` - #[must_use] - pub fn matches(&self, version: &Version) -> bool { - if self.major != version.major { - return false; - } else if self.minor.is_none() || version.minor.is_none() { - return true; - } else if self.minor != version.minor { - return false; - } else if self.release.is_none() || version.release.is_none() { - return true; - } - - self.release == version.release - } + /// # Returns + /// * The exact version or `None` if the [version requirement](VersionReq) is not an exact + /// version. + fn exact_version(&self) -> Option; } -impl fmt::Display for Version { - fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - let major = self.major.to_string(); - let minor = self - .minor - .map_or(String::new(), |minor| format!(".{minor}")); - let release = self - .release - .map_or(String::new(), |release| format!(".{release}")); - write!(formatter, "{major}{minor}{release}") - } -} - -impl FromStr for Version { - type Err = Error; - - fn from_str(version: &str) -> Result { - let parts: Vec<&str> = version.split('.').collect(); - let major_str = parts.first().unwrap_or(&"not specified"); - let major: u64 = major_str.parse()?; - - let minor: Option = match parts.get(1) { - Some(minor) => Some(minor.parse()?), - None => None, - }; +impl ExactVersion for VersionReq { + fn exact_version(&self) -> Option { + if self.comparators.len() != 1 { + return None; + } - let release: Option = match parts.get(2) { - Some(release) => Some(release.parse()?), - None => None, - }; + if let Some(comparator) = self.comparators.first() { + if comparator.op != semver::Op::Exact { + return None; + } + let minor = comparator.minor?; + let patch = comparator.patch?; - if parts.len() > 3 { - return Err(InvalidVersion(version.to_string())); + let version = Version::new(comparator.major, minor, patch); + return Some(version); } - Ok(Version::new(major, minor, release)) + None } } -impl<'de> Deserialize<'de> for Version { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - let version = String::deserialize(deserializer)?; - Version::from_str(&version).map_err(serde::de::Error::custom) - } +/// A trait for getting the exact version requirement from a [version](Version). +pub trait ExactVersionReq { + /// Gets the exact version requirement from a [version](Version). + /// + /// # Returns + /// * The exact version requirement. + /// + /// # Errors + /// * If the version requirement cannot be parsed. + fn exact_version_req(&self) -> Result; } -impl Serialize for Version { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - self.to_string().serialize(serializer) +impl ExactVersionReq for Version { + fn exact_version_req(&self) -> Result { + let version = format!("={self}"); + let version_req = VersionReq::parse(&version)?; + Ok(version_req) } } #[cfg(test)] mod tests { use super::*; - use test_log::test; - - // - // Impl tests - // + use crate::Result; #[test] - fn test_matches_all() { - assert!(&Version::new(1, None, None).matches(&Version::new(1, None, None))); - assert!(!&Version::new(1, None, None).matches(&Version::new(2, None, None))); - - assert!(&Version::new(1, Some(2), None).matches(&Version::new(1, Some(2), None))); - assert!(!&Version::new(1, Some(2), None).matches(&Version::new(1, Some(3), None))); - - assert!(&Version::new(1, Some(2), Some(3)).matches(&Version::new(1, Some(2), Some(3)))); - assert!(!&Version::new(1, Some(2), Some(3)).matches(&Version::new(1, Some(2), Some(4)))); - - assert!(&Version::new(1, None, None).matches(&Version::new(1, Some(2), None))); - assert!(&Version::new(1, Some(2), None).matches(&Version::new(1, None, None))); - - assert!(&Version::new(1, Some(2), None).matches(&Version::new(1, Some(2), Some(3)))); - assert!(&Version::new(1, Some(2), Some(3)).matches(&Version::new(1, Some(2), None))); + fn test_exact_version_star() { + let version_req = VersionReq::STAR; + assert_eq!(None, version_req.exact_version()); } - // - // Display tests - // - #[test] - fn test_version_display() -> Result<()> { - let version_str = "1.2.3"; - let version = Version::from_str(version_str)?; - assert_eq!(version_str, version.to_string()); + fn test_exact_version_greater_than() -> Result<()> { + let version_req = VersionReq::parse(">16")?; + assert_eq!(None, version_req.exact_version()); Ok(()) } #[test] - fn test_version_display_major() -> Result<()> { - let version_str = "1"; - let version = Version::from_str(version_str)?; - assert_eq!(version_str, version.to_string()); + fn test_exact_version_full_no_equals() -> Result<()> { + let version_req = VersionReq::parse("16.3.0")?; + assert_eq!(None, version_req.exact_version()); Ok(()) } #[test] - fn test_version_display_major_minor() -> Result<()> { - let version_str = "1.2"; - let version = Version::from_str(version_str)?; - assert_eq!(version_str, version.to_string()); + fn test_exact_version_full_equals() -> Result<()> { + let version_req = VersionReq::parse("=16.3.0")?; + let version = Version::new(16, 3, 0); + assert_eq!(Some(version), version_req.exact_version()); Ok(()) } - // - // FromStr tests - // - #[test] - fn test_version_from_str() -> Result<()> { - let version = Version::from_str("1.2.3")?; - assert_eq!(version.major, 1u64); - assert_eq!(version.minor, Some(2)); - assert_eq!(version.release, Some(3)); + fn test_exact_version_major_minor() -> Result<()> { + let version_req = VersionReq::parse("=16.3")?; + assert_eq!(None, version_req.exact_version()); Ok(()) } #[test] - fn test_version_from_str_major() -> Result<()> { - let version = Version::from_str("1")?; - assert_eq!(version.major, 1); - assert_eq!(version.minor, None); - assert_eq!(version.release, None); + fn test_exact_version_major() -> Result<()> { + let version_req = VersionReq::parse("=16")?; + assert_eq!(None, version_req.exact_version()); Ok(()) } #[test] - fn test_version_from_str_major_minor() -> Result<()> { - let version = Version::from_str("1.2")?; - assert_eq!(version.major, 1); - assert_eq!(version.minor, Some(2)); - assert_eq!(version.release, None); + fn test_exact_version_req_not_equal() -> Result<()> { + let version = Version::new(1, 2, 3); + assert_ne!(VersionReq::parse("=1.0.0")?, version.exact_version_req()?); Ok(()) } #[test] - fn test_version_from_str_error_missing_major() { - assert!(Version::from_str("").is_err()); - } - - #[test] - fn test_version_from_str_error_invalid_major() { - assert!(Version::from_str("a").is_err()); - } - - #[test] - fn test_version_from_str_error_invalid_minor() { - assert!(Version::from_str("1.a").is_err()); - } - - #[test] - fn test_version_from_str_error_invalid_release() { - assert!(Version::from_str("1.2.a").is_err()); - } - - #[test] - fn test_version_from_str_error_too_many_parts() { - assert!(Version::from_str("1.2.3.4").is_err()); - } - - // - // Deserialize tests - // - - #[test] - fn test_version_deserialize() -> anyhow::Result<()> { - let version = serde_json::from_str::("\"1.2.3\"")?; - assert_eq!(version.major, 1u64); - assert_eq!(version.minor, Some(2)); - assert_eq!(version.release, Some(3)); + fn test_exact_version_req_major_minor_patch() -> Result<()> { + let version = Version::new(16, 3, 0); + assert_eq!(VersionReq::parse("=16.3.0")?, version.exact_version_req()?); Ok(()) } #[test] - fn test_version_deserialize_parse_error() { - assert!(serde_json::from_str::("\"foo\"").is_err()); - } - - // - // Serialize tests - // - - #[test] - fn test_version_serialize() -> anyhow::Result<()> { - let version = Version::new(1, Some(2), Some(3)); - let version_str = serde_json::to_string(&version)?; - assert_eq!(version_str, "\"1.2.3\""); + fn test_exact_version_prerelease() -> Result<()> { + let version = Version::parse("1.2.3-alpha")?; + assert_eq!( + VersionReq::parse("=1.2.3-alpha")?, + version.exact_version_req()? + ); Ok(()) } } diff --git a/postgresql_archive/tests/archive.rs b/postgresql_archive/tests/archive.rs index 50bac3b..916d616 100644 --- a/postgresql_archive/tests/archive.rs +++ b/postgresql_archive/tests/archive.rs @@ -1,115 +1,82 @@ #[allow(deprecated)] -use postgresql_archive::{extract, Version, LATEST, V12, V13, V14, V15, V16}; -use postgresql_archive::{get_archive, get_archive_for_target, get_version, DEFAULT_RELEASES_URL}; +use postgresql_archive::extract; +use postgresql_archive::{get_archive, get_version, DEFAULT_POSTGRESQL_URL}; +use semver::VersionReq; use std::fs::{create_dir_all, remove_dir_all}; use test_log::test; -async fn test_get_archive_for_version_constant(version: Version) -> anyhow::Result<()> { - let (_archive_version, _archive) = get_archive(DEFAULT_RELEASES_URL, &version).await?; +async fn test_get_archive_for_version_constant(major: u64) -> anyhow::Result<()> { + let version_req = VersionReq::parse(&format!("={major}"))?; + let (archive_version, _archive) = get_archive(DEFAULT_POSTGRESQL_URL, &version_req).await?; + + assert!(version_req.matches(&archive_version)); + assert_eq!(major, archive_version.major); Ok(()) } #[test(tokio::test)] async fn test_get_archive_for_version_constant_v16() -> anyhow::Result<()> { - test_get_archive_for_version_constant(V16).await + test_get_archive_for_version_constant(16).await } #[test(tokio::test)] async fn test_get_archive_for_version_constant_v15() -> anyhow::Result<()> { - test_get_archive_for_version_constant(V15).await + test_get_archive_for_version_constant(15).await } #[test(tokio::test)] async fn test_get_archive_for_version_constant_v14() -> anyhow::Result<()> { - test_get_archive_for_version_constant(V14).await + test_get_archive_for_version_constant(14).await } #[test(tokio::test)] async fn test_get_archive_for_version_constant_v13() -> anyhow::Result<()> { - test_get_archive_for_version_constant(V13).await + test_get_archive_for_version_constant(13).await } #[test(tokio::test)] #[allow(deprecated)] async fn test_get_archive_for_version_constant_v12() -> anyhow::Result<()> { - test_get_archive_for_version_constant(V12).await + test_get_archive_for_version_constant(12).await } #[test(tokio::test)] async fn test_get_version_not_found() -> postgresql_archive::Result<()> { - let invalid_version = Version::new(1, Some(0), Some(0)); - let result = get_version(DEFAULT_RELEASES_URL, &invalid_version).await; + let invalid_version_req = VersionReq::parse("=1.0.0")?; + let result = get_version(DEFAULT_POSTGRESQL_URL, &invalid_version_req).await; + assert!(result.is_err()); Ok(()) } #[test(tokio::test)] async fn test_get_version() -> anyhow::Result<()> { - let version = &LATEST; - - assert!(version.minor.is_none()); - assert!(version.release.is_none()); - - let latest_version = get_version(DEFAULT_RELEASES_URL, version).await?; - - assert_eq!(version.major, latest_version.major); - assert!(latest_version.minor.is_some()); - assert!(latest_version.release.is_some()); + let version_req = VersionReq::STAR; + let latest_version = get_version(DEFAULT_POSTGRESQL_URL, &version_req).await?; + assert!(version_req.matches(&latest_version)); Ok(()) } #[test(tokio::test)] async fn test_get_archive_and_extract() -> anyhow::Result<()> { - let version = &LATEST; - let (archive_version, archive) = get_archive(DEFAULT_RELEASES_URL, version).await?; + let version_req = VersionReq::STAR; + let (archive_version, archive) = get_archive(DEFAULT_POSTGRESQL_URL, &version_req).await?; - assert!(archive_version.matches(version)); + assert!(version_req.matches(&archive_version)); let out_dir = tempfile::tempdir()?.path().to_path_buf(); create_dir_all(&out_dir)?; extract(&archive, &out_dir).await?; remove_dir_all(&out_dir)?; - Ok(()) } #[test(tokio::test)] async fn test_get_archive_version_not_found() -> postgresql_archive::Result<()> { - let invalid_version = Version::new(1, Some(0), Some(0)); - let result = get_archive(DEFAULT_RELEASES_URL, &invalid_version).await; - assert!(result.is_err()); - Ok(()) -} + let invalid_version_req = VersionReq::parse("=1.0.0")?; + let result = get_archive(DEFAULT_POSTGRESQL_URL, &invalid_version_req).await; -#[test(tokio::test)] -async fn test_get_archive_for_target_version_not_found() -> postgresql_archive::Result<()> { - let invalid_version = Version::new(1, Some(0), Some(0)); - let result = get_archive_for_target( - DEFAULT_RELEASES_URL, - &invalid_version, - target_triple::TARGET, - ) - .await; assert!(result.is_err()); Ok(()) } - -#[test(tokio::test)] -async fn test_get_archive_for_target_target_not_found() -> postgresql_archive::Result<()> { - let result = - get_archive_for_target(DEFAULT_RELEASES_URL, &LATEST, "wasm64-unknown-unknown").await; - assert!(result.is_err()); - Ok(()) -} - -#[test(tokio::test)] -async fn test_get_archive_for_target() -> anyhow::Result<()> { - let version = &LATEST; - let (archive_version, _archive) = - get_archive_for_target(DEFAULT_RELEASES_URL, version, target_triple::TARGET).await?; - - assert!(archive_version.matches(version)); - - Ok(()) -} diff --git a/postgresql_archive/tests/blocking.rs b/postgresql_archive/tests/blocking.rs index 6e690fe..cfeff47 100644 --- a/postgresql_archive/tests/blocking.rs +++ b/postgresql_archive/tests/blocking.rs @@ -1,7 +1,7 @@ #[cfg(feature = "blocking")] -use postgresql_archive::blocking::{extract, get_archive, get_archive_for_target, get_version}; +use postgresql_archive::blocking::{extract, get_archive, get_version}; #[cfg(feature = "blocking")] -use postgresql_archive::{DEFAULT_RELEASES_URL, LATEST}; +use postgresql_archive::{VersionReq, DEFAULT_POSTGRESQL_URL}; #[cfg(feature = "blocking")] use std::fs::{create_dir_all, remove_dir_all}; #[cfg(feature = "blocking")] @@ -10,17 +10,10 @@ use test_log::test; #[cfg(feature = "blocking")] #[test] fn test_get_version() -> anyhow::Result<()> { - let version = &LATEST; - - assert!(version.minor.is_none()); - assert!(version.release.is_none()); - - let latest_version = get_version(DEFAULT_RELEASES_URL, version)?; - - assert_eq!(version.major, latest_version.major); - assert!(latest_version.minor.is_some()); - assert!(latest_version.release.is_some()); + let version_req = VersionReq::STAR; + let latest_version = get_version(DEFAULT_POSTGRESQL_URL, &version_req)?; + assert!(version_req.matches(&latest_version)); Ok(()) } @@ -28,28 +21,14 @@ fn test_get_version() -> anyhow::Result<()> { #[test] #[allow(deprecated)] fn test_get_archive_and_extract() -> anyhow::Result<()> { - let version = &LATEST; - let (archive_version, archive) = get_archive(DEFAULT_RELEASES_URL, version)?; + let version_req = &VersionReq::STAR; + let (archive_version, archive) = get_archive(DEFAULT_POSTGRESQL_URL, version_req)?; - assert!(archive_version.matches(version)); + assert!(version_req.matches(&archive_version)); let out_dir = tempfile::tempdir()?.path().to_path_buf(); create_dir_all(&out_dir)?; extract(&archive, &out_dir)?; remove_dir_all(&out_dir)?; - - Ok(()) -} - -#[cfg(feature = "blocking")] -#[test] -#[allow(deprecated)] -fn test_get_archive_for_target() -> anyhow::Result<()> { - let version = &LATEST; - let (archive_version, _archive) = - get_archive_for_target(DEFAULT_RELEASES_URL, version, target_triple::TARGET)?; - - assert!(archive_version.matches(version)); - Ok(()) } diff --git a/postgresql_embedded/Cargo.toml b/postgresql_embedded/Cargo.toml index 1e741b6..58aaa5a 100644 --- a/postgresql_embedded/Cargo.toml +++ b/postgresql_embedded/Cargo.toml @@ -23,6 +23,7 @@ lazy_static = { workspace = true } postgresql_archive = { path = "../postgresql_archive", version = "0.12.0", default-features = false } postgresql_commands = { path = "../postgresql_commands", version = "0.12.0" } rand = { workspace = true } +semver = { workspace = true } tempfile = { workspace = true } thiserror = { workspace = true } tokio = { workspace = true, features = ["full"], optional = true } diff --git a/postgresql_embedded/build/bundle.rs b/postgresql_embedded/build/bundle.rs index 37572e3..b81cf3f 100644 --- a/postgresql_embedded/build/bundle.rs +++ b/postgresql_embedded/build/bundle.rs @@ -1,8 +1,8 @@ #![allow(dead_code)] use anyhow::Result; -use postgresql_archive::{get_archive, DEFAULT_RELEASES_URL}; -use postgresql_archive::{Version, LATEST}; +use postgresql_archive::VersionReq; +use postgresql_archive::{get_archive, DEFAULT_POSTGRESQL_URL}; use std::fs::File; use std::io::Write; use std::path::PathBuf; @@ -15,11 +15,11 @@ use std::{env, fs}; /// downloaded at runtime. pub(crate) async fn stage_postgresql_archive() -> Result<()> { let releases_url = - env::var("POSTGRESQL_RELEASES_URL").unwrap_or(DEFAULT_RELEASES_URL.to_string()); + env::var("POSTGRESQL_RELEASES_URL").unwrap_or(DEFAULT_POSTGRESQL_URL.to_string()); println!("PostgreSQL releases URL: {releases_url}"); - let postgres_version = env::var("POSTGRESQL_VERSION").unwrap_or(LATEST.to_string()); - let version = Version::from_str(postgres_version.as_str())?; - println!("PostgreSQL version: {postgres_version}"); + let postgres_version_req = env::var("POSTGRESQL_VERSION").unwrap_or("*".to_string()); + let version_req = VersionReq::from_str(postgres_version_req.as_str())?; + println!("PostgreSQL version: {postgres_version_req}"); let out_dir = PathBuf::from(env::var("OUT_DIR")?); println!("OUT_DIR: {:?}", out_dir); @@ -34,7 +34,7 @@ pub(crate) async fn stage_postgresql_archive() -> Result<()> { return Ok(()); } - let (asset_version, archive) = get_archive(&releases_url, &version).await?; + let (asset_version, archive) = get_archive(&releases_url, &version_req).await?; fs::write(archive_version_file.clone(), asset_version.to_string())?; let mut file = File::create(archive_file.clone())?; diff --git a/postgresql_embedded/src/blocking/postgresql.rs b/postgresql_embedded/src/blocking/postgresql.rs index 5bbd392..07b8815 100644 --- a/postgresql_embedded/src/blocking/postgresql.rs +++ b/postgresql_embedded/src/blocking/postgresql.rs @@ -116,11 +116,11 @@ impl PostgreSQL { #[cfg(test)] mod test { use super::*; - use postgresql_archive::Version; + use crate::VersionReq; #[test] - fn test_postgresql() { - let version = Version::new(16, Some(3), Some(0)); + fn test_postgresql() -> Result<()> { + let version = VersionReq::parse("=16.3.0")?; let settings = Settings { version, ..Settings::default() @@ -128,5 +128,6 @@ mod test { let postgresql = PostgreSQL::new(settings); let initial_statuses = [Status::NotInstalled, Status::Installed, Status::Stopped]; assert!(initial_statuses.contains(&postgresql.status())); + Ok(()) } } diff --git a/postgresql_embedded/src/error.rs b/postgresql_embedded/src/error.rs index 1c2bd7a..36250d4 100644 --- a/postgresql_embedded/src/error.rs +++ b/postgresql_embedded/src/error.rs @@ -36,6 +36,9 @@ pub enum Error { /// Error when IO operations fail #[error(transparent)] IoError(anyhow::Error), + /// Parse error + #[error(transparent)] + ParseError(#[from] semver::Error), } /// Convert `PostgreSQL` [archive errors](postgresql_archive::Error) to an [embedded errors](Error::ArchiveError) @@ -69,7 +72,7 @@ mod test { fn test_from_archive_error() { let archive_error = postgresql_archive::Error::ReleaseNotFound("test".to_string()); let error = Error::from(archive_error); - assert_eq!(error.to_string(), "release not found for version [test]"); + assert_eq!(error.to_string(), "release not found for 'test'"); } #[test] diff --git a/postgresql_embedded/src/lib.rs b/postgresql_embedded/src/lib.rs index c690ecb..3df71f4 100644 --- a/postgresql_embedded/src/lib.rs +++ b/postgresql_embedded/src/lib.rs @@ -81,11 +81,13 @@ //! //! The following features are available: //! -//! | Name | Description | Default? | -//! |------------|-----------------------------------------------------------|----------| -//! | `bundled` | Bundles the PostgreSQL archive into the resulting binary | No | -//! | `blocking` | Enables the blocking API; requires `tokio` | No | -//! | `tokio` | Enables using tokio for async | No | +//! | Name | Description | Default? | +//! |--------------|----------------------------------------------------------|----------| +//! | `bundled` | Bundles the PostgreSQL archive into the resulting binary | No | +//! | `blocking` | Enables the blocking API; requires `tokio` | No | +//! | `native-tls` | Enables native-tls support | No | +//! | `rustls-tls` | Enables rustls-tls support | Yes | +//! | `tokio` | Enables using tokio for async | No | //! //! ## Safety //! @@ -110,6 +112,7 @@ #![deny(clippy::pedantic)] #![allow(dead_code)] #![allow(clippy::doc_markdown)] +#![allow(deprecated)] #[cfg(feature = "blocking")] pub mod blocking; @@ -119,5 +122,29 @@ mod settings; pub use error::{Error, Result}; pub use postgresql::{PostgreSQL, Status}; -pub use postgresql_archive::Version; +pub use postgresql_archive::{Version, VersionReq}; pub use settings::Settings; + +lazy_static::lazy_static! { + /// The latest PostgreSQL version requirement + pub static ref LATEST: VersionReq = VersionReq::STAR; + + /// The latest PostgreSQL version 16 + pub static ref V16: VersionReq = VersionReq::parse("=16").unwrap(); + + /// The latest PostgreSQL version 15 + pub static ref V15: VersionReq = VersionReq::parse("=15").unwrap(); + + /// The latest PostgreSQL version 14 + pub static ref V14: VersionReq = VersionReq::parse("=14").unwrap(); + + /// The latest PostgreSQL version 13 + pub static ref V13: VersionReq = VersionReq::parse("=13").unwrap(); + + /// The latest PostgreSQL version 12 + #[deprecated( + since = "0.1.0", + note = "See https://www.postgresql.org/developer/roadmap/" + )] + pub static ref V12: VersionReq = VersionReq::parse("=12").unwrap(); +} diff --git a/postgresql_embedded/src/postgresql.rs b/postgresql_embedded/src/postgresql.rs index 8006761..de04d76 100644 --- a/postgresql_embedded/src/postgresql.rs +++ b/postgresql_embedded/src/postgresql.rs @@ -3,6 +3,7 @@ use crate::error::Result; use crate::settings::{Settings, BOOTSTRAP_SUPERUSER}; use postgresql_archive::get_version; use postgresql_archive::{extract, get_archive}; +use postgresql_archive::{ExactVersion, ExactVersionReq}; use postgresql_commands::initdb::InitDbBuilder; use postgresql_commands::pg_ctl::Mode::{Start, Stop}; use postgresql_commands::pg_ctl::PgCtlBuilder; @@ -48,12 +49,11 @@ impl PostgreSQL { pub fn new(settings: Settings) -> Self { let mut postgresql = PostgreSQL { settings }; - // If the minor and release version are set, append the version to the installation directory - // to avoid conflicts with other versions. This will also facilitate setting the status - // of the server to the correct initial value. If the minor and release version are not set, - // the installation directory will be determined dynamically during the installation process. - let version = postgresql.settings.version; - if version.minor.is_some() && version.release.is_some() { + // If an exact version is set, append the version to the installation directory to avoid + // conflicts with other versions. This will also facilitate setting the status of the + // server to the correct initial value. If the minor and release version are not set, the + // installation directory will be determined dynamically during the installation process. + if let Some(version) = postgresql.settings.version.exact_version() { let path = &postgresql.settings.installation_dir; let version_string = version.to_string(); @@ -88,11 +88,9 @@ impl PostgreSQL { /// Check if the `PostgreSQL` server is installed fn is_installed(&self) -> bool { - let version = self.settings.version; - if version.minor.is_none() || version.release.is_none() { + let Some(version) = self.settings.version.exact_version() else { return false; - } - + }; let path = &self.settings.installation_dir; path.ends_with(version.to_string()) && path.exists() } @@ -136,16 +134,14 @@ impl PostgreSQL { self.settings.version ); - // If the minor and release version are not set, determine the latest version and update the - // version and installation directory accordingly. This is an optimization to avoid downloading - // the archive if the latest version is already installed. - if self.settings.version.minor.is_none() || self.settings.version.release.is_none() { - self.settings.version = - get_version(&self.settings.releases_url, &self.settings.version).await?; - self.settings.installation_dir = self - .settings - .installation_dir - .join(self.settings.version.to_string()); + // If the exact version is not set, determine the latest version and update the version and + // installation directory accordingly. This is an optimization to avoid downloading the + // archive if the latest version is already installed. + if self.settings.version.exact_version().is_none() { + let version = get_version(&self.settings.releases_url, &self.settings.version).await?; + self.settings.version = version.exact_version_req()?; + self.settings.installation_dir = + self.settings.installation_dir.join(version.to_string()); } if self.settings.installation_dir.exists() { @@ -160,16 +156,21 @@ impl PostgreSQL { let (version, bytes) = if *crate::settings::ARCHIVE_VERSION == self.settings.version { debug!("Using bundled installation archive"); ( - self.settings.version, + self.settings.version.clone(), bytes::Bytes::copy_from_slice(crate::settings::ARCHIVE), ) } else { - get_archive(&self.settings.releases_url, &self.settings.version).await? + let (version, bytes) = + get_archive(&self.settings.releases_url, &self.settings.version).await?; + (version.exact_version_req()?, bytes) }; #[cfg(not(feature = "bundled"))] - let (version, bytes) = - { get_archive(&self.settings.releases_url, &self.settings.version).await? }; + let (version, bytes) = { + let (version, bytes) = + get_archive(&self.settings.releases_url, &self.settings.version).await?; + (version.exact_version_req()?, bytes) + }; self.settings.version = version; extract(&bytes, &self.settings.installation_dir).await?; diff --git a/postgresql_embedded/src/settings.rs b/postgresql_embedded/src/settings.rs index 2c1dbff..42bb3a3 100644 --- a/postgresql_embedded/src/settings.rs +++ b/postgresql_embedded/src/settings.rs @@ -1,6 +1,6 @@ use crate::error::{Error, Result}; use home::home_dir; -use postgresql_archive::{Version, DEFAULT_RELEASES_URL}; +use postgresql_archive::{VersionReq, DEFAULT_POSTGRESQL_URL}; use rand::distributions::Alphanumeric; use rand::Rng; use std::collections::HashMap; @@ -15,11 +15,13 @@ use url::Url; #[cfg(feature = "bundled")] lazy_static::lazy_static! { #[allow(clippy::unwrap_used)] - pub(crate) static ref ARCHIVE_VERSION: Version = { - let version_string = include_str!(concat!(std::env!("OUT_DIR"), "/postgresql.version")); - let version = Version::from_str(version_string).unwrap(); - tracing::debug!("Bundled installation archive version {version}"); - version + pub(crate) static ref ARCHIVE_VERSION: VersionReq = { + let version_string = include_str!( + concat!(std::env!("OUT_DIR"), "/postgresql.version") + ); + let version_req = VersionReq::from_str(&format!("={version_string}")).unwrap(); + tracing::debug!("Bundled installation archive version {version_string}"); + version_req }; } @@ -34,8 +36,8 @@ pub const BOOTSTRAP_SUPERUSER: &str = "postgres"; pub struct Settings { /// URL for the releases location of the `PostgreSQL` installation archives pub releases_url: String, - /// Version of `PostgreSQL` to install - pub version: Version, + /// Version requirement of `PostgreSQL` to install + pub version: VersionReq, /// `PostgreSQL` installation directory pub installation_dir: PathBuf, /// `PostgreSQL` password file @@ -90,7 +92,7 @@ impl Settings { .collect(); Self { - releases_url: DEFAULT_RELEASES_URL.to_string(), + releases_url: DEFAULT_POSTGRESQL_URL.to_string(), version: default_version(), installation_dir: home_dir.join(".theseus").join("postgresql"), password_file, @@ -146,7 +148,7 @@ impl Settings { settings.releases_url = releases_url.to_string(); } if let Some(version) = query_parameters.get("version") { - settings.version = Version::from_str(version)?; + settings.version = VersionReq::parse(version)?; } if let Some(installation_dir) = query_parameters.get("installation_dir") { if let Ok(path) = PathBuf::from_str(installation_dir) { @@ -236,15 +238,15 @@ impl Default for Settings { /// Get the default version used if not otherwise specified #[must_use] -fn default_version() -> Version { +fn default_version() -> VersionReq { #[cfg(feature = "bundled")] { - *ARCHIVE_VERSION + ARCHIVE_VERSION.clone() } #[cfg(not(feature = "bundled"))] { - postgresql_archive::LATEST + VersionReq::STAR } } @@ -288,7 +290,7 @@ mod tests { fn test_settings_from_url() -> Result<()> { let base_url = "postgresql://postgres:password@localhost:5432/test"; let releases_url = "releases_url=https%3A%2F%2Fgithub.com"; - let version = "version=16.3.0"; + let version = "version=%3D16.3.0"; let installation_dir = "installation_dir=/tmp/postgresql"; let password_file = "password_file=/tmp/.pgpass"; let data_dir = "data_dir=/tmp/data"; @@ -300,7 +302,7 @@ mod tests { let settings = Settings::from_url(url)?; assert_eq!("https://github.com", settings.releases_url); - assert_eq!(Version::new(16, Some(3), Some(0)), settings.version); + assert_eq!(VersionReq::parse("=16.3.0")?, settings.version); assert_eq!(PathBuf::from("/tmp/postgresql"), settings.installation_dir); assert_eq!(PathBuf::from("/tmp/.pgpass"), settings.password_file); assert_eq!(PathBuf::from("/tmp/data"), settings.data_dir);