Skip to content

Commit

Permalink
Refactor AWS SigV4 to middleware approach
Browse files Browse the repository at this point in the history
Signed-off-by: Thomas Farr <[email protected]>
  • Loading branch information
Xtansia committed Mar 23, 2023
1 parent 8885a98 commit b33e9c8
Show file tree
Hide file tree
Showing 25 changed files with 978 additions and 756 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
test-args:
- "--features aws-auth"
- ""
- "--no-default-features --features rustls-tls --package opensearch --test cert"
runs-on: ${{ matrix.os }}
steps:
Expand All @@ -33,7 +33,7 @@ jobs:
version: 2.6.0
secured: true

- name: Run Tests (${{ matrix.test-args }})
- name: Run Tests
working-directory: client
run: cargo make test ${{ matrix.test-args }}
env:
Expand Down
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
members = [
"api_generator",
"opensearch",
"opensearch-auth-awssigv4",
"yaml_test_runner"
]
11 changes: 4 additions & 7 deletions Makefile.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,15 @@ OPENSEARCH_URL = { value = "${OPENSEARCH_PROTOCOL}://localhost:9200", condition
[tasks.generate-certs]
category = "OpenSearch"
description = "Generates SSL certificates used for integration tests"
command = "bash ./.ci/generate-certs.sh"
command = "bash"
args = ["./.ci/generate-certs.sh"]

[tasks.run-opensearch]
category = "OpenSearch"
private = true
condition = { env_set = [ "STACK_VERSION"], env_false = ["CARGO_MAKE_CI"] }

[tasks.run-opensearch.linux]
command = "./.ci/run-opensearch.sh"

[tasks.run-opensearch.mac]
command = "./.ci/run-opensearch.sh"
command = "bash"
args = ["./.ci/run-opensearch.sh"]

[tasks.run-opensearch.windows]
script_runner = "cmd"
Expand Down
2 changes: 1 addition & 1 deletion api_generator/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ semver = "1.0.14"
serde = { version = "~1", features = ["derive"] }
serde_json = "~1"
simple_logger = "4.0.0"
syn = { version = "~1.0", features = ["full"] }
syn = { version = "~1.0", features = ["full", "extra-traits"] }
tar = "~0.4"
toml = "0.7.1"
url = "2.1.1"
Expand Down
28 changes: 28 additions & 0 deletions opensearch-auth-awssigv4/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
[package]
name = "opensearch-auth-awssigv4"
version = "2.0.0"
edition = "2018"
authors = ["OpenSearch Contributors"]
description = "Official OpenSearch Rust client AWS SigV4 support"
repository = "https://github.com/opensearch-project/opensearch-rs"
keywords = ["opensearch", "elasticsearch", "search", "lucene"]
categories = ["api-bindings", "database"]
documentation = "https://opensearch.org/docs/latest"
license = "Apache-2.0"
readme = "../README.md"

[dependencies]
aws-config = "0.54"
aws-credential-types = "0.54"
aws-sigv4 = "0.54"
aws-types = "0.54"
http = "0.2"
opensearch = { path = "../opensearch" }
thiserror = "1.0"
tokio = { version = "~1" }

[dev-dependencies]
anyhow = "1.0"
serde_json = "~1"
tokio = { version = "~1", features = ["full"] }
wiremock = "0.5"
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,25 @@
* GitHub history for details.
*/

#[tokio::main]
#[cfg(feature = "aws-auth")]
pub async fn main() -> Result<(), Box<dyn std::error::Error>> {
use std::convert::TryInto;
use std::convert::TryInto;

use opensearch::{
cat::CatIndicesParts,
http::transport::{SingleNodeConnectionPool, TransportBuilder},
OpenSearch,
};
use url::Url;
use opensearch::{
cat::CatIndicesParts,
http::{
transport::{SingleNodeConnectionPool, TransportBuilder},
Url,
},
OpenSearch,
};
use opensearch_auth_awssigv4::TransportBuilderExt;

#[tokio::main]
pub async fn main() -> Result<(), Box<dyn std::error::Error>> {
let aws_config = aws_config::load_from_env().await;

let host = ""; // e.g. https://search-mydomain.us-west-1.es.amazonaws.com
let transport = TransportBuilder::new(SingleNodeConnectionPool::new(Url::parse(host).unwrap()))
.auth(aws_config.try_into()?)
.aws_sigv4(aws_config.try_into()?)
.build()?;
let client = OpenSearch::new(transport);

Expand All @@ -40,8 +42,3 @@ pub async fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("{}", text);
Ok(())
}

#[cfg(not(feature = "aws-auth"))]
pub fn main() {
panic!("Requires the `aws-auth` feature to be enabled")
}
242 changes: 242 additions & 0 deletions opensearch-auth-awssigv4/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

use std::convert::TryFrom;

use aws_config::SdkConfig;
use aws_credential_types::{
provider::{ProvideCredentials, SharedCredentialsProvider},
time_source::TimeSource,
Credentials,
};
use aws_sigv4::http_request::{
sign, PayloadChecksumKind, SignableBody, SignableRequest, SignatureLocation, SigningParams,
SigningSettings,
};
use http::header::{HeaderName, CONTENT_LENGTH, USER_AGENT};
use opensearch::http::{
middleware::{async_trait, RequestHandler, RequestPipeline, RequestPipelineError},
reqwest::Response,
transport::TransportBuilder,
Request,
};
use thiserror::Error;

pub use aws_config;
pub use aws_credential_types;
pub use aws_types;

#[derive(Error, Debug)]
pub enum BuildError {
#[error("the region for signing must be provided")]
MissingRegion,
#[error("the credentials provider for signing must be provided")]
MissingCredentialsProvider,
}

#[derive(Debug, Clone)]
pub struct Builder {
service_name: Option<String>,
credentials_provider: Option<SharedCredentialsProvider>,
region: Option<String>,
ignored_headers: Vec<HeaderName>,
time_source: Option<TimeSource>,
}

impl Builder {
pub fn service_name(mut self, service_name: impl AsRef<str>) -> Self {
self.service_name = Some(service_name.as_ref().into());
self
}

pub fn credentials_provider(
mut self,
credentials_provider: impl ProvideCredentials + 'static,
) -> Self {
self.credentials_provider = Some(SharedCredentialsProvider::new(credentials_provider));
self
}

pub fn region(mut self, region: impl AsRef<str>) -> Self {
self.region = Some(region.as_ref().into());
self
}

pub fn ignore_header(mut self, header_name: HeaderName) -> Self {
self.ignored_headers.push(header_name);
self
}

#[doc(hidden)]
pub fn time_source(mut self, time_source: TimeSource) -> Self {
self.time_source = Some(time_source);
self
}

pub fn build(self) -> Result<AwsSigV4, BuildError> {
Ok(AwsSigV4 {
service_name: self.service_name.unwrap_or_else(|| "es".to_string()),
credentials_provider: self
.credentials_provider
.ok_or(BuildError::MissingCredentialsProvider)?,
region: self.region.ok_or(BuildError::MissingRegion)?,
ignored_headers: self.ignored_headers,
time_source: self.time_source.unwrap_or_default(),
})
}
}

impl Default for Builder {
fn default() -> Self {
Self {
service_name: None,
credentials_provider: None,
region: None,
ignored_headers: vec![USER_AGENT, CONTENT_LENGTH],
time_source: None,
}
}
}

impl From<&SdkConfig> for Builder {
fn from(value: &SdkConfig) -> Self {
let credentials_provider = value.credentials_provider().cloned();
let region = value.region().map(|r| r.as_ref().into());

Self {
credentials_provider,
region,
..Default::default()
}
}
}

impl From<SdkConfig> for Builder {
fn from(value: SdkConfig) -> Self {
<Self as From<&SdkConfig>>::from(&value)
}
}

#[derive(Error, Debug)]
pub enum AwsSigV4Error {
#[error("invalid signing params: {0}")]
InvalidSigningParams(#[from] aws_sigv4::signing_params::BuildError),
#[error("unable to retrieve credentials: {0}")]
FailedCredentialsRetrieval(#[from] aws_credential_types::provider::error::CredentialsError),
#[error("invalid uri: {0}")]
InvalidUri(#[from] http::uri::InvalidUri),
#[error("unable to sign request: {0}")]
FailedSigning(#[from] aws_sigv4::http_request::SigningError),
}

#[derive(Debug, Clone)]
pub struct AwsSigV4 {
service_name: String,
credentials_provider: SharedCredentialsProvider,
region: String,
ignored_headers: Vec<HeaderName>,
time_source: TimeSource,
}

impl AwsSigV4 {
pub fn builder() -> Builder {
Builder::default()
}

async fn sign_request(&self, request: &mut Request) -> Result<(), AwsSigV4Error> {
let credentials = self.credentials_provider.provide_credentials().await?;
let params = self.build_params(&credentials)?;

let uri = request.url().as_str().parse()?;

let signable_request = SignableRequest::new(
request.method(),
&uri,
request.headers(),
SignableBody::Bytes(request.body().and_then(|b| b.as_bytes()).unwrap_or(&[])),
);

let (mut instructions, _) = sign(signable_request, &params)?.into_parts();

if let Some(new_headers) = instructions.take_headers() {
for (name, value) in new_headers.into_iter() {
request.headers_mut().insert(
name.expect("AWS SigV4 signing header name must never be None"),
value,
);
}
}

Ok(())
}

fn build_params<'a>(
&'a self,
credentials: &'a Credentials,
) -> Result<SigningParams<'a>, aws_sigv4::signing_params::BuildError> {
let mut settings = SigningSettings::default();
settings.signature_location = SignatureLocation::Headers;
settings.payload_checksum_kind = PayloadChecksumKind::XAmzSha256;
settings.excluded_headers = Some(self.ignored_headers.clone());

let mut builder = SigningParams::builder()
.access_key(credentials.access_key_id())
.secret_key(credentials.secret_access_key())
.service_name(&self.service_name)
.region(&self.region)
.time(self.time_source.now())
.settings(settings);

builder.set_security_token(credentials.session_token());

builder.build()
}
}

impl TryFrom<&SdkConfig> for AwsSigV4 {
type Error = BuildError;

fn try_from(value: &SdkConfig) -> Result<Self, Self::Error> {
Builder::from(value).build()
}
}

impl TryFrom<SdkConfig> for AwsSigV4 {
type Error = BuildError;

fn try_from(value: SdkConfig) -> Result<Self, Self::Error> {
<Self as TryFrom<&SdkConfig>>::try_from(&value)
}
}

#[async_trait]
impl RequestHandler for AwsSigV4 {
async fn handle(
&self,
mut request: Request,
next: RequestPipeline<'_>,
) -> Result<Response, RequestPipelineError> {
self.sign_request(&mut request)
.await
.map_err(|e| RequestPipelineError::Pipeline(e.into()))?;
next.run(request).await
}
}

pub trait TransportBuilderExt {
fn aws_sigv4(self, aws_sigv4: AwsSigV4) -> Self;
}

impl TransportBuilderExt for TransportBuilder {
fn aws_sigv4(self, aws_sigv4: AwsSigV4) -> Self {
self.with_handler(aws_sigv4)
}
}
Loading

0 comments on commit b33e9c8

Please sign in to comment.