Skip to content

Commit

Permalink
chore(cubesql): SessionManager - support extra_id for Session (cube-j…
Browse files Browse the repository at this point in the history
  • Loading branch information
ovr authored Aug 22, 2024
1 parent 0f6701f commit 861f13e
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 77 deletions.
4 changes: 2 additions & 2 deletions packages/cubejs-backend-native/src/node_export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,8 @@ async fn handle_sql_query(
};

let session = session_manager
.create_session(DatabaseProtocol::PostgreSQL, host, port)
.await;
.create_session(DatabaseProtocol::PostgreSQL, host, port, None)
.await?;

session
.state
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ impl TableProvider for InfoSchemaProcesslistProvider {
) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
let mut builder = InformationSchemaProcesslistBuilder::new();

for process_list in self.sessions.process_list().await {
for process_list in self.sessions.map_sessions::<SessionProcessList>().await {
builder.add_row(process_list);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ impl PgStatActivityBuilder {
self.oid.append_value(session.oid).unwrap();
self.datname.append_option(session.datname).unwrap();
self.pid.append_value(session.pid).unwrap();
self.leader_pid.append_null().unwrap();
self.usesysid.append_null().unwrap();
self.leader_pid.append_option(session.leader_pid).unwrap();
self.usesysid.append_option(session.usesysid).unwrap();
self.usename.append_option(session.usename).unwrap();
self.application_name
.append_option(session.application_name)
Expand Down Expand Up @@ -205,7 +205,7 @@ impl TableProvider for PgCatalogStatActivityProvider {
_filters: &[Expr],
_limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
let sessions = self.sessions.stat_activity().await;
let sessions = self.sessions.map_sessions::<SessionStatActivity>().await;
let mut builder = PgStatActivityBuilder::new(sessions.len());

for session in sessions {
Expand Down
5 changes: 3 additions & 2 deletions rust/cubesql/cubesql/src/compile/test/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -565,8 +565,9 @@ async fn get_test_session_with_config_and_transport(
};
let session_manager = Arc::new(SessionManager::new(server.clone()));
let session = session_manager
.create_session(protocol, "127.0.0.1".to_string(), 1234)
.await;
.create_session(protocol, "127.0.0.1".to_string(), 1234, None)
.await
.unwrap();

// Populate like shims
session.state.set_database(Some(db_name.to_string()));
Expand Down
19 changes: 13 additions & 6 deletions rust/cubesql/cubesql/src/sql/postgres/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use tokio::{
};
use tokio_util::sync::CancellationToken;

use super::shim::AsyncPostgresShim;
use crate::{
compile::DatabaseProtocol,
config::processing_loop::{ProcessingLoop, ShutdownMode},
Expand All @@ -15,8 +16,6 @@ use crate::{
CubeError,
};

use super::shim::AsyncPostgresShim;

pub struct PostgresServer {
// options
address: String,
Expand Down Expand Up @@ -98,10 +97,18 @@ impl ProcessingLoop for PostgresServer {
}
};

let session = self
let session = match self
.session_manager
.create_session(DatabaseProtocol::PostgreSQL, client_addr, client_port)
.await;
.create_session(DatabaseProtocol::PostgreSQL, client_addr, client_port, None)
.await
{
Ok(r) => r,
Err(err) => {
error!("Session creation error: {}", err);
continue;
}
};

let logger = Arc::new(SessionLogger::new(session.state.clone()));

trace!("[pg] New connection {}", session.state.connection_id);
Expand Down Expand Up @@ -147,7 +154,7 @@ impl ProcessingLoop for PostgresServer {

// Close the listening socket (so we _visibly_ stop accepting incoming connections) before
// we wait for the outstanding connection tasks finish.
std::mem::drop(listener);
drop(listener);

// Now that we've had the stop signal, wait for outstanding connection tasks to finish
// cleanly.
Expand Down
88 changes: 47 additions & 41 deletions rust/cubesql/cubesql/src/sql/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ pub enum QueryState {
pub struct SessionState {
// connection id, immutable
pub connection_id: u32,
// Can be UUID or anything else. MDX uses UUID
pub extra_id: Option<String>,
// secret for this session
pub secret: u32,
// client ip, immutable
Expand Down Expand Up @@ -95,6 +97,7 @@ pub struct SessionState {
impl SessionState {
pub fn new(
connection_id: u32,
extra_id: Option<String>,
client_ip: String,
client_port: u16,
protocol: DatabaseProtocol,
Expand All @@ -106,6 +109,7 @@ impl SessionState {

Self {
connection_id,
extra_id,
secret: rng.gen(),
client_ip,
client_port,
Expand Down Expand Up @@ -399,46 +403,7 @@ pub struct Session {
pub state: Arc<SessionState>,
}

impl Session {
// For PostgreSQL
pub fn to_stat_activity(self: &Arc<Self>) -> SessionStatActivity {
let query = self.state.current_query();

let application_name = if let Some(v) = self.state.get_variable("application_name") {
match v.value {
ScalarValue::Utf8(r) => r,
_ => None,
}
} else {
None
};

SessionStatActivity {
oid: self.state.connection_id,
datname: self.state.database(),
pid: self.state.connection_id,
leader_pid: None,
usesysid: 0,
usename: self.state.user(),
application_name,
client_addr: self.state.client_ip.clone(),
client_hostname: None,
client_port: self.state.client_port.clone(),
query,
}
}

// For MySQL
pub fn to_process_list(self: &Arc<Self>) -> SessionProcessList {
SessionProcessList {
id: self.state.connection_id,
host: self.state.client_ip.clone(),
user: self.state.user(),
database: self.state.database(),
}
}
}

/// Specific representation of session for MySQL
#[derive(Debug)]
pub struct SessionProcessList {
pub id: u32,
Expand All @@ -447,17 +412,58 @@ pub struct SessionProcessList {
pub database: Option<String>,
}

impl From<&Arc<Session>> for SessionProcessList {
fn from(session: &Arc<Session>) -> Self {
Self {
id: session.state.connection_id,
host: session.state.client_ip.clone(),
user: session.state.user(),
database: session.state.database(),
}
}
}

/// Specific representation of session for PostgreSQL
#[derive(Debug)]
pub struct SessionStatActivity {
pub oid: u32,
pub datname: Option<String>,
pub pid: u32,
pub leader_pid: Option<u32>,
pub usesysid: u32,
pub usesysid: Option<u32>,
pub usename: Option<String>,
pub application_name: Option<String>,
pub client_addr: String,
pub client_hostname: Option<String>,
pub client_port: u16,
pub query: Option<String>,
}

impl From<&Arc<Session>> for SessionStatActivity {
fn from(session: &Arc<Session>) -> Self {
let query = session.state.current_query();

let application_name = if let Some(v) = session.state.get_variable("application_name") {
match v.value {
ScalarValue::Utf8(r) => r,
_ => None,
}
} else {
None
};

Self {
oid: session.state.connection_id,
datname: session.state.database(),
pid: session.state.connection_id,
leader_pid: None,
usesysid: None,
usename: session.state.user(),
application_name,
client_addr: session.state.client_ip.clone(),
client_hostname: None,
client_port: session.state.client_port.clone(),
query,
}
}
}
65 changes: 43 additions & 22 deletions rust/cubesql/cubesql/src/sql/session_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,21 @@ use std::{

use super::{
server_manager::ServerManager,
session::{Session, SessionProcessList, SessionStatActivity, SessionState},
session::{Session, SessionState},
};
use crate::compile::DatabaseProtocol;

#[derive(Debug)]
struct SessionManagerInner {
sessions: HashMap<u32, Arc<Session>>,
uid_to_session: HashMap<String, Arc<Session>>,
}

#[derive(Debug)]
pub struct SessionManager {
// Sessions
last_id: AtomicU32,
sessions: RWLockAsync<HashMap<u32, Arc<Session>>>,
sessions: RWLockAsync<SessionManagerInner>,
pub temp_table_size: AtomicUsize,
// Backref
pub server: Arc<ServerManager>,
Expand All @@ -30,7 +36,10 @@ impl SessionManager {
pub fn new(server: Arc<ServerManager>) -> Self {
Self {
last_id: AtomicU32::new(1),
sessions: RWLockAsync::new(HashMap::new()),
sessions: RWLockAsync::new(SessionManagerInner {
sessions: HashMap::new(),
uid_to_session: HashMap::new(),
}),
temp_table_size: AtomicUsize::new(0),
server,
}
Expand All @@ -41,60 +50,72 @@ impl SessionManager {
protocol: DatabaseProtocol,
client_addr: String,
client_port: u16,
) -> Arc<Session> {
extra_id: Option<String>,
) -> Result<Arc<Session>, CubeError> {
let connection_id = self.last_id.fetch_add(1, Ordering::SeqCst);

let sess = Session {
let session_ref = Arc::new(Session {
session_manager: self.clone(),
server: self.server.clone(),
state: Arc::new(SessionState::new(
connection_id,
extra_id.clone(),
client_addr,
client_port,
protocol,
None,
Duration::from_secs(self.server.config_obj.auth_expire_secs()),
Arc::downgrade(self),
)),
};

let session_ref = Arc::new(sess);
});

let mut guard = self.sessions.write().await;

guard.insert(connection_id, session_ref.clone());
if let Some(extra_id) = extra_id {
if guard.uid_to_session.contains_key(&extra_id) {
return Err(CubeError::user(format!(
"Session cannot be created, because extra_id: {} already exists",
extra_id
)));
}

session_ref
}
guard.uid_to_session.insert(extra_id, session_ref.clone());
}

pub async fn stat_activity(self: &Arc<Self>) -> Vec<SessionStatActivity> {
let guard = self.sessions.read().await;
guard.sessions.insert(connection_id, session_ref.clone());

guard
.values()
.map(Session::to_stat_activity)
.collect::<Vec<SessionStatActivity>>()
Ok(session_ref)
}

pub async fn process_list(self: &Arc<Self>) -> Vec<SessionProcessList> {
pub async fn map_sessions<T: for<'a> From<&'a Arc<Session>>>(self: &Arc<Self>) -> Vec<T> {
let guard = self.sessions.read().await;

guard
.sessions
.values()
.map(Session::to_process_list)
.collect::<Vec<SessionProcessList>>()
.map(|session| T::from(session))
.collect::<Vec<T>>()
}

pub async fn get_session(&self, connection_id: u32) -> Option<Arc<Session>> {
let guard = self.sessions.read().await;

guard.get(&connection_id).map(|s| s.clone())
guard.sessions.get(&connection_id).map(|s| s.clone())
}

pub async fn get_session_by_extra_id(&self, extra_id: String) -> Option<Arc<Session>> {
let guard = self.sessions.read().await;
guard.uid_to_session.get(&extra_id).map(|s| s.clone())
}

pub async fn drop_session(&self, connection_id: u32) {
let mut guard = self.sessions.write().await;

if let Some(connection) = guard.remove(&connection_id) {
if let Some(connection) = guard.sessions.remove(&connection_id) {
if let Some(extra_id) = &connection.state.extra_id {
guard.uid_to_session.remove(extra_id);
}

self.temp_table_size.fetch_sub(
connection.state.temp_tables().physical_size(),
Ordering::SeqCst,
Expand Down

0 comments on commit 861f13e

Please sign in to comment.