Skip to content

Commit

Permalink
feat(cubesql): Introduce max sessions limit
Browse files Browse the repository at this point in the history
  • Loading branch information
ovr committed Aug 27, 2024
1 parent 03277b0 commit 801e7f3
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 12 deletions.
9 changes: 9 additions & 0 deletions rust/cubesql/cubesql/src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ pub trait ConfigObj: DIService + Debug {

fn non_streaming_query_max_row_limit(&self) -> i32;

fn max_sessions(&self) -> usize;

fn no_implicit_order(&self) -> bool;
}

Expand All @@ -128,6 +130,7 @@ pub struct ConfigObjImpl {
pub push_down_pull_up_split: bool,
pub stream_mode: bool,
pub non_streaming_query_max_row_limit: i32,
pub max_sessions: usize,
pub no_implicit_order: bool,
}

Expand Down Expand Up @@ -164,6 +167,7 @@ impl ConfigObjImpl {
.unwrap_or(sql_push_down),
stream_mode: env_parse("CUBESQL_STREAM_MODE", false),
non_streaming_query_max_row_limit: env_parse("CUBEJS_DB_QUERY_LIMIT", 50000),
max_sessions: env_parse("CUBEJS_MAX_SESSIONS", 1024),
no_implicit_order: env_parse("CUBESQL_SQL_NO_IMPLICIT_ORDER", false),
}
}
Expand Down Expand Up @@ -227,6 +231,10 @@ impl ConfigObj for ConfigObjImpl {
fn no_implicit_order(&self) -> bool {
self.no_implicit_order
}

fn max_sessions(&self) -> usize {
self.max_sessions
}
}

lazy_static! {
Expand Down Expand Up @@ -262,6 +270,7 @@ impl Config {
push_down_pull_up_split: true,
stream_mode: false,
non_streaming_query_max_row_limit: 50000,
max_sessions: 1024,
no_implicit_order: false,
}),
}
Expand Down
14 changes: 8 additions & 6 deletions rust/cubesql/cubesql/src/sql/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ pub struct SessionState {
// connection id, immutable
pub connection_id: u32,
// Can be UUID or anything else. MDX uses UUID
pub extra_id: Option<String>,
pub extra_id: Option<SessionExtraId>,
// secret for this session
pub secret: u32,
// client ip, immutable
Expand Down Expand Up @@ -97,7 +97,7 @@ pub struct SessionState {
impl SessionState {
pub fn new(
connection_id: u32,
extra_id: Option<String>,
extra_id: Option<SessionExtraId>,
client_ip: String,
client_port: u16,
protocol: DatabaseProtocol,
Expand Down Expand Up @@ -394,6 +394,8 @@ impl SessionState {
}
}

pub type SessionExtraId = [u8; 16];

#[derive(Debug)]
pub struct Session {
// Backref
Expand All @@ -412,8 +414,8 @@ pub struct SessionProcessList {
pub database: Option<String>,
}

impl From<&Arc<Session>> for SessionProcessList {
fn from(session: &Arc<Session>) -> Self {
impl From<&Session> for SessionProcessList {
fn from(session: &Session) -> Self {
Self {
id: session.state.connection_id,
host: session.state.client_ip.clone(),
Expand All @@ -439,8 +441,8 @@ pub struct SessionStatActivity {
pub query: Option<String>,
}

impl From<&Arc<Session>> for SessionStatActivity {
fn from(session: &Arc<Session>) -> Self {
impl From<&Session> for SessionStatActivity {
fn from(session: &Session) -> Self {
let query = session.state.current_query();

let application_name = if let Some(v) = session.state.get_variable("application_name") {
Expand Down
19 changes: 13 additions & 6 deletions rust/cubesql/cubesql/src/sql/session_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ use super::{
server_manager::ServerManager,
session::{Session, SessionState},
};
use crate::compile::DatabaseProtocol;
use crate::{compile::DatabaseProtocol, sql::session::SessionExtraId};

#[derive(Debug)]
struct SessionManagerInner {
sessions: HashMap<u32, Arc<Session>>,
uid_to_session: HashMap<String, Arc<Session>>,
uid_to_session: HashMap<SessionExtraId, Arc<Session>>,
}

#[derive(Debug)]
Expand Down Expand Up @@ -50,7 +50,7 @@ impl SessionManager {
protocol: DatabaseProtocol,
client_addr: String,
client_port: u16,
extra_id: Option<String>,
extra_id: Option<SessionExtraId>,
) -> Result<Arc<Session>, CubeError> {
let connection_id = self.last_id.fetch_add(1, Ordering::SeqCst);

Expand All @@ -71,10 +71,17 @@ impl SessionManager {

let mut guard = self.sessions.write().await;

if guard.sessions.len() > self.server.config_obj.max_sessions() {
return Err(CubeError::user(format!(
"Too many sessions, limit reached: {}",
self.server.config_obj.max_sessions()
)));
}

if let Some(extra_id) = extra_id {
if guard.uid_to_session.contains_key(&extra_id) {
return Err(CubeError::user(format!(
"Session cannot be created, because extra_id: {} already exists",
"Session cannot be created, because extra_id: {:?} already exists",
extra_id
)));
}
Expand All @@ -87,7 +94,7 @@ impl SessionManager {
Ok(session_ref)
}

pub async fn map_sessions<T: for<'a> From<&'a Arc<Session>>>(self: &Arc<Self>) -> Vec<T> {
pub async fn map_sessions<T: for<'a> From<&'a Session>>(self: &Arc<Self>) -> Vec<T> {
let guard = self.sessions.read().await;

guard
Expand All @@ -103,7 +110,7 @@ impl SessionManager {
guard.sessions.get(&connection_id).map(|s| s.clone())
}

pub async fn get_session_by_extra_id(&self, extra_id: String) -> Option<Arc<Session>> {
pub async fn get_session_by_extra_id(&self, extra_id: SessionExtraId) -> Option<Arc<Session>> {
let guard = self.sessions.read().await;
guard.uid_to_session.get(&extra_id).map(|s| s.clone())
}
Expand Down

0 comments on commit 801e7f3

Please sign in to comment.