Skip to content

Commit

Permalink
Add post_connect to sql_pool
Browse files Browse the repository at this point in the history
Allow for a custom post_connect to run for new SQL connection, for
example registering user functions for SQLite connections.
  • Loading branch information
Gusted committed Jan 10, 2025
1 parent 27ff431 commit 2f03da4
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 6 deletions.
3 changes: 2 additions & 1 deletion docs/web/postprocess/index.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1951,7 +1951,8 @@ let graphiql_expected = {|<div class="spec value" id="val-graphiql">
|}

let sql_pool_expected = {|<div class="spec value" id="val-sql_pool">
<a href="#val-sql_pool" class="anchor"></a><code><span><span class="keyword">val</span> sql_pool : <span>?size:int <span class="arrow">-&gt;</span></span> <span>string <span class="arrow">-&gt;</span></span> <a href="#type-middleware">middleware</a></span></code>
<a href="#val-sql_pool" class="anchor"></a><code><span><span class="keyword">val</span> sql_pool : <span>?size:int <span class="arrow">-&gt;</span></span>
<span>?post_connect:<span>(<span><span>(<span class="keyword">module</span> <span class="xref-unresolved">Caqti_lwt</span>.CONNECTION)</span> <span class="arrow">-&gt;</span></span> <span><span><span>(unit,&nbsp;<span class="xref-unresolved">Caqti_error</span>.t)</span> <span class="xref-unresolved">Stdlib</span>.result</span> <span class="xref-unresolved">Lwt</span>.t</span>)</span> <span class="arrow">-&gt;</span></span> <span>string <span class="arrow">-&gt;</span></span> <a href="#type-middleware">middleware</a></span></code>
</div>
|}

Expand Down
5 changes: 3 additions & 2 deletions src/dream.mli
Original file line number Diff line number Diff line change
Expand Up @@ -1746,11 +1746,12 @@ val graphiql : ?default_query:string -> string -> handler
{{:https://cheatsheetseries.owasp.org/cheatsheets/Database_Security_Cheat_Sheet.html}
OWASP {i Database Security Cheat Sheet}}. *)

val sql_pool : ?size:int -> string -> middleware
val sql_pool : ?size:int -> ?post_connect:((module Caqti_lwt.CONNECTION) -> (unit, Caqti_error.t) result Lwt.t) -> string -> middleware
(** Makes an SQL connection pool available to its inner handler. [?size] is the
maximum number of concurrent connections that the pool will support. The
default value is picked by the driver. Note that for SQLite, [?size] is
capped to [1]. *)
capped to [1]. [post_connect] is an optional callback, which is called for
every new connection that is opened to the database. *)

val sql : request -> (Caqti_lwt.connection -> 'a promise) -> 'a promise
(** Runs the callback with a connection from the SQL pool. See example
Expand Down
15 changes: 12 additions & 3 deletions src/sql/sql.ml
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ let foreign_keys_on =
(Caqti_type.unit ->. Caqti_type.unit) "PRAGMA foreign_keys = ON"
[@ocaml.warning "-3"]

let post_connect (module Db : Caqti_lwt.CONNECTION) =
let standard_post_connect (module Db : Caqti_lwt.CONNECTION) =
match Caqti_driver_info.dialect_tag Db.driver_info with
| `Sqlite -> Db.exec foreign_keys_on ()
| _ -> Lwt.return (Ok ())

let sql_pool ?size uri =
let sql_pool ?size ?post_connect uri =
let pool_cell = ref None in
fun inner_handler request ->

Expand All @@ -49,7 +49,16 @@ let sql_pool ?size uri =
'sqlite' is not a valid scheme; did you mean 'sqlite3'?");
let pool =
let pool_config = Caqti_pool_config.create ?max_size:size () in
Caqti_lwt_unix.connect_pool ~pool_config ~post_connect parsed_uri in
Caqti_lwt_unix.connect_pool ~pool_config ~post_connect:(fun db ->
let%lwt result = standard_post_connect db in
match result with
| Ok () ->
(match post_connect with
| Some f -> f db
| None -> Lwt.return (Ok ()))
| Error e -> Lwt.return (Error e))
parsed_uri
in
match pool with
| Ok pool ->
pool_cell := Some pool;
Expand Down

0 comments on commit 2f03da4

Please sign in to comment.