-
Notifications
You must be signed in to change notification settings - Fork 34
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor AWS SigV4 to middleware approach
Signed-off-by: Thomas Farr <[email protected]>
- Loading branch information
Showing
25 changed files
with
978 additions
and
756 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,5 +2,6 @@ | |
members = [ | ||
"api_generator", | ||
"opensearch", | ||
"opensearch-auth-awssigv4", | ||
"yaml_test_runner" | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
[package] | ||
name = "opensearch-auth-awssigv4" | ||
version = "2.0.0" | ||
edition = "2018" | ||
authors = ["OpenSearch Contributors"] | ||
description = "Official OpenSearch Rust client AWS SigV4 support" | ||
repository = "https://github.com/opensearch-project/opensearch-rs" | ||
keywords = ["opensearch", "elasticsearch", "search", "lucene"] | ||
categories = ["api-bindings", "database"] | ||
documentation = "https://opensearch.org/docs/latest" | ||
license = "Apache-2.0" | ||
readme = "../README.md" | ||
|
||
[dependencies] | ||
aws-config = "0.54" | ||
aws-credential-types = "0.54" | ||
aws-sigv4 = "0.54" | ||
aws-types = "0.54" | ||
http = "0.2" | ||
opensearch = { path = "../opensearch" } | ||
thiserror = "1.0" | ||
tokio = { version = "~1" } | ||
|
||
[dev-dependencies] | ||
anyhow = "1.0" | ||
serde_json = "~1" | ||
tokio = { version = "~1", features = ["full"] } | ||
wiremock = "0.5" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,242 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
* | ||
* Modifications Copyright OpenSearch Contributors. See | ||
* GitHub history for details. | ||
*/ | ||
|
||
use std::convert::TryFrom; | ||
|
||
use aws_config::SdkConfig; | ||
use aws_credential_types::{ | ||
provider::{ProvideCredentials, SharedCredentialsProvider}, | ||
time_source::TimeSource, | ||
Credentials, | ||
}; | ||
use aws_sigv4::http_request::{ | ||
sign, PayloadChecksumKind, SignableBody, SignableRequest, SignatureLocation, SigningParams, | ||
SigningSettings, | ||
}; | ||
use http::header::{HeaderName, CONTENT_LENGTH, USER_AGENT}; | ||
use opensearch::http::{ | ||
middleware::{async_trait, RequestHandler, RequestPipeline, RequestPipelineError}, | ||
reqwest::Response, | ||
transport::TransportBuilder, | ||
Request, | ||
}; | ||
use thiserror::Error; | ||
|
||
pub use aws_config; | ||
pub use aws_credential_types; | ||
pub use aws_types; | ||
|
||
#[derive(Error, Debug)] | ||
pub enum BuildError { | ||
#[error("the region for signing must be provided")] | ||
MissingRegion, | ||
#[error("the credentials provider for signing must be provided")] | ||
MissingCredentialsProvider, | ||
} | ||
|
||
#[derive(Debug, Clone)] | ||
pub struct Builder { | ||
service_name: Option<String>, | ||
credentials_provider: Option<SharedCredentialsProvider>, | ||
region: Option<String>, | ||
ignored_headers: Vec<HeaderName>, | ||
time_source: Option<TimeSource>, | ||
} | ||
|
||
impl Builder { | ||
pub fn service_name(mut self, service_name: impl AsRef<str>) -> Self { | ||
self.service_name = Some(service_name.as_ref().into()); | ||
self | ||
} | ||
|
||
pub fn credentials_provider( | ||
mut self, | ||
credentials_provider: impl ProvideCredentials + 'static, | ||
) -> Self { | ||
self.credentials_provider = Some(SharedCredentialsProvider::new(credentials_provider)); | ||
self | ||
} | ||
|
||
pub fn region(mut self, region: impl AsRef<str>) -> Self { | ||
self.region = Some(region.as_ref().into()); | ||
self | ||
} | ||
|
||
pub fn ignore_header(mut self, header_name: HeaderName) -> Self { | ||
self.ignored_headers.push(header_name); | ||
self | ||
} | ||
|
||
#[doc(hidden)] | ||
pub fn time_source(mut self, time_source: TimeSource) -> Self { | ||
self.time_source = Some(time_source); | ||
self | ||
} | ||
|
||
pub fn build(self) -> Result<AwsSigV4, BuildError> { | ||
Ok(AwsSigV4 { | ||
service_name: self.service_name.unwrap_or_else(|| "es".to_string()), | ||
credentials_provider: self | ||
.credentials_provider | ||
.ok_or(BuildError::MissingCredentialsProvider)?, | ||
region: self.region.ok_or(BuildError::MissingRegion)?, | ||
ignored_headers: self.ignored_headers, | ||
time_source: self.time_source.unwrap_or_default(), | ||
}) | ||
} | ||
} | ||
|
||
impl Default for Builder { | ||
fn default() -> Self { | ||
Self { | ||
service_name: None, | ||
credentials_provider: None, | ||
region: None, | ||
ignored_headers: vec![USER_AGENT, CONTENT_LENGTH], | ||
time_source: None, | ||
} | ||
} | ||
} | ||
|
||
impl From<&SdkConfig> for Builder { | ||
fn from(value: &SdkConfig) -> Self { | ||
let credentials_provider = value.credentials_provider().cloned(); | ||
let region = value.region().map(|r| r.as_ref().into()); | ||
|
||
Self { | ||
credentials_provider, | ||
region, | ||
..Default::default() | ||
} | ||
} | ||
} | ||
|
||
impl From<SdkConfig> for Builder { | ||
fn from(value: SdkConfig) -> Self { | ||
<Self as From<&SdkConfig>>::from(&value) | ||
} | ||
} | ||
|
||
#[derive(Error, Debug)] | ||
pub enum AwsSigV4Error { | ||
#[error("invalid signing params: {0}")] | ||
InvalidSigningParams(#[from] aws_sigv4::signing_params::BuildError), | ||
#[error("unable to retrieve credentials: {0}")] | ||
FailedCredentialsRetrieval(#[from] aws_credential_types::provider::error::CredentialsError), | ||
#[error("invalid uri: {0}")] | ||
InvalidUri(#[from] http::uri::InvalidUri), | ||
#[error("unable to sign request: {0}")] | ||
FailedSigning(#[from] aws_sigv4::http_request::SigningError), | ||
} | ||
|
||
#[derive(Debug, Clone)] | ||
pub struct AwsSigV4 { | ||
service_name: String, | ||
credentials_provider: SharedCredentialsProvider, | ||
region: String, | ||
ignored_headers: Vec<HeaderName>, | ||
time_source: TimeSource, | ||
} | ||
|
||
impl AwsSigV4 { | ||
pub fn builder() -> Builder { | ||
Builder::default() | ||
} | ||
|
||
async fn sign_request(&self, request: &mut Request) -> Result<(), AwsSigV4Error> { | ||
let credentials = self.credentials_provider.provide_credentials().await?; | ||
let params = self.build_params(&credentials)?; | ||
|
||
let uri = request.url().as_str().parse()?; | ||
|
||
let signable_request = SignableRequest::new( | ||
request.method(), | ||
&uri, | ||
request.headers(), | ||
SignableBody::Bytes(request.body().and_then(|b| b.as_bytes()).unwrap_or(&[])), | ||
); | ||
|
||
let (mut instructions, _) = sign(signable_request, ¶ms)?.into_parts(); | ||
|
||
if let Some(new_headers) = instructions.take_headers() { | ||
for (name, value) in new_headers.into_iter() { | ||
request.headers_mut().insert( | ||
name.expect("AWS SigV4 signing header name must never be None"), | ||
value, | ||
); | ||
} | ||
} | ||
|
||
Ok(()) | ||
} | ||
|
||
fn build_params<'a>( | ||
&'a self, | ||
credentials: &'a Credentials, | ||
) -> Result<SigningParams<'a>, aws_sigv4::signing_params::BuildError> { | ||
let mut settings = SigningSettings::default(); | ||
settings.signature_location = SignatureLocation::Headers; | ||
settings.payload_checksum_kind = PayloadChecksumKind::XAmzSha256; | ||
settings.excluded_headers = Some(self.ignored_headers.clone()); | ||
|
||
let mut builder = SigningParams::builder() | ||
.access_key(credentials.access_key_id()) | ||
.secret_key(credentials.secret_access_key()) | ||
.service_name(&self.service_name) | ||
.region(&self.region) | ||
.time(self.time_source.now()) | ||
.settings(settings); | ||
|
||
builder.set_security_token(credentials.session_token()); | ||
|
||
builder.build() | ||
} | ||
} | ||
|
||
impl TryFrom<&SdkConfig> for AwsSigV4 { | ||
type Error = BuildError; | ||
|
||
fn try_from(value: &SdkConfig) -> Result<Self, Self::Error> { | ||
Builder::from(value).build() | ||
} | ||
} | ||
|
||
impl TryFrom<SdkConfig> for AwsSigV4 { | ||
type Error = BuildError; | ||
|
||
fn try_from(value: SdkConfig) -> Result<Self, Self::Error> { | ||
<Self as TryFrom<&SdkConfig>>::try_from(&value) | ||
} | ||
} | ||
|
||
#[async_trait] | ||
impl RequestHandler for AwsSigV4 { | ||
async fn handle( | ||
&self, | ||
mut request: Request, | ||
next: RequestPipeline<'_>, | ||
) -> Result<Response, RequestPipelineError> { | ||
self.sign_request(&mut request) | ||
.await | ||
.map_err(|e| RequestPipelineError::Pipeline(e.into()))?; | ||
next.run(request).await | ||
} | ||
} | ||
|
||
pub trait TransportBuilderExt { | ||
fn aws_sigv4(self, aws_sigv4: AwsSigV4) -> Self; | ||
} | ||
|
||
impl TransportBuilderExt for TransportBuilder { | ||
fn aws_sigv4(self, aws_sigv4: AwsSigV4) -> Self { | ||
self.with_handler(aws_sigv4) | ||
} | ||
} |
Oops, something went wrong.