Skip to content

Commit

Permalink
feat: support SASL SCRAM-SHA-256 and SCRAM-SHA-512
Browse files Browse the repository at this point in the history
  • Loading branch information
WenyXu committed Aug 11, 2024
1 parent a331d09 commit 2665a03
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 32 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ tokio = { version = "1.19", default-features = false, features = ["io-util", "ne
tokio-rustls = { version = "0.26", optional = true, default-features = false, features = ["logging", "ring", "tls12"] }
tracing = "0.1"
zstd = { version = "0.13", optional = true }
rsasl = { version = "2.0", default-features = false, features = ["config_builder", "provider", "login", "plain", "scram-sha-1", "scram-sha-2"]}

[dev-dependencies]
assert_matches = "1.5"
Expand Down
2 changes: 1 addition & 1 deletion src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use error::{Error, Result};

use self::{controller::ControllerClient, partition::UnknownTopicHandling};

pub use crate::connection::SaslConfig;
pub use crate::connection::{Credentials, SaslConfig};

#[derive(Debug, Error)]
pub enum ProduceError {
Expand Down
5 changes: 2 additions & 3 deletions src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use crate::{
client::metadata_cache::MetadataCache,
};

pub use self::transport::Credentials;
pub use self::transport::SaslConfig;
pub use self::transport::TlsConfig;

Expand Down Expand Up @@ -164,9 +165,7 @@ impl ConnectionHandler for BrokerRepresentation {
let mut messenger = Messenger::new(BufStream::new(transport), max_message_size, client_id);
messenger.sync_versions().await?;
if let Some(sasl_config) = sasl_config {
messenger
.sasl_handshake(sasl_config.mechanism(), sasl_config.auth_bytes())
.await?;
messenger.do_sasl(sasl_config).await?;
}
Ok(Arc::new(messenger))
}
Expand Down
2 changes: 1 addition & 1 deletion src/connection/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use tokio::net::TcpStream;
use tokio_rustls::{client::TlsStream, TlsConnector};

mod sasl;
pub use sasl::SaslConfig;
pub use sasl::{Credentials, SaslConfig};

#[cfg(feature = "transport-tls")]
pub type TlsConfig = Option<Arc<rustls::ClientConfig>>;
Expand Down
38 changes: 29 additions & 9 deletions src/connection/transport/sasl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,45 @@ pub enum SaslConfig {
///
/// # References
/// - <https://datatracker.ietf.org/doc/html/rfc4616>
Plain { username: String, password: String },
Plain(Credentials),
/// SASL - SCRAM-SHA-256
///
/// # References
/// - <https://datatracker.ietf.org/doc/html/rfc7677>
ScramSha256(Credentials),
/// SASL - SCRAM-SHA-512
///
/// # References
/// - <https://datatracker.ietf.org/doc/html/rfc5802>
ScramSha512(Credentials),
}

#[derive(Debug, Clone)]
pub struct Credentials {
pub username: String,
pub password: String,
}

impl Credentials {
pub fn new(username: String, password: String) -> Self {
Self { username, password }
}
}

impl SaslConfig {
pub(crate) fn auth_bytes(&self) -> Vec<u8> {
pub(crate) fn credentials(&self) -> Credentials {
match self {
Self::Plain { username, password } => {
let mut auth: Vec<u8> = vec![0];
auth.extend(username.bytes());
auth.push(0);
auth.extend(password.bytes());
auth
}
Self::Plain(credentials) => credentials.clone(),
Self::ScramSha256(credentials) => credentials.clone(),
Self::ScramSha512(credentials) => credentials.clone(),
}
}

pub(crate) fn mechanism(&self) -> &str {
match self {
Self::Plain { .. } => "PLAIN",
Self::ScramSha256 { .. } => "SCRAM-SHA-256",
Self::ScramSha512 { .. } => "SCRAM-SHA-512",
}
}
}
80 changes: 69 additions & 11 deletions src/messenger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use std::{

use futures::future::BoxFuture;
use parking_lot::Mutex;
use rsasl::{config::SASLConfig, mechname::MechanismNameError, prelude::Mechname};
use thiserror::Error;
use tokio::{
io::{AsyncRead, AsyncWrite, AsyncWriteExt, WriteHalf},
Expand All @@ -23,8 +24,6 @@ use tokio::{
};
use tracing::{debug, info, warn};

use crate::protocol::{api_version::ApiVersionRange, primitives::CompactString};
use crate::protocol::{messages::ApiVersionsRequest, traits::ReadType};
use crate::{
backoff::ErrorOrThrottle,
protocol::{
Expand All @@ -34,12 +33,21 @@ use crate::{
frame::{AsyncMessageRead, AsyncMessageWrite},
messages::{
ReadVersionedError, ReadVersionedType, RequestBody, RequestHeader, ResponseHeader,
SaslAuthenticateRequest, SaslHandshakeRequest, WriteVersionedError, WriteVersionedType,
SaslAuthenticateRequest, SaslAuthenticateResponse, SaslHandshakeRequest,
SaslHandshakeResponse, WriteVersionedError, WriteVersionedType,
},
primitives::{Int16, Int32, NullableString, TaggedFields},
},
throttle::maybe_throttle,
};
use crate::{
client::SaslConfig,
protocol::{api_version::ApiVersionRange, primitives::CompactString},
};
use crate::{
connection::Credentials,
protocol::{messages::ApiVersionsRequest, traits::ReadType},
};

#[derive(Debug)]
struct Response {
Expand Down Expand Up @@ -186,6 +194,9 @@ pub enum SaslError {

#[error("API error: {0}")]
ApiError(#[from] ApiError),

#[error("Invalid sasl mechanism: {0}")]
InvalidSaslMechanism(#[from] MechanismNameError),
}

impl<RW> Messenger<RW>
Expand Down Expand Up @@ -531,16 +542,10 @@ where
Err(SyncVersionsError::NoWorkingVersion)
}

pub async fn sasl_handshake(
async fn sasl_authentication(
&self,
mechanism: &str,
auth_bytes: Vec<u8>,
) -> Result<(), SaslError> {
let req = SaslHandshakeRequest::new(mechanism);
let resp = self.request(req).await?;
if let Some(err) = resp.error_code {
return Err(SaslError::ApiError(err));
}
) -> Result<SaslAuthenticateResponse, SaslError> {
let req = SaslAuthenticateRequest::new(auth_bytes);
let resp = self.request(req).await?;
if let Some(err) = resp.error_code {
Expand All @@ -549,6 +554,59 @@ where
}
return Err(SaslError::ApiError(err));
}

Ok(resp)
}

async fn sasl_handshake(&self, mechanism: &str) -> Result<SaslHandshakeResponse, SaslError> {
let req = SaslHandshakeRequest::new(mechanism);
let resp = self.request(req).await?;
if let Some(err) = resp.error_code {
return Err(SaslError::ApiError(err));
}
Ok(resp)
}

pub async fn do_sasl(&self, config: SaslConfig) -> Result<(), SaslError> {
let mechanism = config.mechanism();
let resp = self.sasl_handshake(mechanism).await?;

let Credentials { username, password } = config.credentials();
let config = SASLConfig::with_credentials(None, username, password).unwrap();
let sasl = rsasl::prelude::SASLClient::new(config);
let raw_mechanisms = resp.mechanisms.0.unwrap_or_default();
let mechanisms = raw_mechanisms
.iter()
.map(|mech| {
debug!("{:?}", mech);
Mechname::parse(mech.0.as_bytes()).map_err(SaslError::InvalidSaslMechanism)
})
.collect::<Result<Vec<_>, SaslError>>()?;
debug!("Supported mechanisms {:?}", mechanisms);
let mut session = sasl.start_suggested(&mechanisms).unwrap();
let selected_mechanism = session.get_mechname();
debug!("Using {:?} for the SASL Mechanism", selected_mechanism);
let mut data: Option<Vec<u8>> = None;

// Stepping the authentication exchange to completion.
while {
let mut out = Cursor::new(Vec::new());
// The each call to step writes the generated auth data into the provided writer.
// Normally this data would then have to be sent to the other party, but this goes
// beyond the scope of this example.
let state = session
.step(data.as_deref(), &mut out)
.expect("step errored!");

data = Some(out.into_inner());

// Returns `true` if step needs to be called again with another batch of data.
state.is_running()
} {
let authentication_response = self.sasl_authentication(data.unwrap().to_vec()).await?;
data = Some(authentication_response.auth_bytes.0);
}

Ok(())
}
}
Expand Down
9 changes: 4 additions & 5 deletions tests/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,9 @@ async fn test_sasl() {
return;
}
ClientBuilder::new(vec![env::var("KAFKA_SASL_CONNECT").unwrap()])
.sasl_config(rskafka::client::SaslConfig::Plain {
username: "admin".to_string(),
password: "admin-secret".to_string(),
})
.sasl_config(rskafka::client::SaslConfig::Plain(
rskafka::client::Credentials::new("admin".to_string(), "admin-secret".to_string()),
))
.build()
.await
.unwrap();
Expand Down Expand Up @@ -425,7 +424,7 @@ async fn test_get_offset() {
// use out-of order timestamps to ensure our "lastest offset" logic works
let record_early = record(b"");
let record_late = Record {
timestamp: record_early.timestamp + chrono::Duration::seconds(1),
timestamp: record_early.timestamp + chrono::Duration::try_seconds(1).unwrap(),
..record_early.clone()
};
let offsets = partition_client
Expand Down
4 changes: 2 additions & 2 deletions tests/produce_consume.rs
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,8 @@ async fn assert_produce_consume<F1, G1, F2, G2>(

// timestamps for records. We'll reorder the messages though to ts2, ts1, ts3
let ts1 = Utc.timestamp_millis_opt(1337).unwrap();
let ts2 = ts1 + Duration::milliseconds(1);
let ts3 = ts2 + Duration::milliseconds(1);
let ts2 = ts1 + Duration::try_milliseconds(1).unwrap();
let ts3 = ts2 + Duration::try_milliseconds(1).unwrap();

let record_1 = {
let record = Record {
Expand Down

0 comments on commit 2665a03

Please sign in to comment.