Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support SASL SCRAM mechanism #247

Merged
merged 3 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ tokio = { version = "1.19", default-features = false, features = ["io-util", "ne
tokio-rustls = { version = "0.26", optional = true, default-features = false, features = ["logging", "ring", "tls12"] }
tracing = "0.1"
zstd = { version = "0.13", optional = true }
rsasl = { version = "2.1", default-features = false, features = ["config_builder", "provider", "plain", "scram-sha-2"]}

[dev-dependencies]
assert_matches = "1.5"
Expand Down
2 changes: 1 addition & 1 deletion src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use error::{Error, Result};

use self::{controller::ControllerClient, partition::UnknownTopicHandling};

pub use crate::connection::SaslConfig;
pub use crate::connection::{Credentials, SaslConfig};

#[derive(Debug, Error)]
pub enum ProduceError {
Expand Down
5 changes: 2 additions & 3 deletions src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use crate::{
client::metadata_cache::MetadataCache,
};

pub use self::transport::Credentials;
pub use self::transport::SaslConfig;
pub use self::transport::TlsConfig;

Expand Down Expand Up @@ -164,9 +165,7 @@ impl ConnectionHandler for BrokerRepresentation {
let mut messenger = Messenger::new(BufStream::new(transport), max_message_size, client_id);
messenger.sync_versions().await?;
if let Some(sasl_config) = sasl_config {
messenger
.sasl_handshake(sasl_config.mechanism(), sasl_config.auth_bytes())
.await?;
messenger.do_sasl(sasl_config).await?;
}
Ok(Arc::new(messenger))
}
Expand Down
2 changes: 1 addition & 1 deletion src/connection/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use tokio::net::TcpStream;
use tokio_rustls::{client::TlsStream, TlsConnector};

mod sasl;
pub use sasl::SaslConfig;
pub use sasl::{Credentials, SaslConfig};

#[cfg(feature = "transport-tls")]
pub type TlsConfig = Option<Arc<rustls::ClientConfig>>;
Expand Down
38 changes: 29 additions & 9 deletions src/connection/transport/sasl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,45 @@ pub enum SaslConfig {
///
/// # References
/// - <https://datatracker.ietf.org/doc/html/rfc4616>
Plain { username: String, password: String },
Plain(Credentials),
/// SASL - SCRAM-SHA-256
///
/// # References
/// - <https://datatracker.ietf.org/doc/html/rfc7677>
ScramSha256(Credentials),
/// SASL - SCRAM-SHA-512
///
/// # References
/// - <https://datatracker.ietf.org/doc/html/draft-melnikov-scram-sha-512-04>
ScramSha512(Credentials),
}

#[derive(Debug, Clone)]
pub struct Credentials {
pub username: String,
pub password: String,
}

impl Credentials {
pub fn new(username: String, password: String) -> Self {
Self { username, password }
}
}

impl SaslConfig {
pub(crate) fn auth_bytes(&self) -> Vec<u8> {
pub(crate) fn credentials(&self) -> Credentials {
match self {
Self::Plain { username, password } => {
let mut auth: Vec<u8> = vec![0];
auth.extend(username.bytes());
auth.push(0);
auth.extend(password.bytes());
auth
}
Self::Plain(credentials) => credentials.clone(),
Self::ScramSha256(credentials) => credentials.clone(),
Self::ScramSha512(credentials) => credentials.clone(),
}
}

pub(crate) fn mechanism(&self) -> &str {
match self {
Self::Plain { .. } => "PLAIN",
Self::ScramSha256 { .. } => "SCRAM-SHA-256",
Self::ScramSha512 { .. } => "SCRAM-SHA-512",
}
}
}
85 changes: 74 additions & 11 deletions src/messenger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ use std::{

use futures::future::BoxFuture;
use parking_lot::Mutex;
use rsasl::{
config::SASLConfig,
mechname::MechanismNameError,
prelude::{Mechname, SessionError},
};
use thiserror::Error;
use tokio::{
io::{AsyncRead, AsyncWrite, AsyncWriteExt, WriteHalf},
Expand All @@ -23,8 +28,6 @@ use tokio::{
};
use tracing::{debug, info, warn};

use crate::protocol::{api_version::ApiVersionRange, primitives::CompactString};
use crate::protocol::{messages::ApiVersionsRequest, traits::ReadType};
use crate::{
backoff::ErrorOrThrottle,
protocol::{
Expand All @@ -34,12 +37,21 @@ use crate::{
frame::{AsyncMessageRead, AsyncMessageWrite},
messages::{
ReadVersionedError, ReadVersionedType, RequestBody, RequestHeader, ResponseHeader,
SaslAuthenticateRequest, SaslHandshakeRequest, WriteVersionedError, WriteVersionedType,
SaslAuthenticateRequest, SaslAuthenticateResponse, SaslHandshakeRequest,
SaslHandshakeResponse, WriteVersionedError, WriteVersionedType,
},
primitives::{Int16, Int32, NullableString, TaggedFields},
},
throttle::maybe_throttle,
};
use crate::{
client::SaslConfig,
protocol::{api_version::ApiVersionRange, primitives::CompactString},
};
use crate::{
connection::Credentials,
protocol::{messages::ApiVersionsRequest, traits::ReadType},
};

#[derive(Debug)]
struct Response {
Expand Down Expand Up @@ -186,6 +198,15 @@ pub enum SaslError {

#[error("API error: {0}")]
ApiError(#[from] ApiError),

#[error("Invalid sasl mechanism: {0}")]
InvalidSaslMechanism(#[from] MechanismNameError),

#[error("Sasl session error: {0}")]
SaslSessionError(#[from] SessionError),

#[error("unsupported sasl mechanism")]
UnsupportedSaslMechanism,
}

impl<RW> Messenger<RW>
Expand Down Expand Up @@ -531,16 +552,10 @@ where
Err(SyncVersionsError::NoWorkingVersion)
}

pub async fn sasl_handshake(
async fn sasl_authentication(
&self,
mechanism: &str,
auth_bytes: Vec<u8>,
) -> Result<(), SaslError> {
let req = SaslHandshakeRequest::new(mechanism);
let resp = self.request(req).await?;
if let Some(err) = resp.error_code {
return Err(SaslError::ApiError(err));
}
) -> Result<SaslAuthenticateResponse, SaslError> {
let req = SaslAuthenticateRequest::new(auth_bytes);
let resp = self.request(req).await?;
if let Some(err) = resp.error_code {
Expand All @@ -549,6 +564,54 @@ where
}
return Err(SaslError::ApiError(err));
}

Ok(resp)
}

async fn sasl_handshake(&self, mechanism: &str) -> Result<SaslHandshakeResponse, SaslError> {
let req = SaslHandshakeRequest::new(mechanism);
let resp = self.request(req).await?;
if let Some(err) = resp.error_code {
return Err(SaslError::ApiError(err));
}
Ok(resp)
}

pub async fn do_sasl(&self, config: SaslConfig) -> Result<(), SaslError> {
let mechanism = config.mechanism();
let resp = self.sasl_handshake(mechanism).await?;

let Credentials { username, password } = config.credentials();
let config = SASLConfig::with_credentials(None, username, password).unwrap();
let sasl = rsasl::prelude::SASLClient::new(config);
let raw_mechanisms = resp.mechanisms.0.unwrap_or_default();
let mechanisms = raw_mechanisms
.iter()
.map(|mech| Mechname::parse(mech.0.as_bytes()).map_err(SaslError::InvalidSaslMechanism))
.collect::<Result<Vec<_>, SaslError>>()?;
debug!(?mechanisms, "Supported SASL mechanisms");
let prefer_mechanism =
Mechname::parse(mechanism.as_bytes()).map_err(SaslError::InvalidSaslMechanism)?;
if !mechanisms.contains(&prefer_mechanism) {
return Err(SaslError::UnsupportedSaslMechanism);
}
let mut session = sasl
.start_suggested(&[prefer_mechanism])
.map_err(|_| SaslError::UnsupportedSaslMechanism)?;
debug!(?mechanism, "Using SASL Mechanism");
// we step through the auth process, starting on our side with NO data received so far
let mut data_received: Option<Vec<u8>> = None;
loop {
let mut to_sent = Cursor::new(Vec::new());
let state = session.step(data_received.as_deref(), &mut to_sent)?;
if !state.is_running() {
break;
}

let authentication_response = self.sasl_authentication(to_sent.into_inner()).await?;
data_received = Some(authentication_response.auth_bytes.0);
}

Ok(())
}
}
Expand Down
9 changes: 4 additions & 5 deletions tests/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,9 @@ async fn test_sasl() {
return;
}
ClientBuilder::new(vec![env::var("KAFKA_SASL_CONNECT").unwrap()])
.sasl_config(rskafka::client::SaslConfig::Plain {
username: "admin".to_string(),
password: "admin-secret".to_string(),
})
.sasl_config(rskafka::client::SaslConfig::Plain(
rskafka::client::Credentials::new("admin".to_string(), "admin-secret".to_string()),
))
.build()
.await
.unwrap();
Expand Down Expand Up @@ -425,7 +424,7 @@ async fn test_get_offset() {
// use out-of order timestamps to ensure our "lastest offset" logic works
let record_early = record(b"");
let record_late = Record {
timestamp: record_early.timestamp + chrono::Duration::seconds(1),
timestamp: record_early.timestamp + chrono::Duration::try_seconds(1).unwrap(),
..record_early.clone()
};
let offsets = partition_client
Expand Down
4 changes: 2 additions & 2 deletions tests/produce_consume.rs
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,8 @@ async fn assert_produce_consume<F1, G1, F2, G2>(

// timestamps for records. We'll reorder the messages though to ts2, ts1, ts3
let ts1 = Utc.timestamp_millis_opt(1337).unwrap();
let ts2 = ts1 + Duration::milliseconds(1);
let ts3 = ts2 + Duration::milliseconds(1);
let ts2 = ts1 + Duration::try_milliseconds(1).unwrap();
let ts3 = ts2 + Duration::try_milliseconds(1).unwrap();

let record_1 = {
let record = Record {
Expand Down