From 92b640edae1bb01e9224778691b216e72bbb7e34 Mon Sep 17 00:00:00 2001 From: brianheineman Date: Sat, 29 Jun 2024 12:45:37 -0600 Subject: [PATCH 1/3] feat: add hasher and matcher supports function --- postgresql_archive/src/error.rs | 6 + postgresql_archive/src/hasher/registry.rs | 126 +++++++----------- postgresql_archive/src/matcher/registry.rs | 121 ++++++----------- .../src/repository/github/repository.rs | 2 +- postgresql_archive/src/repository/registry.rs | 93 +++++++------ 5 files changed, 143 insertions(+), 205 deletions(-) diff --git a/postgresql_archive/src/error.rs b/postgresql_archive/src/error.rs index e6683d4..22df208 100644 --- a/postgresql_archive/src/error.rs +++ b/postgresql_archive/src/error.rs @@ -31,6 +31,12 @@ pub enum Error { /// Unexpected error #[error("{0}")] Unexpected(String), + /// Unsupported hasher + #[error("unsupported hasher for '{0}'")] + UnsupportedHasher(String), + /// Unsupported hasher + #[error("unsupported matcher for '{0}'")] + UnsupportedMatcher(String), /// Unsupported repository #[error("unsupported repository for '{0}'")] UnsupportedRepository(String), diff --git a/postgresql_archive/src/hasher/registry.rs b/postgresql_archive/src/hasher/registry.rs index 90f3264..e937001 100644 --- a/postgresql_archive/src/hasher/registry.rs +++ b/postgresql_archive/src/hasher/registry.rs @@ -1,8 +1,7 @@ use crate::hasher::{blake2b_512, blake2s_256, sha2_256, sha2_512, sha3_256, sha3_512}; -use crate::Error::PoisonedLock; +use crate::Error::{PoisonedLock, UnsupportedRepository}; use crate::Result; use lazy_static::lazy_static; -use std::collections::HashMap; use std::sync::{Arc, Mutex, RwLock}; lazy_static! { @@ -10,53 +9,55 @@ lazy_static! { Arc::new(Mutex::new(HasherRegistry::default())); } +pub type SupportsFn = fn(&str, &str) -> Result; pub type HasherFn = fn(&Vec) -> Result; /// Singleton struct to store hashers +#[allow(clippy::type_complexity)] struct HasherRegistry { - hashers: HashMap>>, + hashers: Vec<(Arc>, Arc>)>, } impl HasherRegistry { /// Creates a new hasher registry. fn new() -> Self { Self { - hashers: HashMap::new(), + hashers: Vec::new(), } } - /// Registers a hasher for an extension. Newly registered hashers with the same extension will - /// override existing ones. - fn register>(&mut self, extension: S, hasher_fn: HasherFn) { - let extension = extension.as_ref().to_string(); - self.hashers - .insert(extension, Arc::new(RwLock::new(hasher_fn))); + /// Registers a hasher for a supports function. Newly registered hashers will take precedence + /// over existing ones. + fn register(&mut self, supports_fn: SupportsFn, hasher_fn: HasherFn) { + self.hashers.insert( + 0, + ( + Arc::new(RwLock::new(supports_fn)), + Arc::new(RwLock::new(hasher_fn)), + ), + ); } - /// Get a hasher for the specified extension. + /// Get a hasher for the specified url and extension. /// /// # Errors /// * If the registry is poisoned. - fn get>(&self, extension: S) -> Result> { - let extension = extension.as_ref().to_string(); - if let Some(hasher) = self.hashers.get(&extension) { - let hasher = *hasher + fn get>(&self, url: S, extension: S) -> Result { + let url = url.as_ref(); + let extension = extension.as_ref(); + for (supports_fn, hasher_fn) in &self.hashers { + let supports_function = supports_fn .read() .map_err(|error| PoisonedLock(error.to_string()))?; - return Ok(Some(hasher)); + if supports_function(url, extension)? { + let hasher_function = hasher_fn + .read() + .map_err(|error| PoisonedLock(error.to_string()))?; + return Ok(*hasher_function); + } } - Ok(None) - } - - /// Get the number of hashers in the registry. - fn len(&self) -> usize { - self.hashers.len() - } - - /// Check if the registry is empty. - fn is_empty(&self) -> bool { - self.hashers.is_empty() + Err(UnsupportedRepository(url.to_string())) } } @@ -64,61 +65,39 @@ impl Default for HasherRegistry { /// Creates a new hasher registry with the default hashers registered. fn default() -> Self { let mut registry = Self::new(); - registry.register("blake2s", blake2s_256::hash); - registry.register("blake2b", blake2b_512::hash); - registry.register("sha256", sha2_256::hash); - registry.register("sha512", sha2_512::hash); - registry.register("sha3-256", sha3_256::hash); - registry.register("sha3-512", sha3_512::hash); + registry.register(|_, extension| Ok(extension == "blake2s"), blake2s_256::hash); + registry.register(|_, extension| Ok(extension == "blake2b"), blake2b_512::hash); + registry.register(|_, extension| Ok(extension == "sha256"), sha2_256::hash); + registry.register(|_, extension| Ok(extension == "sha512"), sha2_512::hash); + registry.register(|_, extension| Ok(extension == "sha3-256"), sha3_256::hash); + registry.register(|_, extension| Ok(extension == "sha3-512"), sha3_512::hash); registry } } -/// Registers a hasher for an extension. Newly registered hashers with the same extension will -/// override existing ones. +/// Registers a hasher for a supports function. Newly registered hashers will take precedence +/// over existing ones. /// /// # Errors /// * If the registry is poisoned. #[allow(dead_code)] -pub fn register>(extension: S, hasher_fn: HasherFn) -> Result<()> { +pub fn register(supports_fn: SupportsFn, hasher_fn: HasherFn) -> Result<()> { let mut registry = REGISTRY .lock() .map_err(|error| PoisonedLock(error.to_string()))?; - registry.register(extension, hasher_fn); + registry.register(supports_fn, hasher_fn); Ok(()) } -/// Get a hasher for the specified extension. -/// -/// # Errors -/// * If the registry is poisoned. -pub fn get>(extension: S) -> Result> { - let registry = REGISTRY - .lock() - .map_err(|error| PoisonedLock(error.to_string()))?; - registry.get(extension) -} - -/// Get the number of matchers in the registry. +/// Get a hasher for the specified url and extension. /// /// # Errors /// * If the registry is poisoned. -pub fn len() -> Result { +pub fn get>(url: S, extension: S) -> Result { let registry = REGISTRY .lock() .map_err(|error| PoisonedLock(error.to_string()))?; - Ok(registry.len()) -} - -/// Check if the registry is empty. -/// -/// # Errors -/// * If the registry is poisoned. -pub fn is_empty() -> Result { - let registry = REGISTRY - .lock() - .map_err(|error| PoisonedLock(error.to_string()))?; - Ok(registry.is_empty()) + registry.get(url, extension) } #[cfg(test)] @@ -126,7 +105,7 @@ mod tests { use super::*; fn test_hasher(extension: &str, expected: &str) -> Result<()> { - let hasher = get(extension)?.unwrap(); + let hasher = get("https://foo.com", extension)?; let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]; let hash = hasher(&data)?; assert_eq!(expected, hash); @@ -135,22 +114,11 @@ mod tests { #[test] fn test_register() -> Result<()> { - let extension = "sha256"; - let hashers = len()?; - assert!(!is_empty()?); - REGISTRY - .lock() - .map_err(|error| PoisonedLock(error.to_string()))? - .hashers - .remove(extension); - assert_ne!(hashers, len()?); - register(extension, sha2_256::hash)?; - assert_eq!(hashers, len()?); - - test_hasher( - extension, - "9a89c68c4c5e28b8c4a5567673d462fff515db46116f9900624d09c474f593fb", - ) + register( + |_, extension| Ok(extension == "foo"), + |_| Ok("42".to_string()), + )?; + test_hasher("foo", "42") } #[test] diff --git a/postgresql_archive/src/matcher/registry.rs b/postgresql_archive/src/matcher/registry.rs index 98671bb..7059344 100644 --- a/postgresql_archive/src/matcher/registry.rs +++ b/postgresql_archive/src/matcher/registry.rs @@ -1,9 +1,8 @@ use crate::matcher::{default, postgresql_binaries}; -use crate::Error::PoisonedLock; +use crate::Error::{PoisonedLock, UnsupportedMatcher}; use crate::{Result, DEFAULT_POSTGRESQL_URL}; use lazy_static::lazy_static; use semver::Version; -use std::collections::HashMap; use std::sync::{Arc, Mutex, RwLock}; lazy_static! { @@ -11,59 +10,54 @@ lazy_static! { Arc::new(Mutex::new(MatchersRegistry::default())); } +pub type SupportsFn = fn(&str) -> Result; pub type MatcherFn = fn(&str, &Version) -> Result; /// Singleton struct to store matchers +#[allow(clippy::type_complexity)] struct MatchersRegistry { - matchers: HashMap, Arc>>, + matchers: Vec<(Arc>, Arc>)>, } impl MatchersRegistry { /// Creates a new matcher registry. fn new() -> Self { Self { - matchers: HashMap::new(), + matchers: Vec::new(), } } - /// Registers a matcher for a URL. Newly registered matchers with the same url will override - /// existing ones. - fn register>(&mut self, url: Option, matcher_fn: MatcherFn) { - let url: Option = url.map(|s| s.as_ref().to_string()); - self.matchers.insert(url, Arc::new(RwLock::new(matcher_fn))); + /// Registers a matcher for a supports function. Newly registered matchers with the take + /// precedence over existing ones. + fn register(&mut self, supports_fn: SupportsFn, matcher_fn: MatcherFn) { + self.matchers.insert( + 0, + ( + Arc::new(RwLock::new(supports_fn)), + Arc::new(RwLock::new(matcher_fn)), + ), + ); } - /// Get a matcher for the specified URL, or the default matcher if no matcher is - /// registered for the URL. + /// Get a matcher for the specified URL. /// /// # Errors /// * If the registry is poisoned. fn get>(&self, url: S) -> Result { - let url = Some(url.as_ref().to_string()); - if let Some(matcher) = self.matchers.get(&url) { - let matcher = *matcher + let url = url.as_ref(); + for (supports_fn, matcher_fn) in &self.matchers { + let supports_function = supports_fn .read() .map_err(|error| PoisonedLock(error.to_string()))?; - return Ok(matcher); + if supports_function(url)? { + let matcher_function = matcher_fn + .read() + .map_err(|error| PoisonedLock(error.to_string()))?; + return Ok(*matcher_function); + } } - let matcher = match self.matchers.get(&None) { - Some(matcher) => *matcher - .read() - .map_err(|error| PoisonedLock(error.to_string()))?, - None => default::matcher, - }; - Ok(matcher) - } - - /// Get the number of matchers in the registry. - fn len(&self) -> usize { - self.matchers.len() - } - - /// Check if the registry is empty. - fn is_empty(&self) -> bool { - self.matchers.is_empty() + Err(UnsupportedMatcher(url.to_string())) } } @@ -71,28 +65,30 @@ impl Default for MatchersRegistry { /// Creates a new matcher registry with the default matchers registered. fn default() -> Self { let mut registry = Self::new(); - registry.register(None::<&str>, default::matcher); - registry.register(Some(DEFAULT_POSTGRESQL_URL), postgresql_binaries::matcher); + registry.register( + |url| Ok(url == DEFAULT_POSTGRESQL_URL), + postgresql_binaries::matcher, + ); + registry.register(|_| Ok(true), default::matcher); registry } } -/// Registers a matcher for a URL. Newly registered matchers with the same url will override -/// existing ones. +/// Registers a matcher for a supports function. Newly registered matchers with the take +/// precedence over existing ones. /// /// # Errors /// * If the registry is poisoned. #[allow(dead_code)] -pub fn register>(url: Option, matcher_fn: MatcherFn) -> Result<()> { +pub fn register(supports_fn: SupportsFn, matcher_fn: MatcherFn) -> Result<()> { let mut registry = REGISTRY .lock() .map_err(|error| PoisonedLock(error.to_string()))?; - registry.register(url, matcher_fn); + registry.register(supports_fn, matcher_fn); Ok(()) } -/// Get a matcher for the specified URL, or the default matcher if no matcher is -/// registered for the URL. +/// Get a matcher for the specified URL. /// /// # Errors /// * If the registry is poisoned. @@ -103,53 +99,22 @@ pub fn get>(url: S) -> Result { registry.get(url) } -/// Get the number of matchers in the registry. -/// -/// # Errors -/// * If the registry is poisoned. -pub fn len() -> Result { - let registry = REGISTRY - .lock() - .map_err(|error| PoisonedLock(error.to_string()))?; - Ok(registry.len()) -} - -/// Check if the registry is empty. -/// -/// # Errors -/// * If the registry is poisoned. -pub fn is_empty() -> Result { - let registry = REGISTRY - .lock() - .map_err(|error| PoisonedLock(error.to_string()))?; - Ok(registry.is_empty()) -} - #[cfg(test)] mod tests { use super::*; - use crate::Error::PoisonedLock; use std::env; #[test] fn test_register() -> Result<()> { - let matchers = len()?; - assert!(!is_empty()?); - REGISTRY - .lock() - .map_err(|error| PoisonedLock(error.to_string()))? - .matchers - .remove(&None::); - assert_ne!(matchers, len()?); - register(None::<&str>, default::matcher)?; - assert_eq!(matchers, len()?); - - let matcher = get(DEFAULT_POSTGRESQL_URL)?; + register( + |url| Ok(url == "https://foo.com"), + |name, _| Ok(name == "foo"), + )?; + + let matcher = get("https://foo.com")?; let version = Version::new(16, 3, 0); - let target = target_triple::TARGET; - let name = format!("postgresql-{version}-{target}.tar.gz"); - assert!(matcher(name.as_str(), &version)?, "{}", name); + assert!(matcher("foo", &version)?); Ok(()) } diff --git a/postgresql_archive/src/repository/github/repository.rs b/postgresql_archive/src/repository/github/repository.rs index c97f31e..b072cd8 100644 --- a/postgresql_archive/src/repository/github/repository.rs +++ b/postgresql_archive/src/repository/github/repository.rs @@ -199,7 +199,7 @@ impl GitHub { .strip_prefix(format!("{}.", asset.name.as_str()).as_str()) .unwrap_or_default(); - if let Some(hasher_fn) = hasher::registry::get(extension)? { + if let Ok(hasher_fn) = hasher::registry::get(&self.url, &extension.to_string()) { asset_hash = Some(release_asset.clone()); asset_hasher_fn = Some(hasher_fn); break; diff --git a/postgresql_archive/src/repository/registry.rs b/postgresql_archive/src/repository/registry.rs index 5406645..9b9019e 100644 --- a/postgresql_archive/src/repository/registry.rs +++ b/postgresql_archive/src/repository/registry.rs @@ -27,7 +27,7 @@ impl RepositoryRegistry { } } - /// Registers a repository. Newly registered repositories can override existing ones. + /// Registers a repository. Newly registered repositories take precedence over existing ones. fn register(&mut self, supports_fn: Box, new_fn: Box) { self.repositories.insert( 0, @@ -57,16 +57,6 @@ impl RepositoryRegistry { Err(UnsupportedRepository(url.to_string())) } - - /// Get the number of repositories in the registry. - fn len(&self) -> usize { - self.repositories.len() - } - - /// Check if the registry is empty. - fn is_empty(&self) -> bool { - self.repositories.is_empty() - } } impl Default for RepositoryRegistry { @@ -102,48 +92,57 @@ pub fn get(url: &str) -> Result> { registry.get(url) } -/// Get the number of repositories in the registry. -/// -/// # Errors -/// * If the registry is poisoned. -pub fn len() -> Result { - let registry = REGISTRY - .lock() - .map_err(|error| PoisonedLock(error.to_string()))?; - Ok(registry.len()) -} - -/// Check if the registry is empty. -/// -/// # Errors -/// * If the registry is poisoned. -pub fn is_empty() -> Result { - let registry = REGISTRY - .lock() - .map_err(|error| PoisonedLock(error.to_string()))?; - Ok(registry.is_empty()) -} - #[cfg(test)] mod tests { use super::*; + use crate::repository::Archive; + use async_trait::async_trait; + use semver::{Version, VersionReq}; + use std::fmt::Debug; + + #[derive(Debug)] + struct TestRepository; + + impl TestRepository { + #[allow(clippy::new_ret_no_self)] + #[allow(clippy::unnecessary_wraps)] + fn new(_url: &str) -> Result> { + Ok(Box::new(Self)) + } + + fn supports(url: &str) -> bool { + url == "https://foo.com" + } + } + + #[async_trait] + impl Repository for TestRepository { + fn name(&self) -> &str { + "test" + } + + async fn get_version(&self, _version_req: &VersionReq) -> Result { + Ok(Version::new(0, 0, 42)) + } + + async fn get_archive(&self, _version_req: &VersionReq) -> Result { + Ok(Archive::new( + "test".to_string(), + Version::new(0, 0, 42), + Vec::new(), + )) + } + } #[tokio::test] async fn test_register() -> Result<()> { - let repositories = len()?; - assert!(!is_empty()?); - REGISTRY - .lock() - .map_err(|error| PoisonedLock(error.to_string()))? - .repositories - .truncate(0); - assert_ne!(repositories, len()?); - register(Box::new(GitHub::supports), Box::new(GitHub::new))?; - assert_eq!(repositories, len()?); - - let url = "https://github.com/theseus-rs/postgresql-binaries"; - let result = get(url); - assert!(result.is_ok()); + register( + Box::new(TestRepository::supports), + Box::new(TestRepository::new), + )?; + let url = "https://foo.com"; + let repository = get(url)?; + assert_eq!("test", repository.name()); Ok(()) } From fccea3463c7ffc817ee09ab4c182f07d326d7b4e Mon Sep 17 00:00:00 2001 From: brianheineman Date: Sat, 29 Jun 2024 13:20:58 -0600 Subject: [PATCH 2/3] test: improve test coverage --- postgresql_archive/src/repository/registry.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/postgresql_archive/src/repository/registry.rs b/postgresql_archive/src/repository/registry.rs index 9b9019e..cafefb7 100644 --- a/postgresql_archive/src/repository/registry.rs +++ b/postgresql_archive/src/repository/registry.rs @@ -143,6 +143,8 @@ mod tests { let url = "https://foo.com"; let repository = get(url)?; assert_eq!("test", repository.name()); + assert!(repository.get_version(&VersionReq::STAR).await.is_ok()); + assert!(repository.get_archive(&VersionReq::STAR).await.is_ok()); Ok(()) } From 9dfc01188edba6ebcd23dbe8fc943153ddd97a6c Mon Sep 17 00:00:00 2001 From: brianheineman Date: Sat, 29 Jun 2024 13:56:06 -0600 Subject: [PATCH 3/3] refactor: remove default registry values --- examples/archive_async/src/main.rs | 7 +- examples/archive_sync/src/main.rs | 4 +- postgresql_archive/benches/archive.rs | 4 +- postgresql_archive/src/archive.rs | 7 +- postgresql_archive/src/hasher/registry.rs | 77 +++++++------------ postgresql_archive/src/lib.rs | 10 +-- postgresql_archive/src/matcher/mod.rs | 2 +- postgresql_archive/src/matcher/registry.rs | 23 +++--- .../matcher/{default.rs => target_os_arch.rs} | 0 .../src/repository/github/repository.rs | 56 ++------------ postgresql_archive/src/repository/registry.rs | 40 ++++------ postgresql_archive/tests/archive.rs | 11 +-- postgresql_archive/tests/blocking.rs | 6 +- postgresql_embedded/build/bundle.rs | 4 +- postgresql_embedded/src/settings.rs | 4 +- 15 files changed, 91 insertions(+), 164 deletions(-) rename postgresql_archive/src/matcher/{default.rs => target_os_arch.rs} (100%) diff --git a/examples/archive_async/src/main.rs b/examples/archive_async/src/main.rs index 33d12c9..7c3a634 100644 --- a/examples/archive_async/src/main.rs +++ b/examples/archive_async/src/main.rs @@ -1,12 +1,15 @@ #![forbid(unsafe_code)] #![deny(clippy::pedantic)] -use postgresql_archive::{extract, get_archive, Result, VersionReq, DEFAULT_POSTGRESQL_URL}; +use postgresql_archive::{ + extract, get_archive, Result, VersionReq, THESEUS_POSTGRESQL_BINARIES_URL, +}; #[tokio::main] async fn main() -> Result<()> { let version_req = VersionReq::STAR; - let (archive_version, archive) = get_archive(DEFAULT_POSTGRESQL_URL, &version_req).await?; + let (archive_version, archive) = + get_archive(THESEUS_POSTGRESQL_BINARIES_URL, &version_req).await?; let out_dir = tempfile::tempdir()?.into_path(); extract(&archive, &out_dir).await?; println!( diff --git a/examples/archive_sync/src/main.rs b/examples/archive_sync/src/main.rs index 4214e0b..5d9d7dc 100644 --- a/examples/archive_sync/src/main.rs +++ b/examples/archive_sync/src/main.rs @@ -2,11 +2,11 @@ #![deny(clippy::pedantic)] use postgresql_archive::blocking::{extract, get_archive}; -use postgresql_archive::{Result, VersionReq, DEFAULT_POSTGRESQL_URL}; +use postgresql_archive::{Result, VersionReq, THESEUS_POSTGRESQL_BINARIES_URL}; fn main() -> Result<()> { let version_req = VersionReq::STAR; - let (archive_version, archive) = get_archive(DEFAULT_POSTGRESQL_URL, &version_req)?; + let (archive_version, archive) = get_archive(THESEUS_POSTGRESQL_BINARIES_URL, &version_req)?; let out_dir = tempfile::tempdir()?.into_path(); extract(&archive, &out_dir)?; println!( diff --git a/postgresql_archive/benches/archive.rs b/postgresql_archive/benches/archive.rs index e054584..303f4dd 100644 --- a/postgresql_archive/benches/archive.rs +++ b/postgresql_archive/benches/archive.rs @@ -1,6 +1,6 @@ use criterion::{criterion_group, criterion_main, Criterion}; use postgresql_archive::blocking::{extract, get_archive}; -use postgresql_archive::{Result, VersionReq, DEFAULT_POSTGRESQL_URL}; +use postgresql_archive::{Result, VersionReq, THESEUS_POSTGRESQL_BINARIES_URL}; use std::fs::{create_dir_all, remove_dir_all}; use std::time::Duration; @@ -10,7 +10,7 @@ fn benchmarks(criterion: &mut Criterion) { fn bench_extract(criterion: &mut Criterion) -> Result<()> { let version_req = VersionReq::STAR; - let (_archive_version, archive) = get_archive(DEFAULT_POSTGRESQL_URL, &version_req)?; + let (_archive_version, archive) = get_archive(THESEUS_POSTGRESQL_BINARIES_URL, &version_req)?; criterion.bench_function("extract", |bencher| { bencher.iter(|| { diff --git a/postgresql_archive/src/archive.rs b/postgresql_archive/src/archive.rs index 2b381ad..a3e9e3b 100644 --- a/postgresql_archive/src/archive.rs +++ b/postgresql_archive/src/archive.rs @@ -16,7 +16,8 @@ use std::time::Duration; use tar::Archive; use tracing::{debug, instrument, warn}; -pub const DEFAULT_POSTGRESQL_URL: &str = "https://github.com/theseus-rs/postgresql-binaries"; +pub const THESEUS_POSTGRESQL_BINARIES_URL: &str = + "https://github.com/theseus-rs/postgresql-binaries"; /// Gets the version for the specified [version requirement](VersionReq). If a version for the /// [version requirement](VersionReq) is not found, then an error is returned. @@ -213,7 +214,7 @@ mod tests { #[tokio::test] async fn test_get_version() -> Result<()> { let version_req = VersionReq::parse("=16.3.0")?; - let version = get_version(DEFAULT_POSTGRESQL_URL, &version_req).await?; + let version = get_version(THESEUS_POSTGRESQL_BINARIES_URL, &version_req).await?; assert_eq!(Version::new(16, 3, 0), version); Ok(()) } @@ -221,7 +222,7 @@ mod tests { #[tokio::test] async fn test_get_archive() -> Result<()> { let version_req = VersionReq::parse("=16.3.0")?; - let (version, bytes) = get_archive(DEFAULT_POSTGRESQL_URL, &version_req).await?; + let (version, bytes) = get_archive(THESEUS_POSTGRESQL_BINARIES_URL, &version_req).await?; assert_eq!(Version::new(16, 3, 0), version); assert!(!bytes.is_empty()); Ok(()) diff --git a/postgresql_archive/src/hasher/registry.rs b/postgresql_archive/src/hasher/registry.rs index e937001..adc258f 100644 --- a/postgresql_archive/src/hasher/registry.rs +++ b/postgresql_archive/src/hasher/registry.rs @@ -1,6 +1,6 @@ -use crate::hasher::{blake2b_512, blake2s_256, sha2_256, sha2_512, sha3_256, sha3_512}; -use crate::Error::{PoisonedLock, UnsupportedRepository}; -use crate::Result; +use crate::hasher::sha2_256; +use crate::Error::{PoisonedLock, UnsupportedHasher}; +use crate::{Result, THESEUS_POSTGRESQL_BINARIES_URL}; use lazy_static::lazy_static; use std::sync::{Arc, Mutex, RwLock}; @@ -57,7 +57,7 @@ impl HasherRegistry { } } - Err(UnsupportedRepository(url.to_string())) + Err(UnsupportedHasher(url.to_string())) } } @@ -65,12 +65,12 @@ impl Default for HasherRegistry { /// Creates a new hasher registry with the default hashers registered. fn default() -> Self { let mut registry = Self::new(); - registry.register(|_, extension| Ok(extension == "blake2s"), blake2s_256::hash); - registry.register(|_, extension| Ok(extension == "blake2b"), blake2b_512::hash); - registry.register(|_, extension| Ok(extension == "sha256"), sha2_256::hash); - registry.register(|_, extension| Ok(extension == "sha512"), sha2_512::hash); - registry.register(|_, extension| Ok(extension == "sha3-256"), sha3_256::hash); - registry.register(|_, extension| Ok(extension == "sha3-512"), sha3_512::hash); + registry.register( + |url, extension| { + Ok(url.starts_with(THESEUS_POSTGRESQL_BINARIES_URL) && extension == "sha256") + }, + sha2_256::hash, + ); registry } } @@ -115,57 +115,32 @@ mod tests { #[test] fn test_register() -> Result<()> { register( - |_, extension| Ok(extension == "foo"), + |_, extension| Ok(extension == "test"), |_| Ok("42".to_string()), )?; - test_hasher("foo", "42") - } - - #[test] - fn test_sha2_256() -> Result<()> { - test_hasher( - "sha256", - "9a89c68c4c5e28b8c4a5567673d462fff515db46116f9900624d09c474f593fb", - ) - } - - #[test] - fn test_sha2_512() -> Result<()> { - test_hasher( - "sha512", - "3ad3f36979450d4f53366244ecf1010f4f9121d6888285ff14104fd5aded85d48aa171bf1e33a112602f92b7a7088b298789012fb87b9056321241a19fb74e0b", - ) - } - - #[test] - fn test_sha3_256() -> Result<()> { - test_hasher( - "sha3-256", - "c0188232190e0427fc9cc78597221c76c799528660889bd6ce1f3563148ff84d", - ) + test_hasher("test", "42") } #[test] - fn test_sha3_512() -> Result<()> { - test_hasher( - "sha3-512", - "9429fc1f9772cc1d8039fe75cc1b033cd60f0ec4face0f8a514d25b0649ba8a5954b6c7a41cc3697a56db3ff321475be1fa14b70c7eb78fec6ce62dbfc54c9d3", - ) + fn test_get_invalid_url_error() { + let error = get("https://foo.com", "foo").unwrap_err(); + assert_eq!( + "unsupported hasher for 'https://foo.com'", + error.to_string() + ); } #[test] - fn test_blake2s_256() -> Result<()> { - test_hasher( - "blake2s", - "7125921e06071710350390fe902856dbea366a5d6f5ee26c18e741143ac80061", - ) + fn test_get_invalid_extension_error() { + let error = get(THESEUS_POSTGRESQL_BINARIES_URL, "foo").unwrap_err(); + assert_eq!( + format!("unsupported hasher for '{THESEUS_POSTGRESQL_BINARIES_URL}'"), + error.to_string() + ); } #[test] - fn test_blake2b_512() -> Result<()> { - test_hasher( - "blake2b", - "67767f1cab415502dcceec9f099fb84539b1c73c5ebdcfe1bb8ca7411e3b6cb33e304f49222edac9bdaa74129e9e13f11f215b8560f9081f0e8f1f869162bf46", - ) + fn test_get_theseus_postgresql_binaries() { + assert!(get(THESEUS_POSTGRESQL_BINARIES_URL, "sha256").is_ok()); } } diff --git a/postgresql_archive/src/lib.rs b/postgresql_archive/src/lib.rs index 207c8e9..108b6e8 100644 --- a/postgresql_archive/src/lib.rs +++ b/postgresql_archive/src/lib.rs @@ -21,11 +21,11 @@ //! ### Asynchronous API //! //! ```no_run -//! use postgresql_archive::{extract, get_archive, Result, VersionReq, DEFAULT_POSTGRESQL_URL}; +//! use postgresql_archive::{extract, get_archive, Result, VersionReq, THESEUS_POSTGRESQL_BINARIES_URL}; //! //! #[tokio::main] //! async fn main() -> Result<()> { -//! let (archive_version, archive) = get_archive(DEFAULT_POSTGRESQL_URL, &VersionReq::STAR).await?; +//! let (archive_version, archive) = get_archive(THESEUS_POSTGRESQL_BINARIES_URL, &VersionReq::STAR).await?; //! let out_dir = std::env::temp_dir(); //! extract(&archive, &out_dir).await //! } @@ -34,10 +34,10 @@ //! ### Synchronous API //! ```no_run //! #[cfg(feature = "blocking")] { -//! use postgresql_archive::{VersionReq, DEFAULT_POSTGRESQL_URL}; +//! use postgresql_archive::{VersionReq, THESEUS_POSTGRESQL_BINARIES_URL}; //! use postgresql_archive::blocking::{extract, get_archive}; //! -//! let (archive_version, archive) = get_archive(DEFAULT_POSTGRESQL_URL, &VersionReq::STAR).unwrap(); +//! let (archive_version, archive) = get_archive(THESEUS_POSTGRESQL_BINARIES_URL, &VersionReq::STAR).unwrap(); //! let out_dir = std::env::temp_dir(); //! let result = extract(&archive, &out_dir).unwrap(); //! } @@ -118,7 +118,7 @@ pub mod matcher; pub mod repository; mod version; -pub use archive::DEFAULT_POSTGRESQL_URL; +pub use archive::THESEUS_POSTGRESQL_BINARIES_URL; pub use archive::{extract, get_archive, get_version}; pub use error::{Error, Result}; pub use semver::{Version, VersionReq}; diff --git a/postgresql_archive/src/matcher/mod.rs b/postgresql_archive/src/matcher/mod.rs index 8fb17cd..dc1f02f 100644 --- a/postgresql_archive/src/matcher/mod.rs +++ b/postgresql_archive/src/matcher/mod.rs @@ -1,3 +1,3 @@ -pub mod default; pub mod postgresql_binaries; pub mod registry; +pub mod target_os_arch; diff --git a/postgresql_archive/src/matcher/registry.rs b/postgresql_archive/src/matcher/registry.rs index 7059344..03c30d1 100644 --- a/postgresql_archive/src/matcher/registry.rs +++ b/postgresql_archive/src/matcher/registry.rs @@ -1,6 +1,6 @@ -use crate::matcher::{default, postgresql_binaries}; +use crate::matcher::postgresql_binaries; use crate::Error::{PoisonedLock, UnsupportedMatcher}; -use crate::{Result, DEFAULT_POSTGRESQL_URL}; +use crate::{Result, THESEUS_POSTGRESQL_BINARIES_URL}; use lazy_static::lazy_static; use semver::Version; use std::sync::{Arc, Mutex, RwLock}; @@ -66,10 +66,9 @@ impl Default for MatchersRegistry { fn default() -> Self { let mut registry = Self::new(); registry.register( - |url| Ok(url == DEFAULT_POSTGRESQL_URL), + |url| Ok(url == THESEUS_POSTGRESQL_BINARIES_URL), postgresql_binaries::matcher, ); - registry.register(|_| Ok(true), default::matcher); registry } } @@ -102,7 +101,6 @@ pub fn get>(url: S) -> Result { #[cfg(test)] mod tests { use super::*; - use std::env; #[test] fn test_register() -> Result<()> { @@ -119,14 +117,13 @@ mod tests { } #[test] - fn test_default_matcher() -> Result<()> { - let matcher = get("https://foo.com")?; - let version = Version::new(16, 3, 0); - let os = env::consts::OS; - let arch = env::consts::ARCH; - let name = format!("plugin_csv.pg16-{os}_{arch}.tar.gz"); + fn test_get_error() { + let result = get("foo").unwrap_err(); + assert_eq!("unsupported matcher for 'foo'", result.to_string()); + } - assert!(matcher(name.as_str(), &version)?, "{}", name); - Ok(()) + #[test] + fn test_get_theseus_postgresql_binaries() { + assert!(get(THESEUS_POSTGRESQL_BINARIES_URL).is_ok()); } } diff --git a/postgresql_archive/src/matcher/default.rs b/postgresql_archive/src/matcher/target_os_arch.rs similarity index 100% rename from postgresql_archive/src/matcher/default.rs rename to postgresql_archive/src/matcher/target_os_arch.rs diff --git a/postgresql_archive/src/repository/github/repository.rs b/postgresql_archive/src/repository/github/repository.rs index b072cd8..ebae93f 100644 --- a/postgresql_archive/src/repository/github/repository.rs +++ b/postgresql_archive/src/repository/github/repository.rs @@ -80,18 +80,6 @@ impl GitHub { })) } - /// Determines if the specified URL is supported by the GitHub repository. - /// - /// # Errors - /// * If the URL cannot be parsed. - pub fn supports(url: &str) -> bool { - let Ok(parsed_url) = Url::parse(url) else { - return false; - }; - let host = parsed_url.host_str().unwrap_or_default(); - host.contains("github.com") - } - /// Gets the version from the specified tag name. /// /// # Errors @@ -334,22 +322,11 @@ fn reqwest_client() -> ClientWithMiddleware { #[cfg(test)] mod tests { use super::*; - - const URL: &str = "https://github.com/theseus-rs/postgresql-binaries"; - - #[test] - fn test_supports() { - assert!(GitHub::supports(URL)); - } - - #[test] - fn test_supports_error() { - assert!(!GitHub::supports("https://foo.com")); - } + use crate::THESEUS_POSTGRESQL_BINARIES_URL; #[test] fn test_name() { - let github = GitHub::new(URL).unwrap(); + let github = GitHub::new(THESEUS_POSTGRESQL_BINARIES_URL).unwrap(); assert_eq!("GitHub", github.name()); } @@ -379,7 +356,7 @@ mod tests { #[tokio::test] async fn test_get_version() -> Result<()> { - let github = GitHub::new(URL)?; + let github = GitHub::new(THESEUS_POSTGRESQL_BINARIES_URL)?; let version_req = VersionReq::STAR; let version = github.get_version(&version_req).await?; assert!(version > Version::new(0, 0, 0)); @@ -388,7 +365,7 @@ mod tests { #[tokio::test] async fn test_get_specific_version() -> Result<()> { - let github = GitHub::new(URL)?; + let github = GitHub::new(THESEUS_POSTGRESQL_BINARIES_URL)?; let version_req = VersionReq::parse("=16.3.0")?; let version = github.get_version(&version_req).await?; assert_eq!(Version::new(16, 3, 0), version); @@ -397,7 +374,7 @@ mod tests { #[tokio::test] async fn test_get_specific_not_found() -> Result<()> { - let github = GitHub::new(URL)?; + let github = GitHub::new(THESEUS_POSTGRESQL_BINARIES_URL)?; let version_req = VersionReq::parse("=0.0.0")?; let error = github.get_version(&version_req).await.unwrap_err(); assert_eq!("version not found for '=0.0.0'", error.to_string()); @@ -410,7 +387,7 @@ mod tests { #[tokio::test] async fn test_get_archive() -> Result<()> { - let github = GitHub::new(URL)?; + let github = GitHub::new(THESEUS_POSTGRESQL_BINARIES_URL)?; let version_req = VersionReq::parse("=16.3.0")?; let archive = github.get_archive(&version_req).await?; assert_eq!( @@ -436,25 +413,4 @@ mod tests { assert_eq!(Version::new(0, 12, 0), version); Ok(()) } - - /// Test that a version with a 'v' prefix is correctly parsed; this is a common convention - /// for GitHub releases. Use a known PostgreSQL plugin repository for the test. - #[tokio::test] - async fn test_get_archive_with_v_prefix() -> Result<()> { - let github = GitHub::new("https://github.com/turbot/steampipe-plugin-csv")?; - let version_req = VersionReq::parse("=0.12.0")?; - let archive = github.get_archive(&version_req).await?; - let name = archive.name(); - // Note: this plugin repository has 3 artifacts that can match: - // steampipe_export... - // steampipe_postgres... - // steampipe_sqlite... - // custom matchers will be needed to disambiguate plugins - assert!(name.starts_with("steampipe_")); - assert!(name.contains("csv")); - assert!(name.ends_with(".tar.gz")); - assert_eq!(&Version::new(0, 12, 0), archive.version()); - assert!(!archive.bytes().is_empty()); - Ok(()) - } } diff --git a/postgresql_archive/src/repository/registry.rs b/postgresql_archive/src/repository/registry.rs index cafefb7..8bb5eb7 100644 --- a/postgresql_archive/src/repository/registry.rs +++ b/postgresql_archive/src/repository/registry.rs @@ -1,7 +1,7 @@ use crate::repository::github::repository::GitHub; use crate::repository::model::Repository; use crate::Error::{PoisonedLock, UnsupportedRepository}; -use crate::Result; +use crate::{Result, THESEUS_POSTGRESQL_BINARIES_URL}; use lazy_static::lazy_static; use std::sync::{Arc, Mutex, RwLock}; @@ -10,7 +10,7 @@ lazy_static! { Arc::new(Mutex::new(RepositoryRegistry::default())); } -type SupportsFn = dyn Fn(&str) -> bool + Send + Sync; +type SupportsFn = fn(&str) -> Result; type NewFn = dyn Fn(&str) -> Result> + Send + Sync; /// Singleton struct to store repositories @@ -28,7 +28,7 @@ impl RepositoryRegistry { } /// Registers a repository. Newly registered repositories take precedence over existing ones. - fn register(&mut self, supports_fn: Box, new_fn: Box) { + fn register(&mut self, supports_fn: SupportsFn, new_fn: Box) { self.repositories.insert( 0, ( @@ -47,7 +47,7 @@ impl RepositoryRegistry { let supports_function = supports_fn .read() .map_err(|error| PoisonedLock(error.to_string()))?; - if supports_function(url) { + if supports_function(url)? { let new_function = new_fn .read() .map_err(|error| PoisonedLock(error.to_string()))?; @@ -63,7 +63,10 @@ impl Default for RepositoryRegistry { /// Creates a new repository registry with the default repositories registered. fn default() -> Self { let mut registry = Self::new(); - registry.register(Box::new(GitHub::supports), Box::new(GitHub::new)); + registry.register( + |url| Ok(url.starts_with(THESEUS_POSTGRESQL_BINARIES_URL)), + Box::new(GitHub::new), + ); registry } } @@ -73,7 +76,7 @@ impl Default for RepositoryRegistry { /// # Errors /// * If the registry is poisoned. #[allow(dead_code)] -pub fn register(supports_fn: Box, new_fn: Box) -> Result<()> { +pub fn register(supports_fn: SupportsFn, new_fn: Box) -> Result<()> { let mut registry = REGISTRY .lock() .map_err(|error| PoisonedLock(error.to_string()))?; @@ -109,10 +112,6 @@ mod tests { fn new(_url: &str) -> Result> { Ok(Box::new(Self)) } - - fn supports(url: &str) -> bool { - url == "https://foo.com" - } } #[async_trait] @@ -137,7 +136,7 @@ mod tests { #[tokio::test] async fn test_register() -> Result<()> { register( - Box::new(TestRepository::supports), + |url| Ok(url == "https://foo.com"), Box::new(TestRepository::new), )?; let url = "https://foo.com"; @@ -148,19 +147,14 @@ mod tests { Ok(()) } - #[tokio::test] - async fn test_get_no_host() -> Result<()> { - let url = "https://"; - let error = get(url).err().unwrap(); - assert_eq!("unsupported repository for 'https://'", error.to_string()); - Ok(()) + #[test] + fn test_get_error() { + let error = get("foo").unwrap_err(); + assert_eq!("unsupported repository for 'foo'", error.to_string()); } - #[tokio::test] - async fn test_get_github() -> Result<()> { - let url = "https://github.com/theseus-rs/postgresql-binaries"; - let result = get(url); - assert!(result.is_ok()); - Ok(()) + #[test] + fn test_get_theseus_postgresql_binaries() { + assert!(get(THESEUS_POSTGRESQL_BINARIES_URL).is_ok()); } } diff --git a/postgresql_archive/tests/archive.rs b/postgresql_archive/tests/archive.rs index fdb6363..1d515c8 100644 --- a/postgresql_archive/tests/archive.rs +++ b/postgresql_archive/tests/archive.rs @@ -1,6 +1,6 @@ #[allow(deprecated)] use postgresql_archive::extract; -use postgresql_archive::{get_archive, get_version, DEFAULT_POSTGRESQL_URL}; +use postgresql_archive::{get_archive, get_version, THESEUS_POSTGRESQL_BINARIES_URL}; use semver::VersionReq; use std::fs::{create_dir_all, remove_dir_all}; use test_log::test; @@ -8,7 +8,7 @@ use test_log::test; #[test(tokio::test)] async fn test_get_version_not_found() -> postgresql_archive::Result<()> { let invalid_version_req = VersionReq::parse("=1.0.0")?; - let result = get_version(DEFAULT_POSTGRESQL_URL, &invalid_version_req).await; + let result = get_version(THESEUS_POSTGRESQL_BINARIES_URL, &invalid_version_req).await; assert!(result.is_err()); Ok(()) @@ -17,7 +17,7 @@ async fn test_get_version_not_found() -> postgresql_archive::Result<()> { #[test(tokio::test)] async fn test_get_version() -> anyhow::Result<()> { let version_req = VersionReq::parse("=16.3.0")?; - let latest_version = get_version(DEFAULT_POSTGRESQL_URL, &version_req).await?; + let latest_version = get_version(THESEUS_POSTGRESQL_BINARIES_URL, &version_req).await?; assert!(version_req.matches(&latest_version)); Ok(()) @@ -26,7 +26,8 @@ async fn test_get_version() -> anyhow::Result<()> { #[test(tokio::test)] async fn test_get_archive_and_extract() -> anyhow::Result<()> { let version_req = VersionReq::STAR; - let (archive_version, archive) = get_archive(DEFAULT_POSTGRESQL_URL, &version_req).await?; + let (archive_version, archive) = + get_archive(THESEUS_POSTGRESQL_BINARIES_URL, &version_req).await?; assert!(version_req.matches(&archive_version)); @@ -40,7 +41,7 @@ async fn test_get_archive_and_extract() -> anyhow::Result<()> { #[test(tokio::test)] async fn test_get_archive_version_not_found() -> postgresql_archive::Result<()> { let invalid_version_req = VersionReq::parse("=1.0.0")?; - let result = get_archive(DEFAULT_POSTGRESQL_URL, &invalid_version_req).await; + let result = get_archive(THESEUS_POSTGRESQL_BINARIES_URL, &invalid_version_req).await; assert!(result.is_err()); Ok(()) diff --git a/postgresql_archive/tests/blocking.rs b/postgresql_archive/tests/blocking.rs index cfeff47..c557118 100644 --- a/postgresql_archive/tests/blocking.rs +++ b/postgresql_archive/tests/blocking.rs @@ -1,7 +1,7 @@ #[cfg(feature = "blocking")] use postgresql_archive::blocking::{extract, get_archive, get_version}; #[cfg(feature = "blocking")] -use postgresql_archive::{VersionReq, DEFAULT_POSTGRESQL_URL}; +use postgresql_archive::{VersionReq, THESEUS_POSTGRESQL_BINARIES_URL}; #[cfg(feature = "blocking")] use std::fs::{create_dir_all, remove_dir_all}; #[cfg(feature = "blocking")] @@ -11,7 +11,7 @@ use test_log::test; #[test] fn test_get_version() -> anyhow::Result<()> { let version_req = VersionReq::STAR; - let latest_version = get_version(DEFAULT_POSTGRESQL_URL, &version_req)?; + let latest_version = get_version(THESEUS_POSTGRESQL_BINARIES_URL, &version_req)?; assert!(version_req.matches(&latest_version)); Ok(()) @@ -22,7 +22,7 @@ fn test_get_version() -> anyhow::Result<()> { #[allow(deprecated)] fn test_get_archive_and_extract() -> anyhow::Result<()> { let version_req = &VersionReq::STAR; - let (archive_version, archive) = get_archive(DEFAULT_POSTGRESQL_URL, version_req)?; + let (archive_version, archive) = get_archive(THESEUS_POSTGRESQL_BINARIES_URL, version_req)?; assert!(version_req.matches(&archive_version)); diff --git a/postgresql_embedded/build/bundle.rs b/postgresql_embedded/build/bundle.rs index 3afc2b5..da00394 100644 --- a/postgresql_embedded/build/bundle.rs +++ b/postgresql_embedded/build/bundle.rs @@ -2,7 +2,7 @@ use anyhow::Result; use postgresql_archive::VersionReq; -use postgresql_archive::{get_archive, DEFAULT_POSTGRESQL_URL}; +use postgresql_archive::{get_archive, THESEUS_POSTGRESQL_BINARIES_URL}; use std::fs::File; use std::io::Write; use std::path::PathBuf; @@ -15,7 +15,7 @@ use std::{env, fs}; /// downloaded at runtime. pub(crate) async fn stage_postgresql_archive() -> Result<()> { let releases_url = - env::var("POSTGRESQL_RELEASES_URL").unwrap_or(DEFAULT_POSTGRESQL_URL.to_string()); + env::var("POSTGRESQL_RELEASES_URL").unwrap_or(THESEUS_POSTGRESQL_BINARIES_URL.to_string()); println!("PostgreSQL releases URL: {releases_url}"); let postgres_version_req = env::var("POSTGRESQL_VERSION").unwrap_or("*".to_string()); let version_req = VersionReq::from_str(postgres_version_req.as_str())?; diff --git a/postgresql_embedded/src/settings.rs b/postgresql_embedded/src/settings.rs index 42bb3a3..54d7028 100644 --- a/postgresql_embedded/src/settings.rs +++ b/postgresql_embedded/src/settings.rs @@ -1,6 +1,6 @@ use crate::error::{Error, Result}; use home::home_dir; -use postgresql_archive::{VersionReq, DEFAULT_POSTGRESQL_URL}; +use postgresql_archive::{VersionReq, THESEUS_POSTGRESQL_BINARIES_URL}; use rand::distributions::Alphanumeric; use rand::Rng; use std::collections::HashMap; @@ -92,7 +92,7 @@ impl Settings { .collect(); Self { - releases_url: DEFAULT_POSTGRESQL_URL.to_string(), + releases_url: THESEUS_POSTGRESQL_BINARIES_URL.to_string(), version: default_version(), installation_dir: home_dir.join(".theseus").join("postgresql"), password_file,