Skip to content

Commit

Permalink
Merge pull request #91 from theseus-rs/add-hasher-and-matcher-support…
Browse files Browse the repository at this point in the history
…s-fn

feat: add hasher and matcher supports function
  • Loading branch information
brianheineman authored Jun 29, 2024
2 parents e8ef1ec + 9dfc011 commit 3d74b9a
Show file tree
Hide file tree
Showing 16 changed files with 219 additions and 352 deletions.
7 changes: 5 additions & 2 deletions examples/archive_async/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
#![forbid(unsafe_code)]
#![deny(clippy::pedantic)]

use postgresql_archive::{extract, get_archive, Result, VersionReq, DEFAULT_POSTGRESQL_URL};
use postgresql_archive::{
extract, get_archive, Result, VersionReq, THESEUS_POSTGRESQL_BINARIES_URL,
};

#[tokio::main]
async fn main() -> Result<()> {
let version_req = VersionReq::STAR;
let (archive_version, archive) = get_archive(DEFAULT_POSTGRESQL_URL, &version_req).await?;
let (archive_version, archive) =
get_archive(THESEUS_POSTGRESQL_BINARIES_URL, &version_req).await?;
let out_dir = tempfile::tempdir()?.into_path();
extract(&archive, &out_dir).await?;
println!(
Expand Down
4 changes: 2 additions & 2 deletions examples/archive_sync/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
#![deny(clippy::pedantic)]

use postgresql_archive::blocking::{extract, get_archive};
use postgresql_archive::{Result, VersionReq, DEFAULT_POSTGRESQL_URL};
use postgresql_archive::{Result, VersionReq, THESEUS_POSTGRESQL_BINARIES_URL};

fn main() -> Result<()> {
let version_req = VersionReq::STAR;
let (archive_version, archive) = get_archive(DEFAULT_POSTGRESQL_URL, &version_req)?;
let (archive_version, archive) = get_archive(THESEUS_POSTGRESQL_BINARIES_URL, &version_req)?;
let out_dir = tempfile::tempdir()?.into_path();
extract(&archive, &out_dir)?;
println!(
Expand Down
4 changes: 2 additions & 2 deletions postgresql_archive/benches/archive.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use criterion::{criterion_group, criterion_main, Criterion};
use postgresql_archive::blocking::{extract, get_archive};
use postgresql_archive::{Result, VersionReq, DEFAULT_POSTGRESQL_URL};
use postgresql_archive::{Result, VersionReq, THESEUS_POSTGRESQL_BINARIES_URL};
use std::fs::{create_dir_all, remove_dir_all};
use std::time::Duration;

Expand All @@ -10,7 +10,7 @@ fn benchmarks(criterion: &mut Criterion) {

fn bench_extract(criterion: &mut Criterion) -> Result<()> {
let version_req = VersionReq::STAR;
let (_archive_version, archive) = get_archive(DEFAULT_POSTGRESQL_URL, &version_req)?;
let (_archive_version, archive) = get_archive(THESEUS_POSTGRESQL_BINARIES_URL, &version_req)?;

criterion.bench_function("extract", |bencher| {
bencher.iter(|| {
Expand Down
7 changes: 4 additions & 3 deletions postgresql_archive/src/archive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ use std::time::Duration;
use tar::Archive;
use tracing::{debug, instrument, warn};

pub const DEFAULT_POSTGRESQL_URL: &str = "https://github.com/theseus-rs/postgresql-binaries";
pub const THESEUS_POSTGRESQL_BINARIES_URL: &str =
"https://github.com/theseus-rs/postgresql-binaries";

/// Gets the version for the specified [version requirement](VersionReq). If a version for the
/// [version requirement](VersionReq) is not found, then an error is returned.
Expand Down Expand Up @@ -213,15 +214,15 @@ mod tests {
#[tokio::test]
async fn test_get_version() -> Result<()> {
let version_req = VersionReq::parse("=16.3.0")?;
let version = get_version(DEFAULT_POSTGRESQL_URL, &version_req).await?;
let version = get_version(THESEUS_POSTGRESQL_BINARIES_URL, &version_req).await?;
assert_eq!(Version::new(16, 3, 0), version);
Ok(())
}

#[tokio::test]
async fn test_get_archive() -> Result<()> {
let version_req = VersionReq::parse("=16.3.0")?;
let (version, bytes) = get_archive(DEFAULT_POSTGRESQL_URL, &version_req).await?;
let (version, bytes) = get_archive(THESEUS_POSTGRESQL_BINARIES_URL, &version_req).await?;
assert_eq!(Version::new(16, 3, 0), version);
assert!(!bytes.is_empty());
Ok(())
Expand Down
6 changes: 6 additions & 0 deletions postgresql_archive/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ pub enum Error {
/// Unexpected error
#[error("{0}")]
Unexpected(String),
/// Unsupported hasher
#[error("unsupported hasher for '{0}'")]
UnsupportedHasher(String),
/// Unsupported hasher
#[error("unsupported matcher for '{0}'")]
UnsupportedMatcher(String),
/// Unsupported repository
#[error("unsupported repository for '{0}'")]
UnsupportedRepository(String),
Expand Down
183 changes: 63 additions & 120 deletions postgresql_archive/src/hasher/registry.rs
Original file line number Diff line number Diff line change
@@ -1,132 +1,111 @@
use crate::hasher::{blake2b_512, blake2s_256, sha2_256, sha2_512, sha3_256, sha3_512};
use crate::Error::PoisonedLock;
use crate::Result;
use crate::hasher::sha2_256;
use crate::Error::{PoisonedLock, UnsupportedHasher};
use crate::{Result, THESEUS_POSTGRESQL_BINARIES_URL};
use lazy_static::lazy_static;
use std::collections::HashMap;
use std::sync::{Arc, Mutex, RwLock};

lazy_static! {
static ref REGISTRY: Arc<Mutex<HasherRegistry>> =
Arc::new(Mutex::new(HasherRegistry::default()));
}

pub type SupportsFn = fn(&str, &str) -> Result<bool>;
pub type HasherFn = fn(&Vec<u8>) -> Result<String>;

/// Singleton struct to store hashers
#[allow(clippy::type_complexity)]
struct HasherRegistry {
hashers: HashMap<String, Arc<RwLock<HasherFn>>>,
hashers: Vec<(Arc<RwLock<SupportsFn>>, Arc<RwLock<HasherFn>>)>,
}

impl HasherRegistry {
/// Creates a new hasher registry.
fn new() -> Self {
Self {
hashers: HashMap::new(),
hashers: Vec::new(),
}
}

/// Registers a hasher for an extension. Newly registered hashers with the same extension will
/// override existing ones.
fn register<S: AsRef<str>>(&mut self, extension: S, hasher_fn: HasherFn) {
let extension = extension.as_ref().to_string();
self.hashers
.insert(extension, Arc::new(RwLock::new(hasher_fn)));
/// Registers a hasher for a supports function. Newly registered hashers will take precedence
/// over existing ones.
fn register(&mut self, supports_fn: SupportsFn, hasher_fn: HasherFn) {
self.hashers.insert(
0,
(
Arc::new(RwLock::new(supports_fn)),
Arc::new(RwLock::new(hasher_fn)),
),
);
}

/// Get a hasher for the specified extension.
/// Get a hasher for the specified url and extension.
///
/// # Errors
/// * If the registry is poisoned.
fn get<S: AsRef<str>>(&self, extension: S) -> Result<Option<HasherFn>> {
let extension = extension.as_ref().to_string();
if let Some(hasher) = self.hashers.get(&extension) {
let hasher = *hasher
fn get<S: AsRef<str>>(&self, url: S, extension: S) -> Result<HasherFn> {
let url = url.as_ref();
let extension = extension.as_ref();
for (supports_fn, hasher_fn) in &self.hashers {
let supports_function = supports_fn
.read()
.map_err(|error| PoisonedLock(error.to_string()))?;
return Ok(Some(hasher));
if supports_function(url, extension)? {
let hasher_function = hasher_fn
.read()
.map_err(|error| PoisonedLock(error.to_string()))?;
return Ok(*hasher_function);
}
}

Ok(None)
}

/// Get the number of hashers in the registry.
fn len(&self) -> usize {
self.hashers.len()
}

/// Check if the registry is empty.
fn is_empty(&self) -> bool {
self.hashers.is_empty()
Err(UnsupportedHasher(url.to_string()))
}
}

impl Default for HasherRegistry {
/// Creates a new hasher registry with the default hashers registered.
fn default() -> Self {
let mut registry = Self::new();
registry.register("blake2s", blake2s_256::hash);
registry.register("blake2b", blake2b_512::hash);
registry.register("sha256", sha2_256::hash);
registry.register("sha512", sha2_512::hash);
registry.register("sha3-256", sha3_256::hash);
registry.register("sha3-512", sha3_512::hash);
registry.register(
|url, extension| {
Ok(url.starts_with(THESEUS_POSTGRESQL_BINARIES_URL) && extension == "sha256")
},
sha2_256::hash,
);
registry
}
}

/// Registers a hasher for an extension. Newly registered hashers with the same extension will
/// override existing ones.
/// Registers a hasher for a supports function. Newly registered hashers will take precedence
/// over existing ones.
///
/// # Errors
/// * If the registry is poisoned.
#[allow(dead_code)]
pub fn register<S: AsRef<str>>(extension: S, hasher_fn: HasherFn) -> Result<()> {
pub fn register(supports_fn: SupportsFn, hasher_fn: HasherFn) -> Result<()> {
let mut registry = REGISTRY
.lock()
.map_err(|error| PoisonedLock(error.to_string()))?;
registry.register(extension, hasher_fn);
registry.register(supports_fn, hasher_fn);
Ok(())
}

/// Get a hasher for the specified extension.
/// Get a hasher for the specified url and extension.
///
/// # Errors
/// * If the registry is poisoned.
pub fn get<S: AsRef<str>>(extension: S) -> Result<Option<HasherFn>> {
pub fn get<S: AsRef<str>>(url: S, extension: S) -> Result<HasherFn> {
let registry = REGISTRY
.lock()
.map_err(|error| PoisonedLock(error.to_string()))?;
registry.get(extension)
}

/// Get the number of matchers in the registry.
///
/// # Errors
/// * If the registry is poisoned.
pub fn len() -> Result<usize> {
let registry = REGISTRY
.lock()
.map_err(|error| PoisonedLock(error.to_string()))?;
Ok(registry.len())
}

/// Check if the registry is empty.
///
/// # Errors
/// * If the registry is poisoned.
pub fn is_empty() -> Result<bool> {
let registry = REGISTRY
.lock()
.map_err(|error| PoisonedLock(error.to_string()))?;
Ok(registry.is_empty())
registry.get(url, extension)
}

#[cfg(test)]
mod tests {
use super::*;

fn test_hasher(extension: &str, expected: &str) -> Result<()> {
let hasher = get(extension)?.unwrap();
let hasher = get("https://foo.com", extension)?;
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0];
let hash = hasher(&data)?;
assert_eq!(expected, hash);
Expand All @@ -135,69 +114,33 @@ mod tests {

#[test]
fn test_register() -> Result<()> {
let extension = "sha256";
let hashers = len()?;
assert!(!is_empty()?);
REGISTRY
.lock()
.map_err(|error| PoisonedLock(error.to_string()))?
.hashers
.remove(extension);
assert_ne!(hashers, len()?);
register(extension, sha2_256::hash)?;
assert_eq!(hashers, len()?);

test_hasher(
extension,
"9a89c68c4c5e28b8c4a5567673d462fff515db46116f9900624d09c474f593fb",
)
}

#[test]
fn test_sha2_256() -> Result<()> {
test_hasher(
"sha256",
"9a89c68c4c5e28b8c4a5567673d462fff515db46116f9900624d09c474f593fb",
)
}

#[test]
fn test_sha2_512() -> Result<()> {
test_hasher(
"sha512",
"3ad3f36979450d4f53366244ecf1010f4f9121d6888285ff14104fd5aded85d48aa171bf1e33a112602f92b7a7088b298789012fb87b9056321241a19fb74e0b",
)
}

#[test]
fn test_sha3_256() -> Result<()> {
test_hasher(
"sha3-256",
"c0188232190e0427fc9cc78597221c76c799528660889bd6ce1f3563148ff84d",
)
register(
|_, extension| Ok(extension == "test"),
|_| Ok("42".to_string()),
)?;
test_hasher("test", "42")
}

#[test]
fn test_sha3_512() -> Result<()> {
test_hasher(
"sha3-512",
"9429fc1f9772cc1d8039fe75cc1b033cd60f0ec4face0f8a514d25b0649ba8a5954b6c7a41cc3697a56db3ff321475be1fa14b70c7eb78fec6ce62dbfc54c9d3",
)
fn test_get_invalid_url_error() {
let error = get("https://foo.com", "foo").unwrap_err();
assert_eq!(
"unsupported hasher for 'https://foo.com'",
error.to_string()
);
}

#[test]
fn test_blake2s_256() -> Result<()> {
test_hasher(
"blake2s",
"7125921e06071710350390fe902856dbea366a5d6f5ee26c18e741143ac80061",
)
fn test_get_invalid_extension_error() {
let error = get(THESEUS_POSTGRESQL_BINARIES_URL, "foo").unwrap_err();
assert_eq!(
format!("unsupported hasher for '{THESEUS_POSTGRESQL_BINARIES_URL}'"),
error.to_string()
);
}

#[test]
fn test_blake2b_512() -> Result<()> {
test_hasher(
"blake2b",
"67767f1cab415502dcceec9f099fb84539b1c73c5ebdcfe1bb8ca7411e3b6cb33e304f49222edac9bdaa74129e9e13f11f215b8560f9081f0e8f1f869162bf46",
)
fn test_get_theseus_postgresql_binaries() {
assert!(get(THESEUS_POSTGRESQL_BINARIES_URL, "sha256").is_ok());
}
}
10 changes: 5 additions & 5 deletions postgresql_archive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
//! ### Asynchronous API
//!
//! ```no_run
//! use postgresql_archive::{extract, get_archive, Result, VersionReq, DEFAULT_POSTGRESQL_URL};
//! use postgresql_archive::{extract, get_archive, Result, VersionReq, THESEUS_POSTGRESQL_BINARIES_URL};
//!
//! #[tokio::main]
//! async fn main() -> Result<()> {
//! let (archive_version, archive) = get_archive(DEFAULT_POSTGRESQL_URL, &VersionReq::STAR).await?;
//! let (archive_version, archive) = get_archive(THESEUS_POSTGRESQL_BINARIES_URL, &VersionReq::STAR).await?;
//! let out_dir = std::env::temp_dir();
//! extract(&archive, &out_dir).await
//! }
Expand All @@ -34,10 +34,10 @@
//! ### Synchronous API
//! ```no_run
//! #[cfg(feature = "blocking")] {
//! use postgresql_archive::{VersionReq, DEFAULT_POSTGRESQL_URL};
//! use postgresql_archive::{VersionReq, THESEUS_POSTGRESQL_BINARIES_URL};
//! use postgresql_archive::blocking::{extract, get_archive};
//!
//! let (archive_version, archive) = get_archive(DEFAULT_POSTGRESQL_URL, &VersionReq::STAR).unwrap();
//! let (archive_version, archive) = get_archive(THESEUS_POSTGRESQL_BINARIES_URL, &VersionReq::STAR).unwrap();
//! let out_dir = std::env::temp_dir();
//! let result = extract(&archive, &out_dir).unwrap();
//! }
Expand Down Expand Up @@ -118,7 +118,7 @@ pub mod matcher;
pub mod repository;
mod version;

pub use archive::DEFAULT_POSTGRESQL_URL;
pub use archive::THESEUS_POSTGRESQL_BINARIES_URL;
pub use archive::{extract, get_archive, get_version};
pub use error::{Error, Result};
pub use semver::{Version, VersionReq};
Expand Down
2 changes: 1 addition & 1 deletion postgresql_archive/src/matcher/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
pub mod default;
pub mod postgresql_binaries;
pub mod registry;
pub mod target_os_arch;
Loading

0 comments on commit 3d74b9a

Please sign in to comment.