Skip to content

Commit

Permalink
refactor: runPgLocals to Query.hs
Browse files Browse the repository at this point in the history
* unmmiddleware runPgLocals
  • Loading branch information
steve-chavez committed Sep 26, 2022
1 parent 0da56ab commit 21f7222
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 83 deletions.
5 changes: 3 additions & 2 deletions src/PostgREST/App.hs
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,9 @@ postgrestResponse appState conf@AppConfig{..} maybeDbStructure jsonDbS pgVer Aut
pure $ Response.infoResponse (iTarget apiRequest) dbStructure
else
runDbHandler appState (Query.txMode apiRequest) (Just authRole /= configDbAnonRole) configDbPreparedStatements .
Middleware.optionalRollback conf apiRequest $
Middleware.runPgLocals conf authClaims authRole (handleRequest . ctx) apiRequest jsonDbS pgVer
Middleware.optionalRollback conf apiRequest $ do
Query.runPgLocals conf authClaims authRole apiRequest jsonDbS pgVer
handleRequest (ctx apiRequest)

runDbHandler :: AppState.AppState -> SQL.Mode -> Bool -> Bool -> DbHandler b -> Handler IO b
runDbHandler appState mode authenticated prepared handler = do
Expand Down
80 changes: 4 additions & 76 deletions src/PostgREST/Middleware.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,79 +5,22 @@ Description : Sets CORS policy. Also the PostgreSQL GUCs, role, search_path and
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE RecordWildCards #-}
module PostgREST.Middleware
( runPgLocals
, optionalRollback
( optionalRollback
) where

import qualified Data.Aeson as JSON
import qualified Data.Aeson.Key as K
import qualified Data.Aeson.KeyMap as KM
import qualified Data.ByteString.Lazy.Char8 as LBS
import qualified Data.HashMap.Strict as HM
import qualified Data.Text.Encoding as T
import qualified Hasql.Decoders as HD
import qualified Hasql.DynamicStatements.Snippet as SQL hiding (sql)
import qualified Hasql.DynamicStatements.Statement as SQL
import qualified Hasql.Transaction as SQL
import qualified Network.Wai as Wai
import qualified Hasql.Transaction as SQL
import qualified Network.Wai as Wai


import Control.Arrow ((***))

import Data.Scientific (FPFormat (..), formatScientific, isInteger)

import PostgREST.Config (AppConfig (..))
import PostgREST.Config.PgVersion (PgVersion (..), pgVersion140)
import PostgREST.Error (Error, errorResponseFor)
import PostgREST.GucHeader (addHeadersIfNotIncluded)
import PostgREST.Query.SqlFragment (fromQi, intercalateSnippet,
pgFmtIdentList, unknownEncoder)
import PostgREST.Request.ApiRequest (ApiRequest (..), Target (..))
import PostgREST.Request.ApiRequest (ApiRequest (..))

import PostgREST.Request.Preferences

import Protolude

-- | Runs local(transaction scoped) GUCs for every request, plus the pre-request function
runPgLocals :: AppConfig -> KM.KeyMap JSON.Value -> Text ->
(ApiRequest -> ExceptT Error SQL.Transaction Wai.Response) ->
ApiRequest -> ByteString -> PgVersion -> ExceptT Error SQL.Transaction Wai.Response
runPgLocals conf claims role app req jsonDbS actualPgVersion = do
lift $ SQL.statement mempty $ SQL.dynamicallyParameterized
("select " <> intercalateSnippet ", " (searchPathSql : roleSql ++ claimsSql ++ [methodSql, pathSql] ++ headersSql ++ cookiesSql ++ appSettingsSql ++ specSql))
HD.noResult (configDbPreparedStatements conf)
lift $ traverse_ SQL.sql preReqSql
app req
where
methodSql = setConfigLocal mempty ("request.method", iMethod req)
pathSql = setConfigLocal mempty ("request.path", iPath req)
headersSql = if usesLegacyGucs
then setConfigLocal "request.header." <$> iHeaders req
else setConfigLocalJson "request.headers" (iHeaders req)
cookiesSql = if usesLegacyGucs
then setConfigLocal "request.cookie." <$> iCookies req
else setConfigLocalJson "request.cookies" (iCookies req)
claimsSql = if usesLegacyGucs
then setConfigLocal "request.jwt.claim." <$> [(toUtf8 $ K.toText c, toUtf8 $ unquoted v) | (c,v) <- KM.toList claims]
else [setConfigLocal mempty ("request.jwt.claims", LBS.toStrict $ JSON.encode claims)]
roleSql = [setConfigLocal mempty ("role", toUtf8 role)]
appSettingsSql = setConfigLocal mempty <$> (join bimap toUtf8 <$> configAppSettings conf)
searchPathSql =
let schemas = pgFmtIdentList (iSchema req : configDbExtraSearchPath conf) in
setConfigLocal mempty ("search_path", schemas)
preReqSql = (\f -> "select " <> fromQi f <> "();") <$> configDbPreRequest conf
specSql = case iTarget req of
TargetProc{tpIsRootSpec=True} -> [setConfigLocal mempty ("request.spec", jsonDbS)]
_ -> mempty
usesLegacyGucs = configDbUseLegacyGucs conf && actualPgVersion < pgVersion140

unquoted :: JSON.Value -> Text
unquoted (JSON.String t) = t
unquoted (JSON.Number n) =
toS $ formatScientific Fixed (if isInteger n then Just 0 else Nothing) n
unquoted (JSON.Bool b) = show b
unquoted v = T.decodeUtf8 . LBS.toStrict $ JSON.encode v

-- | Set a transaction to eventually roll back if requested and set respective
-- headers on the response.
optionalRollback
Expand Down Expand Up @@ -105,18 +48,3 @@ optionalRollback AppConfig{..} ApiRequest{..} transaction = do
[toAppliedHeader Rollback]
| otherwise =
identity

-- | Do a pg set_config(setting, value, true) call. This is equivalent to a SET LOCAL.
setConfigLocal :: ByteString -> (ByteString, ByteString) -> SQL.Snippet
setConfigLocal prefix (k, v) =
"set_config(" <> unknownEncoder (prefix <> k) <> ", " <> unknownEncoder v <> ", true)"

-- | Starting from PostgreSQL v14, some characters are not allowed for config names (mostly affecting headers with "-").
-- | A JSON format string is used to avoid this problem. See https://github.com/PostgREST/postgrest/issues/1857
setConfigLocalJson :: ByteString -> [(ByteString, ByteString)] -> [SQL.Snippet]
setConfigLocalJson prefix keyVals = [setConfigLocal mempty (prefix, gucJsonVal keyVals)]
where
gucJsonVal :: [(ByteString, ByteString)] -> ByteString
gucJsonVal = LBS.toStrict . JSON.encode . HM.fromList . arrayByteStringToText
arrayByteStringToText :: [(ByteString, ByteString)] -> [(Text,Text)]
arrayByteStringToText keyVal = (T.decodeUtf8 *** T.decodeUtf8) <$> keyVal
62 changes: 57 additions & 5 deletions src/PostgREST/Query.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,21 @@ module PostgREST.Query
, singleUpsertQuery
, txMode
, updateQuery
, runPgLocals
, DbHandler
) where

import qualified Data.HashMap.Strict as HM
import qualified Hasql.DynamicStatements.Snippet as SQL (Snippet)
import qualified Hasql.Transaction as SQL
import qualified Hasql.Transaction.Sessions as SQL
import qualified Data.Aeson as JSON
import qualified Data.Aeson.Key as K
import qualified Data.Aeson.KeyMap as KM
import qualified Data.ByteString.Lazy.Char8 as LBS
import qualified Data.HashMap.Strict as HM
import qualified Data.Text.Encoding as T
import qualified Hasql.Decoders as HD
import qualified Hasql.DynamicStatements.Snippet as SQL (Snippet)
import qualified Hasql.DynamicStatements.Statement as SQL
import qualified Hasql.Transaction as SQL
import qualified Hasql.Transaction.Sessions as SQL

import qualified PostgREST.DbStructure as DbStructure
import qualified PostgREST.DbStructure.Proc as Proc
Expand All @@ -25,9 +33,12 @@ import qualified PostgREST.RangeQuery as RangeQuery
import qualified PostgREST.Request.MutateQuery as MutateRequest
import qualified PostgREST.Request.Types as ApiRequestTypes

import Data.Scientific (FPFormat (..), formatScientific, isInteger)

import PostgREST.Config (AppConfig (..),
OpenAPIMode (..))
import PostgREST.Config.PgVersion (PgVersion (..))
import PostgREST.Config.PgVersion (PgVersion (..),
pgVersion140)
import PostgREST.DbStructure (DbStructure (..))
import PostgREST.DbStructure.Identifiers (FieldName,
QualifiedIdentifier (..),
Expand All @@ -38,6 +49,10 @@ import PostgREST.DbStructure.Proc (ProcDescription (..),
import PostgREST.DbStructure.Table (TablesMap)
import PostgREST.Error (Error)
import PostgREST.MediaType (MediaType (..))
import PostgREST.Query.SqlFragment (fromQi, intercalateSnippet,
pgFmtIdentList,
setConfigLocal,
setConfigLocalJson)
import PostgREST.Query.Statements (ResultSet (..))
import PostgREST.Request.ApiRequest (Action (..),
ApiRequest (..),
Expand Down Expand Up @@ -214,3 +229,40 @@ failsChangesOffLimits (Just maxChanges) RSStandard{rsQueryTotal=queryTotal} =
lift SQL.condemn
throwError $ Error.OffLimitsChangesError queryTotal maxChanges

-- | Runs local(transaction scoped) GUCs for every request, plus the pre-request function
runPgLocals :: AppConfig -> KM.KeyMap JSON.Value -> Text ->
ApiRequest -> ByteString -> PgVersion -> DbHandler ()
runPgLocals conf claims role req jsonDbS actualPgVersion = do
lift $ SQL.statement mempty $ SQL.dynamicallyParameterized
("select " <> intercalateSnippet ", " (searchPathSql : roleSql ++ claimsSql ++ [methodSql, pathSql] ++ headersSql ++ cookiesSql ++ appSettingsSql ++ specSql))
HD.noResult (configDbPreparedStatements conf)
lift $ traverse_ SQL.sql preReqSql
where
methodSql = setConfigLocal mempty ("request.method", iMethod req)
pathSql = setConfigLocal mempty ("request.path", iPath req)
headersSql = if usesLegacyGucs
then setConfigLocal "request.header." <$> iHeaders req
else setConfigLocalJson "request.headers" (iHeaders req)
cookiesSql = if usesLegacyGucs
then setConfigLocal "request.cookie." <$> iCookies req
else setConfigLocalJson "request.cookies" (iCookies req)
claimsSql = if usesLegacyGucs
then setConfigLocal "request.jwt.claim." <$> [(toUtf8 $ K.toText c, toUtf8 $ unquoted v) | (c,v) <- KM.toList claims]
else [setConfigLocal mempty ("request.jwt.claims", LBS.toStrict $ JSON.encode claims)]
roleSql = [setConfigLocal mempty ("role", toUtf8 role)]
appSettingsSql = setConfigLocal mempty <$> (join bimap toUtf8 <$> configAppSettings conf)
searchPathSql =
let schemas = pgFmtIdentList (iSchema req : configDbExtraSearchPath conf) in
setConfigLocal mempty ("search_path", schemas)
preReqSql = (\f -> "select " <> fromQi f <> "();") <$> configDbPreRequest conf
specSql = case iTarget req of
TargetProc{tpIsRootSpec=True} -> [setConfigLocal mempty ("request.spec", jsonDbS)]
_ -> mempty
usesLegacyGucs = configDbUseLegacyGucs conf && actualPgVersion < pgVersion140

unquoted :: JSON.Value -> Text
unquoted (JSON.String t) = t
unquoted (JSON.Number n) =
toS $ formatScientific Fixed (if isInteger n then Just 0 else Nothing) n
unquoted (JSON.Bool b) = show b
unquoted v = T.decodeUtf8 . LBS.toStrict $ JSON.encode v
22 changes: 22 additions & 0 deletions src/PostgREST/Query/SqlFragment.hs
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,21 @@ module PostgREST.Query.SqlFragment
, unknownEncoder
, intercalateSnippet
, explainF
, setConfigLocal
, setConfigLocalJson
) where

import qualified Data.Aeson as JSON
import qualified Data.ByteString.Char8 as BS
import qualified Data.ByteString.Lazy as LBS
import qualified Data.HashMap.Strict as HM
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import qualified Hasql.DynamicStatements.Snippet as SQL
import qualified Hasql.Encoders as HE

import Control.Arrow ((***))

import Data.Foldable (foldr1)
import Text.InterpolatedString.Perl6 (qc)

Expand Down Expand Up @@ -386,3 +393,18 @@ explainF fmt opts snip =

fmtPlanFmt PlanJSON = "FORMAT JSON"
fmtPlanFmt PlanText = "FORMAT TEXT"

-- | Do a pg set_config(setting, value, true) call. This is equivalent to a SET LOCAL.
setConfigLocal :: ByteString -> (ByteString, ByteString) -> SQL.Snippet
setConfigLocal prefix (k, v) =
"set_config(" <> unknownEncoder (prefix <> k) <> ", " <> unknownEncoder v <> ", true)"

-- | Starting from PostgreSQL v14, some characters are not allowed for config names (mostly affecting headers with "-").
-- | A JSON format string is used to avoid this problem. See https://github.com/PostgREST/postgrest/issues/1857
setConfigLocalJson :: ByteString -> [(ByteString, ByteString)] -> [SQL.Snippet]
setConfigLocalJson prefix keyVals = [setConfigLocal mempty (prefix, gucJsonVal keyVals)]
where
gucJsonVal :: [(ByteString, ByteString)] -> ByteString
gucJsonVal = LBS.toStrict . JSON.encode . HM.fromList . arrayByteStringToText
arrayByteStringToText :: [(ByteString, ByteString)] -> [(Text,Text)]
arrayByteStringToText keyVal = (T.decodeUtf8 *** T.decodeUtf8) <$> keyVal

0 comments on commit 21f7222

Please sign in to comment.