diff --git a/CHANGELOG.md b/CHANGELOG.md index 0758c9b2..bcd4c69c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,8 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) ## [Unreleased] +### ⚠️ Breaking Changes ⚠️ +- Internalized the `BuildError` type, consolidating on the `Error` type ([#228](https://github.com/opensearch-project/opensearch-rs/pull/228)) ### Added diff --git a/opensearch/Cargo.toml b/opensearch/Cargo.toml index 28bb3ba3..fd4d9885 100644 --- a/opensearch/Cargo.toml +++ b/opensearch/Cargo.toml @@ -39,6 +39,7 @@ url = "2.1" serde = { version = "1", features = ["derive"] } serde_json = "1" serde_with = "3" +thiserror = "1" void = "1.0.2" aws-credential-types = { version = "1", optional = true } aws-sigv4 = { version = "1", optional = true } diff --git a/opensearch/src/auth.rs b/opensearch/src/auth.rs index 787917ba..203e3439 100644 --- a/opensearch/src/auth.rs +++ b/opensearch/src/auth.rs @@ -97,11 +97,11 @@ impl std::convert::TryFrom<&aws_types::SdkConfig> for Credentials { fn try_from(value: &aws_types::SdkConfig) -> Result { let credentials = value .credentials_provider() - .ok_or_else(|| super::error::lib("SdkConfig does not have a credentials_provider"))? + .ok_or(crate::http::aws_auth::AwsSigV4Error::MissingCredentialsProvider)? .clone(); let region = value .region() - .ok_or_else(|| super::error::lib("SdkConfig does not have a region"))? + .ok_or(crate::http::aws_auth::AwsSigV4Error::MissingRegion)? .clone(); Ok(Credentials::AwsSigV4(credentials, region)) } diff --git a/opensearch/src/cert.rs b/opensearch/src/cert.rs index d74b54e8..f9a4172b 100644 --- a/opensearch/src/cert.rs +++ b/opensearch/src/cert.rs @@ -233,8 +233,15 @@ impl Certificate { END_CERTIFICATE if begin => { begin = false; cert.push(line); - certs.push(reqwest::Certificate::from_pem(cert.join("\n").as_bytes())?); - cert = Vec::new(); + + { + let cert = reqwest::Certificate::from_pem(cert.join("\n").as_bytes()) + .map_err(CertificateError::MalformedCertificate)?; + + certs.push(cert); + } + + cert.clear(); } _ if begin => cert.push(line), _ => {} @@ -242,9 +249,7 @@ impl Certificate { } if certs.is_empty() { - Err(crate::error::lib( - "could not find PEM certificate in input data", - )) + Err(CertificateError::MissingPemCertificate.into()) } else { Ok(Self(certs)) } @@ -252,7 +257,9 @@ impl Certificate { /// Create a `Certificate` from a binary DER encoded certificate. pub fn from_der(der: &[u8]) -> Result { - Ok(Self(vec![reqwest::Certificate::from_der(der)?])) + Ok(Self(vec![ + reqwest::Certificate::from_der(der).map_err(CertificateError::MalformedCertificate)? + ])) } /// Append a `Certificate` to the chain. @@ -279,3 +286,11 @@ impl Deref for Certificate { &self.0 } } + +#[derive(Debug, thiserror::Error)] +pub(crate) enum CertificateError { + #[error("could not find PEM certificate in input data")] + MissingPemCertificate, + #[error("malformed certificate: {0}")] + MalformedCertificate(reqwest::Error), +} diff --git a/opensearch/src/error.rs b/opensearch/src/error.rs index 82383052..4fe10ae4 100644 --- a/opensearch/src/error.rs +++ b/opensearch/src/error.rs @@ -32,125 +32,77 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -use crate::http::{transport::BuildError, StatusCode}; -use std::{error, fmt, io}; + +use crate::{ + cert::CertificateError, + http::{transport, StatusCode}, +}; + +pub(crate) type BoxError<'a> = Box; /// An error with the client. /// /// Errors that can occur include IO and parsing errors, as well as specific /// errors from OpenSearch and internal errors from the client. -#[derive(Debug)] -pub struct Error { - kind: Kind, +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +pub struct Error(Kind); + +impl From for Error +where + Kind: From, +{ + fn from(error: E) -> Self { + Self(Kind::from(error)) + } } -#[derive(Debug)] +#[derive(Debug, thiserror::Error)] enum Kind { - /// An error building the client - Build(BuildError), - - /// A general error from this library - Lib(String), + #[error("transport builder error: {0}")] + TransportBuilder(#[from] transport::BuildError), - /// HTTP library error - Http(reqwest::Error), + #[error("certificate error: {0}")] + Certificate(#[from] CertificateError), - /// IO error - Io(io::Error), + #[error("http error: {0}")] + Http(#[from] reqwest::Error), - /// JSON error - Json(serde_json::error::Error), -} + #[error("URL parse error: {0}")] + UrlParse(#[from] url::ParseError), -impl From for Error { - fn from(err: io::Error) -> Error { - Error { - kind: Kind::Io(err), - } - } -} + #[error("IO error: {0}")] + Io(#[from] std::io::Error), -impl From for Error { - fn from(err: reqwest::Error) -> Error { - Error { - kind: Kind::Http(err), - } - } -} + #[error("JSON error: {0}")] + Json(#[from] serde_json::error::Error), -impl From for Error { - fn from(err: serde_json::error::Error) -> Error { - Error { - kind: Kind::Json(err), - } - } + #[cfg(feature = "aws-auth")] + #[error("AwsSigV4 error: {0}")] + AwsSigV4(#[from] crate::http::aws_auth::AwsSigV4Error), } -impl From for Error { - fn from(err: url::ParseError) -> Error { - Error { - kind: Kind::Lib(err.to_string()), - } - } -} - -impl From for Error { - fn from(err: BuildError) -> Error { - Error { - kind: Kind::Build(err), - } - } -} - -pub(crate) fn lib(err: impl Into) -> Error { - Error { - kind: Kind::Lib(err.into()), - } -} +use Kind::*; impl Error { /// The status code, if the error was generated from a response pub fn status_code(&self) -> Option { - match &self.kind { - Kind::Http(err) => err.status(), + match &self.0 { + Http(err) => err.status(), _ => None, } } /// Returns true if the error is related to a timeout pub fn is_timeout(&self) -> bool { - match &self.kind { - Kind::Http(err) => err.is_timeout(), + match &self.0 { + Http(err) => err.is_timeout(), _ => false, } } /// Returns true if the error is related to serialization or deserialization pub fn is_json(&self) -> bool { - matches!(self.kind, Kind::Json(_)) - } -} - -impl error::Error for Error { - fn source(&self) -> Option<&(dyn error::Error + 'static)> { - match &self.kind { - Kind::Build(err) => Some(err), - Kind::Lib(_) => None, - Kind::Http(err) => Some(err), - Kind::Io(err) => Some(err), - Kind::Json(err) => Some(err), - } - } -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match &self.kind { - Kind::Build(err) => err.fmt(f), - Kind::Lib(err) => err.fmt(f), - Kind::Http(err) => err.fmt(f), - Kind::Io(err) => err.fmt(f), - Kind::Json(err) => err.fmt(f), - } + matches!(self.0, Json(_)) } } diff --git a/opensearch/src/http/aws_auth.rs b/opensearch/src/http/aws_auth.rs index 6150ffd0..a4d7fa9a 100644 --- a/opensearch/src/http/aws_auth.rs +++ b/opensearch/src/http/aws_auth.rs @@ -9,7 +9,7 @@ * GitHub history for details. */ -use crate::http::headers::HeaderValue; +use crate::{http::headers::HeaderValue, BoxError}; use aws_credential_types::provider::{ProvideCredentials, SharedCredentialsProvider}; use aws_sigv4::{ http_request::{ @@ -21,15 +21,32 @@ use aws_smithy_runtime_api::client::identity::Identity; use aws_types::{region::Region, sdk_config::SharedTimeSource}; use reqwest::Request; -pub async fn sign_request( +#[derive(Debug, thiserror::Error)] +pub(crate) enum AwsSigV4Error { + #[error("SdkConfig is does not have a credentials provider configured")] + MissingCredentialsProvider, + #[error("SdkConfig is does not have a region configured")] + MissingRegion, + #[error("signing error: {0}")] + SigningError(#[source] BoxError<'static>), +} + +fn signing_error>>(e: E) -> AwsSigV4Error { + AwsSigV4Error::SigningError(e.into()) +} + +pub(crate) async fn sign_request( request: &mut Request, credentials_provider: &SharedCredentialsProvider, service_name: &str, region: &Region, time_source: &SharedTimeSource, -) -> Result<(), Box> { +) -> Result<(), AwsSigV4Error> { let identity = { - let c = credentials_provider.provide_credentials().await?; + let c = credentials_provider + .provide_credentials() + .await + .map_err(signing_error)?; let e = c.expiry(); Identity::new(c, e) }; @@ -47,7 +64,8 @@ pub async fn sign_request( .region(region.as_ref()) .time(time_source.now()) .settings(signing_settings) - .build()?; + .build() + .map_err(signing_error)?; SigningParams::V4(p) }; @@ -68,11 +86,13 @@ pub async fn sign_request( None => SignableBody::Bytes(&[]), }; - SignableRequest::new(method, uri, headers, body)? + SignableRequest::new(method, uri, headers, body).map_err(signing_error)? }; let (new_headers, new_query_params) = { - let (instructions, _) = sign(signable_request, ¶ms)?.into_parts(); + let (instructions, _) = sign(signable_request, ¶ms) + .map_err(signing_error)? + .into_parts(); instructions.into_parts() }; diff --git a/opensearch/src/http/mod.rs b/opensearch/src/http/mod.rs index b228c9a8..7e131c8f 100644 --- a/opensearch/src/http/mod.rs +++ b/opensearch/src/http/mod.rs @@ -38,7 +38,7 @@ pub mod request; pub mod response; pub mod transport; -pub use reqwest::StatusCode; +pub use reqwest::{self, Request, StatusCode}; pub use url::Url; /// Http methods supported by OpenSearch diff --git a/opensearch/src/http/transport.rs b/opensearch/src/http/transport.rs index a9a49cb2..e8cc4822 100644 --- a/opensearch/src/http/transport.rs +++ b/opensearch/src/http/transport.rs @@ -36,6 +36,7 @@ use crate::auth::ClientCertificate; use crate::cert::CertificateValidation; use crate::{ auth::Credentials, + cert::CertificateError, error::Error, http::{ headers::{ @@ -54,60 +55,15 @@ use bytes::BytesMut; use dyn_clone::clone_trait_object; use lazy_static::lazy_static; use serde::Serialize; -use std::{ - error, fmt, - fmt::Debug, - io::{self, Write}, - time::Duration, -}; +use std::{fmt::Debug, io::Write, time::Duration}; use url::Url; -/// Error that can occur when building a [Transport] -#[derive(Debug)] -pub enum BuildError { - /// IO error - Io(io::Error), - - /// Certificate error - Cert(reqwest::Error), -} - -impl From for BuildError { - fn from(err: io::Error) -> BuildError { - BuildError::Io(err) - } -} - -impl From for BuildError { - fn from(err: reqwest::Error) -> BuildError { - BuildError::Cert(err) - } -} - -impl error::Error for BuildError { - #[allow(warnings)] - fn description(&self) -> &str { - match *self { - BuildError::Io(ref err) => err.description(), - BuildError::Cert(ref err) => err.description(), - } - } - - fn cause(&self) -> Option<&dyn error::Error> { - match *self { - BuildError::Io(ref err) => Some(err as &dyn error::Error), - BuildError::Cert(ref err) => Some(err as &dyn error::Error), - } - } -} - -impl fmt::Display for BuildError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - BuildError::Io(ref err) => fmt::Display::fmt(err, f), - BuildError::Cert(ref err) => fmt::Display::fmt(err, f), - } - } +#[derive(Debug, thiserror::Error)] +pub(crate) enum BuildError { + #[error("proxy configuration error: {0}")] + Proxy(#[source] reqwest::Error), + #[error("client configuration error: {0}")] + ClientBuilder(#[source] reqwest::Error), } /// Default address to OpenSearch running on `http://localhost:9200` @@ -270,7 +226,7 @@ impl TransportBuilder { } /// Builds a [Transport] to use to send API calls to OpenSearch. - pub fn build(self) -> Result { + pub fn build(self) -> Result { let mut client_builder = self.client_builder; if let Some(t) = self.timeout { @@ -287,12 +243,14 @@ impl TransportBuilder { Some(pass) => pass.as_str(), None => "", }; - let pkcs12 = reqwest::Identity::from_pkcs12_der(b, password)?; + let pkcs12 = reqwest::Identity::from_pkcs12_der(b, password) + .map_err(CertificateError::MalformedCertificate)?; client_builder.identity(pkcs12) } #[cfg(feature = "rustls-tls")] ClientCertificate::Pem(b) => { - let pem = reqwest::Identity::from_pem(b)?; + let pem = reqwest::Identity::from_pem(b) + .map_err(CertificateError::MalformedCertificate)?; client_builder.identity(pem) } } @@ -322,7 +280,7 @@ impl TransportBuilder { if self.disable_proxy { client_builder = client_builder.no_proxy(); } else if let Some(url) = self.proxy { - let mut proxy = reqwest::Proxy::all(url)?; + let mut proxy = reqwest::Proxy::all(url).map_err(BuildError::Proxy)?; if let Some(c) = self.proxy_credentials { proxy = match c { Credentials::Basic(u, p) => proxy.basic_auth(&u, &p), @@ -332,7 +290,7 @@ impl TransportBuilder { client_builder = client_builder.proxy(proxy); } - let client = client_builder.build()?; + let client = client_builder.build().map_err(BuildError::ClientBuilder)?; Ok(Transport { client, conn_pool: self.conn_pool, @@ -505,15 +463,12 @@ impl Transport { region, &self.sigv4_time_source, ) - .await - .map_err(|e| crate::error::lib(format!("AWSV4 Signing Failed: {}", e)))?; + .await?; } - let response = self.client.execute(request).await; - match response { - Ok(r) => Ok(Response::new(r, method)), - Err(e) => Err(e.into()), - } + let response = self.client.execute(request).await?; + + Ok(Response::new(response, method)) } }