Skip to content

Commit

Permalink
refactor(cubesql): Make Postgres authentication extensible
Browse files Browse the repository at this point in the history
  • Loading branch information
MazterQyou committed Sep 14, 2024
1 parent e8d81f2 commit fb9ba98
Show file tree
Hide file tree
Showing 8 changed files with 268 additions and 108 deletions.
7 changes: 4 additions & 3 deletions rust/cubesql/cubesql/src/compile/test/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ use crate::{
},
config::{ConfigObj, ConfigObjImpl},
sql::{
compiler_cache::CompilerCacheImpl, dataframe::batches_to_dataframe, AuthContextRef,
AuthenticateResponse, HttpAuthContext, ServerManager, Session, SessionManager,
SqlAuthService,
compiler_cache::CompilerCacheImpl, dataframe::batches_to_dataframe,
pg_auth_service::PostgresAuthServiceDefaultImpl, AuthContextRef, AuthenticateResponse,
HttpAuthContext, ServerManager, Session, SessionManager, SqlAuthService,
},
transport::{
CubeStreamReceiver, LoadRequestMeta, SpanId, SqlGenerator, SqlResponse, SqlTemplates,
Expand Down Expand Up @@ -607,6 +607,7 @@ async fn get_test_session_with_config_and_transport(
let server = Arc::new(ServerManager::new(
get_test_auth(),
test_transport.clone(),
Arc::new(PostgresAuthServiceDefaultImpl::new()),
Arc::new(CompilerCacheImpl::new(config_obj.clone(), test_transport)),
None,
config_obj,
Expand Down
12 changes: 11 additions & 1 deletion rust/cubesql/cubesql/src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ use crate::{
injection::{DIService, Injector},
processing_loop::{ProcessingLoop, ShutdownMode},
},
sql::{PostgresServer, ServerManager, SessionManager, SqlAuthDefaultImpl, SqlAuthService},
sql::{
pg_auth_service::{PostgresAuthService, PostgresAuthServiceDefaultImpl},
PostgresServer, ServerManager, SessionManager, SqlAuthDefaultImpl, SqlAuthService,
},
transport::{HttpTransport, TransportService},
CubeError,
};
Expand Down Expand Up @@ -302,6 +305,12 @@ impl Config {
})
.await;

self.injector
.register_typed::<dyn PostgresAuthService, _, _, _>(|_| async move {
Arc::new(PostgresAuthServiceDefaultImpl::new())
})
.await;

Check warning on line 312 in rust/cubesql/cubesql/src/config/mod.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/config/mod.rs#L312

Added line #L312 was not covered by tests

self.injector
.register_typed::<dyn CompilerCache, _, _, _>(|i| async move {
let config = i.get_service_typed::<dyn ConfigObj>().await;
Expand All @@ -319,6 +328,7 @@ impl Config {
i.get_service_typed().await,
i.get_service_typed().await,
i.get_service_typed().await,
i.get_service_typed().await,
config.nonce().clone(),
config.clone(),
))
Expand Down
1 change: 1 addition & 0 deletions rust/cubesql/cubesql/src/sql/postgres/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub(crate) mod extended;
pub mod pg_auth_service;
pub(crate) mod pg_type;
pub(crate) mod service;
pub(crate) mod shim;
Expand Down
106 changes: 106 additions & 0 deletions rust/cubesql/cubesql/src/sql/postgres/pg_auth_service.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
use std::{collections::HashMap, fmt::Debug, sync::Arc};

use async_trait::async_trait;

use crate::{
sql::{AuthContextRef, SqlAuthService},
CubeError,
};

pub use pg_srv::{
protocol::{AuthenticationRequest, FrontendMessage},
MessageTagParser, MessageTagParserDefaultImpl,
};

#[derive(Debug)]
pub enum AuthenticationStatus {
UnexpectedFrontendMessage,
Failed(String),
// User name + auth context
Success(String, AuthContextRef),
}

#[async_trait]
pub trait PostgresAuthService: Sync + Send + Debug {
fn get_auth_method(&self, parameters: &HashMap<String, String>) -> AuthenticationRequest;

async fn authenticate(
&self,
service: Arc<dyn SqlAuthService>,
request: AuthenticationRequest,
secret: FrontendMessage,
parameters: &HashMap<String, String>,
) -> AuthenticationStatus;

fn get_pg_message_tag_parser(&self) -> Arc<dyn MessageTagParser>;
}

#[derive(Debug)]
pub struct PostgresAuthServiceDefaultImpl {
pg_message_tag_parser: Arc<dyn MessageTagParser>,
}

impl PostgresAuthServiceDefaultImpl {
pub fn new() -> Self {
Self {
pg_message_tag_parser: Arc::new(MessageTagParserDefaultImpl::default()),
}
}
}

#[async_trait]
impl PostgresAuthService for PostgresAuthServiceDefaultImpl {
fn get_auth_method(&self, _: &HashMap<String, String>) -> AuthenticationRequest {
AuthenticationRequest::CleartextPassword
}

async fn authenticate(
&self,
service: Arc<dyn SqlAuthService>,
request: AuthenticationRequest,
secret: FrontendMessage,
parameters: &HashMap<String, String>,
) -> AuthenticationStatus {
let FrontendMessage::PasswordMessage(password_message) = secret else {
return AuthenticationStatus::UnexpectedFrontendMessage;

Check warning on line 65 in rust/cubesql/cubesql/src/sql/postgres/pg_auth_service.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/sql/postgres/pg_auth_service.rs#L65

Added line #L65 was not covered by tests
};

if !matches!(request, AuthenticationRequest::CleartextPassword) {
return AuthenticationStatus::UnexpectedFrontendMessage;

Check warning on line 69 in rust/cubesql/cubesql/src/sql/postgres/pg_auth_service.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/sql/postgres/pg_auth_service.rs#L69

Added line #L69 was not covered by tests
}

let user = parameters.get("user").unwrap().clone();
let authenticate_response = service
.authenticate(Some(user.clone()), Some(password_message.password.clone()))
.await;

Check warning on line 75 in rust/cubesql/cubesql/src/sql/postgres/pg_auth_service.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/sql/postgres/pg_auth_service.rs#L75

Added line #L75 was not covered by tests

let auth_fail = || {
AuthenticationStatus::Failed(format!(
"password authentication failed for user \"{}\"",
user
))
};

Check warning on line 82 in rust/cubesql/cubesql/src/sql/postgres/pg_auth_service.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/sql/postgres/pg_auth_service.rs#L78-L82

Added lines #L78 - L82 were not covered by tests

let Ok(authenticate_response) = authenticate_response else {
return auth_fail();

Check warning on line 85 in rust/cubesql/cubesql/src/sql/postgres/pg_auth_service.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/sql/postgres/pg_auth_service.rs#L85

Added line #L85 was not covered by tests
};

if !authenticate_response.skip_password_check {
let is_password_correct = match authenticate_response.password {
None => false,

Check warning on line 90 in rust/cubesql/cubesql/src/sql/postgres/pg_auth_service.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/sql/postgres/pg_auth_service.rs#L90

Added line #L90 was not covered by tests
Some(password) => password == password_message.password,
};
if !is_password_correct {
return auth_fail();

Check warning on line 94 in rust/cubesql/cubesql/src/sql/postgres/pg_auth_service.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/sql/postgres/pg_auth_service.rs#L94

Added line #L94 was not covered by tests
}
}

Check warning on line 96 in rust/cubesql/cubesql/src/sql/postgres/pg_auth_service.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/sql/postgres/pg_auth_service.rs#L96

Added line #L96 was not covered by tests

AuthenticationStatus::Success(user, authenticate_response.context)
}

fn get_pg_message_tag_parser(&self) -> Arc<dyn MessageTagParser> {
Arc::clone(&self.pg_message_tag_parser)
}
}

crate::di_service!(PostgresAuthServiceDefaultImpl, [PostgresAuthService]);
131 changes: 60 additions & 71 deletions rust/cubesql/cubesql/src/sql/postgres/shim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{
time::SystemTime,
};

use super::extended::PreparedStatement;
use super::{extended::PreparedStatement, pg_auth_service::AuthenticationStatus};
use crate::{
compile::{
convert_statement_to_cube_query,
Expand All @@ -24,8 +24,11 @@ use crate::{
use futures::{pin_mut, FutureExt, StreamExt};
use log::{debug, error, trace};
use pg_srv::{
buffer, protocol,
protocol::{ErrorCode, ErrorResponse, Format, InitialMessage, PortalCompletion},
buffer,
protocol::{
self, AuthenticationRequest, ErrorCode, ErrorResponse, Format, InitialMessage,
PortalCompletion,
},
PgType, PgTypeId, ProtocolError,
};
use sqlparser::ast::{self, CloseCursor, FetchDirection, Query, SetExpr, Statement, Value};
Expand All @@ -46,10 +49,9 @@ pub struct AsyncPostgresShim {
logger: Arc<dyn ContextLogger>,
}

#[derive(PartialEq, Eq)]
pub enum StartupState {
// Initial parameters which client sends in the first message, we use it later in auth method
Success(HashMap<String, String>),
Success(HashMap<String, String>, AuthenticationRequest),
SslRequested,
Denied,
CancelRequest,
Expand Down Expand Up @@ -313,25 +315,23 @@ impl AsyncPostgresShim {
}

pub async fn run(&mut self) -> Result<(), ConnectionError> {
let initial_parameters = match self.process_initial_message().await? {
StartupState::Success(parameters) => parameters,
let (initial_parameters, auth_method) = match self.process_initial_message().await? {
StartupState::Success(parameters, auth_method) => (parameters, auth_method),
StartupState::SslRequested => match self.process_initial_message().await? {
StartupState::Success(parameters) => parameters,
StartupState::Success(parameters, auth_method) => (parameters, auth_method),

Check warning on line 321 in rust/cubesql/cubesql/src/sql/postgres/shim.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/sql/postgres/shim.rs#L321

Added line #L321 was not covered by tests
_ => return Ok(()),
},
StartupState::Denied | StartupState::CancelRequest => return Ok(()),
};

match buffer::read_message(&mut self.socket).await? {
protocol::FrontendMessage::PasswordMessage(password_message) => {
if !self
.authenticate(password_message, initial_parameters)
.await?
{
return Ok(());
}
}
_ => return Ok(()),
let message_tag_parser = self.session.server.pg_auth.get_pg_message_tag_parser();
let auth_secret =
buffer::read_message(&mut self.socket, Arc::clone(&message_tag_parser)).await?;
if !self
.authenticate(auth_method, auth_secret, initial_parameters)
.await?

Check warning on line 332 in rust/cubesql/cubesql/src/sql/postgres/shim.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/sql/postgres/shim.rs#L332

Added line #L332 was not covered by tests
{
return Ok(());

Check warning on line 334 in rust/cubesql/cubesql/src/sql/postgres/shim.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/sql/postgres/shim.rs#L334

Added line #L334 was not covered by tests
}

self.ready().await?;
Expand All @@ -351,7 +351,7 @@ impl AsyncPostgresShim {
true = async { semifast_shutdownable && { semifast_shutdown_interruptor.cancelled().await; true } } => {
return Self::flush_and_write_admin_shutdown_fatal_message(self).await;
}
message_result = buffer::read_message(&mut self.socket) => message_result?
message_result = buffer::read_message(&mut self.socket, Arc::clone(&message_tag_parser)) => message_result?
};

let result = match message {
Expand Down Expand Up @@ -716,73 +716,62 @@ impl AsyncPostgresShim {
return Ok(StartupState::Denied);
}

self.write(protocol::Authentication::new(
protocol::AuthenticationRequest::CleartextPassword,
))
.await?;
let auth_method = self.session.server.pg_auth.get_auth_method(&parameters);
self.write(protocol::Authentication::new(auth_method.clone()))
.await?;

Check warning on line 721 in rust/cubesql/cubesql/src/sql/postgres/shim.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/sql/postgres/shim.rs#L721

Added line #L721 was not covered by tests

Ok(StartupState::Success(parameters))
Ok(StartupState::Success(parameters, auth_method))
}

pub async fn authenticate(
&mut self,
password_message: protocol::PasswordMessage,
auth_request: AuthenticationRequest,
auth_secret: protocol::FrontendMessage,
parameters: HashMap<String, String>,
) -> Result<bool, ConnectionError> {
let user = parameters.get("user").unwrap().clone();
let authenticate_response = self
let auth_service = self.session.server.auth.clone();
let auth_status = self
.session
.server
.auth
.authenticate(Some(user.clone()), Some(password_message.password.clone()))
.pg_auth
.authenticate(auth_service, auth_request, auth_secret, &parameters)
.await;
let result = match auth_status {
AuthenticationStatus::UnexpectedFrontendMessage => Err((
"invalid authorization specification".to_string(),
protocol::ErrorCode::InvalidAuthorizationSpecification,
)),
AuthenticationStatus::Failed(err) => Err((err, protocol::ErrorCode::InvalidPassword)),

Check warning on line 744 in rust/cubesql/cubesql/src/sql/postgres/shim.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/sql/postgres/shim.rs#L740-L744

Added lines #L740 - L744 were not covered by tests
AuthenticationStatus::Success(user, auth_context) => Ok((user, auth_context)),
};

let mut auth_context: Option<AuthContextRef> = None;
match result {
Err((message, code)) => {
let error_response = protocol::ErrorResponse::fatal(code, message);
buffer::write_message(
&mut self.partial_write_buf,
&mut self.socket,
error_response,
)
.await?;

Check warning on line 756 in rust/cubesql/cubesql/src/sql/postgres/shim.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/sql/postgres/shim.rs#L749-L756

Added lines #L749 - L756 were not covered by tests

let auth_success = match authenticate_response {
Ok(authenticate_response) => {
auth_context = Some(authenticate_response.context);
if !authenticate_response.skip_password_check {
match authenticate_response.password {
None => false,
Some(password) => password == password_message.password,
}
} else {
true
}
Ok(false)

Check warning on line 758 in rust/cubesql/cubesql/src/sql/postgres/shim.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/sql/postgres/shim.rs#L758

Added line #L758 was not covered by tests
}
_ => false,
};

if !auth_success {
let error_response = protocol::ErrorResponse::fatal(
protocol::ErrorCode::InvalidPassword,
format!("password authentication failed for user \"{}\"", &user),
);
buffer::write_message(
&mut self.partial_write_buf,
&mut self.socket,
error_response,
)
.await?;
Ok((user, auth_context)) => {
let database = parameters
.get("database")
.map(|v| v.clone())
.unwrap_or("db".to_string());
self.session.state.set_database(Some(database));
self.session.state.set_user(Some(user));
self.session.state.set_auth_context(Some(auth_context));

self.write(protocol::Authentication::new(AuthenticationRequest::Ok))
.await?;

Check warning on line 770 in rust/cubesql/cubesql/src/sql/postgres/shim.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/sql/postgres/shim.rs#L770

Added line #L770 was not covered by tests

return Ok(false);
Ok(true)
}
}

let database = parameters
.get("database")
.map(|v| v.clone())
.unwrap_or("db".to_string());
self.session.state.set_database(Some(database));
self.session.state.set_user(Some(user));
self.session.state.set_auth_context(auth_context);

self.write(protocol::Authentication::new(
protocol::AuthenticationRequest::Ok,
))
.await?;

Ok(true)
}

pub async fn ready(&mut self) -> Result<(), ConnectionError> {
Expand Down
4 changes: 4 additions & 0 deletions rust/cubesql/cubesql/src/sql/server_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::{
sql::{
compiler_cache::CompilerCache,
database_variables::{mysql_default_global_variables, postgres_default_global_variables},
pg_auth_service::PostgresAuthService,
SqlAuthService,
},
transport::TransportService,
Expand Down Expand Up @@ -37,6 +38,7 @@ pub struct ServerManager {
// References to shared things
pub auth: Arc<dyn SqlAuthService>,
pub transport: Arc<dyn TransportService>,
pub pg_auth: Arc<dyn PostgresAuthService>,
// Non references
pub configuration: ServerConfiguration,
pub nonce: Option<Vec<u8>>,
Expand All @@ -52,13 +54,15 @@ impl ServerManager {
pub fn new(
auth: Arc<dyn SqlAuthService>,
transport: Arc<dyn TransportService>,
pg_auth: Arc<dyn PostgresAuthService>,
compiler_cache: Arc<dyn CompilerCache>,
nonce: Option<Vec<u8>>,
config_obj: Arc<dyn ConfigObj>,
) -> Self {
Self {
auth,
transport,
pg_auth,
compiler_cache,
nonce,
config_obj,
Expand Down
Loading

0 comments on commit fb9ba98

Please sign in to comment.