Skip to content

Commit

Permalink
feat: sample use of custom connection (working)
Browse files Browse the repository at this point in the history
  • Loading branch information
ShubhranshuSanjeev committed Feb 2, 2025
1 parent 7a53eea commit 007698a
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 33 deletions.
19 changes: 6 additions & 13 deletions crates/context_aware_config/src/api/default_config/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use diesel::{Connection, ExpressionMethods, QueryDsl, RunQueryDsl, SelectableHel
use jsonschema::{Draft, JSONSchema, ValidationError};
use serde_json::Value;
use service_utils::{
db::types::ConnectionImpl,
helpers::{parse_config_tags, validation_err_to_str},
service::types::{AppHeader, AppState, CustomHeaders, DbConnection, SchemaName},
};
Expand Down Expand Up @@ -317,38 +318,30 @@ fn fetch_default_key(

#[get("")]
async fn get(
db_conn: DbConnection,
mut db_conn: ConnectionImpl,
filters: Query<PaginationParams>,
schema_name: SchemaName,
) -> superposition::Result<Json<PaginatedResponse<DefaultConfig>>> {
let DbConnection(mut conn) = db_conn;

if let Some(true) = filters.all {
let result: Vec<DefaultConfig> = dsl::default_configs
.schema_name(&schema_name)
.get_results(&mut conn)?;
let result: Vec<DefaultConfig> =
dsl::default_configs.get_results(&mut db_conn)?;
return Ok(Json(PaginatedResponse {
total_pages: 1,
total_items: result.len() as i64,
data: result,
}));
}

let n_default_configs: i64 = dsl::default_configs
.count()
.schema_name(&schema_name)
.get_result(&mut conn)?;
let n_default_configs: i64 = dsl::default_configs.count().get_result(&mut db_conn)?;
let limit = filters.count.unwrap_or(10);
let mut builder = dsl::default_configs
.order(dsl::created_at.desc())
.limit(limit)
.schema_name(&schema_name)
.into_boxed();
if let Some(page) = filters.page {
let offset = (page - 1) * limit;
builder = builder.offset(offset);
}
let result: Vec<DefaultConfig> = builder.load(&mut conn)?;
let result: Vec<DefaultConfig> = builder.load(&mut db_conn)?;
let total_pages = (n_default_configs as f64 / limit as f64).ceil() as i64;
Ok(Json(PaginatedResponse {
total_pages,
Expand Down
111 changes: 94 additions & 17 deletions crates/service_utils/src/db/types.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
use actix_web::web::Data;
use actix_web::{FromRequest, HttpMessage};
use derive_more::{Deref, DerefMut};
use diesel::connection::{
AnsiTransactionManager, Connection, ConnectionSealed, DefaultLoadingMode,
LoadConnection, SimpleConnection, TransactionManager,
Expand All @@ -7,14 +10,16 @@ use diesel::r2d2::{ConnectionManager, PooledConnection};
use diesel::PgConnection;
use diesel::RunQueryDsl;

pub struct SuperTransactionManager;
use crate::service::types::{AppState, SchemaName};

impl TransactionManager<SuperConnection> for SuperTransactionManager {
pub struct TransactionManagerImpl;

impl TransactionManager<ConnectionImpl> for TransactionManagerImpl {
type TransactionStateData = <AnsiTransactionManager as TransactionManager<
PgConnection,
>>::TransactionStateData;

fn begin_transaction(conn: &mut SuperConnection) -> diesel::prelude::QueryResult<()> {
fn begin_transaction(conn: &mut ConnectionImpl) -> diesel::prelude::QueryResult<()> {
AnsiTransactionManager::begin_transaction(&mut *conn.conn)?;
let result = diesel::sql_query("SELECT set_config('search_path', $1, true)")
.bind::<diesel::sql_types::Text, _>(&conn.namespace)
Expand All @@ -24,55 +29,80 @@ impl TransactionManager<SuperConnection> for SuperTransactionManager {
}

fn rollback_transaction(
conn: &mut SuperConnection,
conn: &mut ConnectionImpl,
) -> diesel::prelude::QueryResult<()> {
AnsiTransactionManager::rollback_transaction(&mut *conn.conn)
}

fn commit_transaction(
conn: &mut SuperConnection,
) -> diesel::prelude::QueryResult<()> {
fn commit_transaction(conn: &mut ConnectionImpl) -> diesel::prelude::QueryResult<()> {
AnsiTransactionManager::commit_transaction(&mut *conn.conn)
}

fn transaction_manager_status_mut(
conn: &mut SuperConnection,
conn: &mut ConnectionImpl,
) -> &mut diesel::connection::TransactionManagerStatus {
AnsiTransactionManager::transaction_manager_status_mut(&mut *conn.conn)
}
}

pub struct SuperConnection {
pub struct ConnectionImpl {
namespace: String,
conn: PooledConnection<ConnectionManager<PgConnection>>,
}

impl SuperConnection {
impl ConnectionImpl {
pub fn new(
namespace: String,
mut conn: PooledConnection<ConnectionManager<PgConnection>>,
) -> Self {
conn.set_prepared_statement_cache_size(diesel::connection::CacheSize::Disabled);
SuperConnection { namespace, conn }
ConnectionImpl { namespace, conn }
}

pub fn set_namespace(&mut self, namespace: String) {
self.namespace = namespace;
}

pub fn from_request_override(
req: &actix_web::HttpRequest,
schema_name: String,
) -> Result<Self, actix_web::Error> {
let app_state = match req.app_data::<Data<AppState>>() {
Some(state) => state,
None => {
log::info!(
"DbConnection-FromRequest: Unable to get app_data from request"
);
return Err(actix_web::error::ErrorInternalServerError(""));
}
};

match app_state.db_pool.get() {
Ok(conn) => Ok(ConnectionImpl::new(schema_name, conn)),
Err(e) => {
log::info!("Unable to get db connection from pool, error: {e}");
Err(actix_web::error::ErrorInternalServerError(""))
}
}
}
}

impl ConnectionSealed for SuperConnection {}
impl ConnectionSealed for ConnectionImpl {}

impl SimpleConnection for SuperConnection {
impl SimpleConnection for ConnectionImpl {
fn batch_execute(&mut self, query: &str) -> diesel::prelude::QueryResult<()> {
self.conn.batch_execute(query)
}
}

impl Connection for SuperConnection {
impl Connection for ConnectionImpl {
type Backend = Pg;
type TransactionManager = SuperTransactionManager;
type TransactionManager = TransactionManagerImpl;

// NOTE: this function will never be used, so namespace here doesn't matter
fn establish(database_url: &str) -> diesel::prelude::ConnectionResult<Self> {
let conn = PooledConnection::establish(database_url)?;
Ok(SuperConnection {
Ok(ConnectionImpl {
namespace: String::new(),
conn,
})
Expand Down Expand Up @@ -112,7 +142,7 @@ impl Connection for SuperConnection {
}
}

impl LoadConnection<DefaultLoadingMode> for SuperConnection {
impl LoadConnection<DefaultLoadingMode> for ConnectionImpl {
type Cursor<'conn, 'query> =
<PgConnection as LoadConnection<DefaultLoadingMode>>::Cursor<'conn, 'query>;
type Row<'conn, 'query> =
Expand Down Expand Up @@ -140,3 +170,50 @@ impl LoadConnection<DefaultLoadingMode> for SuperConnection {
)
}
}

impl FromRequest for ConnectionImpl {
type Error = actix_web::Error;
type Future = std::future::Ready<Result<ConnectionImpl, Self::Error>>;

fn from_request(
req: &actix_web::HttpRequest,
_: &mut actix_web::dev::Payload,
) -> Self::Future {
let schema_name = req.extensions().get::<SchemaName>().cloned().unwrap().0;
std::future::ready(ConnectionImpl::from_request_override(req, schema_name))
}
}

#[derive(Deref, DerefMut)]
pub struct PublicConnection(pub ConnectionImpl);
impl FromRequest for PublicConnection {
type Error = actix_web::Error;
type Future = std::future::Ready<Result<PublicConnection, Self::Error>>;

fn from_request(
req: &actix_web::HttpRequest,
_: &mut actix_web::dev::Payload,
) -> Self::Future {
std::future::ready(
ConnectionImpl::from_request_override(req, String::from("public"))
.map(|conn| PublicConnection(conn)),
)
}
}

#[derive(Deref, DerefMut)]
pub struct SuperpositionConnection(pub ConnectionImpl);
impl FromRequest for SuperpositionConnection {
type Error = actix_web::Error;
type Future = std::future::Ready<Result<SuperpositionConnection, Self::Error>>;

fn from_request(
req: &actix_web::HttpRequest,
_: &mut actix_web::dev::Payload,
) -> Self::Future {
std::future::ready(
ConnectionImpl::from_request_override(req, String::from("superposition"))
.map(|conn| SuperpositionConnection(conn)),
)
}
}
3 changes: 0 additions & 3 deletions crates/service_utils/src/service/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ pub struct DbConnection(pub PooledConnection<ConnectionManager<PgConnection>>);
impl FromRequest for DbConnection {
type Error = Error;
type Future = Ready<Result<DbConnection, Self::Error>>;

fn from_request(
req: &actix_web::HttpRequest,
_: &mut actix_web::dev::Payload,
Expand All @@ -185,15 +184,13 @@ impl FromRequest for DbConnection {
return ready(Err(error::ErrorInternalServerError("")));
}
};

let result = match app_state.db_pool.get() {
Ok(conn) => Ok(DbConnection(conn)),
Err(e) => {
log::info!("Unable to get db connection from pool, error: {e}");
Err(error::ErrorInternalServerError(""))
}
};

ready(result)
}
}
Expand Down

0 comments on commit 007698a

Please sign in to comment.