From b3d01bd15816ef3e5361892448616a02bf256f0f Mon Sep 17 00:00:00 2001 From: steve-chavez Date: Wed, 21 Sep 2022 18:40:13 -0500 Subject: [PATCH] refactor: runPgLocals to Query.hs * unmmiddleware runPgLocals --- src/PostgREST/App.hs | 5 +- src/PostgREST/Middleware.hs | 80 ++---------------------------- src/PostgREST/Query.hs | 64 +++++++++++++++++++++--- src/PostgREST/Query/SqlFragment.hs | 22 ++++++++ 4 files changed, 86 insertions(+), 85 deletions(-) diff --git a/src/PostgREST/App.hs b/src/PostgREST/App.hs index 78bb47836e..8ff6cc1eaf 100644 --- a/src/PostgREST/App.hs +++ b/src/PostgREST/App.hs @@ -186,8 +186,9 @@ postgrestResponse appState conf@AppConfig{..} maybeDbStructure jsonDbS pgVer Aut pure $ Response.infoResponse (iTarget apiRequest) dbStructure else runDbHandler appState (Query.txMode apiRequest) (Just authRole /= configDbAnonRole) configDbPreparedStatements . - Middleware.optionalRollback conf apiRequest $ - Middleware.runPgLocals conf authClaims authRole (handleRequest . ctx) apiRequest jsonDbS pgVer + Middleware.optionalRollback conf apiRequest $ do + Query.runPgLocals conf authClaims authRole apiRequest jsonDbS pgVer + handleRequest (ctx apiRequest) runDbHandler :: AppState.AppState -> SQL.Mode -> Bool -> Bool -> DbHandler b -> Handler IO b runDbHandler appState mode authenticated prepared handler = do diff --git a/src/PostgREST/Middleware.hs b/src/PostgREST/Middleware.hs index 09981aab7a..858ea74ec9 100644 --- a/src/PostgREST/Middleware.hs +++ b/src/PostgREST/Middleware.hs @@ -5,79 +5,22 @@ Description : Sets CORS policy. Also the PostgreSQL GUCs, role, search_path and {-# LANGUAGE BlockArguments #-} {-# LANGUAGE RecordWildCards #-} module PostgREST.Middleware - ( runPgLocals - , optionalRollback + ( optionalRollback ) where -import qualified Data.Aeson as JSON -import qualified Data.Aeson.Key as K -import qualified Data.Aeson.KeyMap as KM -import qualified Data.ByteString.Lazy.Char8 as LBS -import qualified Data.HashMap.Strict as HM -import qualified Data.Text.Encoding as T -import qualified Hasql.Decoders as HD -import qualified Hasql.DynamicStatements.Snippet as SQL hiding (sql) -import qualified Hasql.DynamicStatements.Statement as SQL -import qualified Hasql.Transaction as SQL -import qualified Network.Wai as Wai +import qualified Hasql.Transaction as SQL +import qualified Network.Wai as Wai -import Control.Arrow ((***)) - -import Data.Scientific (FPFormat (..), formatScientific, isInteger) - import PostgREST.Config (AppConfig (..)) -import PostgREST.Config.PgVersion (PgVersion (..), pgVersion140) import PostgREST.Error (Error, errorResponseFor) import PostgREST.GucHeader (addHeadersIfNotIncluded) -import PostgREST.Query.SqlFragment (fromQi, intercalateSnippet, - pgFmtIdentList, unknownEncoder) -import PostgREST.Request.ApiRequest (ApiRequest (..), Target (..)) +import PostgREST.Request.ApiRequest (ApiRequest (..)) import PostgREST.Request.Preferences import Protolude --- | Runs local(transaction scoped) GUCs for every request, plus the pre-request function -runPgLocals :: AppConfig -> KM.KeyMap JSON.Value -> Text -> - (ApiRequest -> ExceptT Error SQL.Transaction Wai.Response) -> - ApiRequest -> ByteString -> PgVersion -> ExceptT Error SQL.Transaction Wai.Response -runPgLocals conf claims role app req jsonDbS actualPgVersion = do - lift $ SQL.statement mempty $ SQL.dynamicallyParameterized - ("select " <> intercalateSnippet ", " (searchPathSql : roleSql ++ claimsSql ++ [methodSql, pathSql] ++ headersSql ++ cookiesSql ++ appSettingsSql ++ specSql)) - HD.noResult (configDbPreparedStatements conf) - lift $ traverse_ SQL.sql preReqSql - app req - where - methodSql = setConfigLocal mempty ("request.method", iMethod req) - pathSql = setConfigLocal mempty ("request.path", iPath req) - headersSql = if usesLegacyGucs - then setConfigLocal "request.header." <$> iHeaders req - else setConfigLocalJson "request.headers" (iHeaders req) - cookiesSql = if usesLegacyGucs - then setConfigLocal "request.cookie." <$> iCookies req - else setConfigLocalJson "request.cookies" (iCookies req) - claimsSql = if usesLegacyGucs - then setConfigLocal "request.jwt.claim." <$> [(toUtf8 $ K.toText c, toUtf8 $ unquoted v) | (c,v) <- KM.toList claims] - else [setConfigLocal mempty ("request.jwt.claims", LBS.toStrict $ JSON.encode claims)] - roleSql = [setConfigLocal mempty ("role", toUtf8 role)] - appSettingsSql = setConfigLocal mempty <$> (join bimap toUtf8 <$> configAppSettings conf) - searchPathSql = - let schemas = pgFmtIdentList (iSchema req : configDbExtraSearchPath conf) in - setConfigLocal mempty ("search_path", schemas) - preReqSql = (\f -> "select " <> fromQi f <> "();") <$> configDbPreRequest conf - specSql = case iTarget req of - TargetProc{tpIsRootSpec=True} -> [setConfigLocal mempty ("request.spec", jsonDbS)] - _ -> mempty - usesLegacyGucs = configDbUseLegacyGucs conf && actualPgVersion < pgVersion140 - - unquoted :: JSON.Value -> Text - unquoted (JSON.String t) = t - unquoted (JSON.Number n) = - toS $ formatScientific Fixed (if isInteger n then Just 0 else Nothing) n - unquoted (JSON.Bool b) = show b - unquoted v = T.decodeUtf8 . LBS.toStrict $ JSON.encode v - -- | Set a transaction to eventually roll back if requested and set respective -- headers on the response. optionalRollback @@ -105,18 +48,3 @@ optionalRollback AppConfig{..} ApiRequest{..} transaction = do [toAppliedHeader Rollback] | otherwise = identity - --- | Do a pg set_config(setting, value, true) call. This is equivalent to a SET LOCAL. -setConfigLocal :: ByteString -> (ByteString, ByteString) -> SQL.Snippet -setConfigLocal prefix (k, v) = - "set_config(" <> unknownEncoder (prefix <> k) <> ", " <> unknownEncoder v <> ", true)" - --- | Starting from PostgreSQL v14, some characters are not allowed for config names (mostly affecting headers with "-"). --- | A JSON format string is used to avoid this problem. See https://github.com/PostgREST/postgrest/issues/1857 -setConfigLocalJson :: ByteString -> [(ByteString, ByteString)] -> [SQL.Snippet] -setConfigLocalJson prefix keyVals = [setConfigLocal mempty (prefix, gucJsonVal keyVals)] - where - gucJsonVal :: [(ByteString, ByteString)] -> ByteString - gucJsonVal = LBS.toStrict . JSON.encode . HM.fromList . arrayByteStringToText - arrayByteStringToText :: [(ByteString, ByteString)] -> [(Text,Text)] - arrayByteStringToText keyVal = (T.decodeUtf8 *** T.decodeUtf8) <$> keyVal diff --git a/src/PostgREST/Query.hs b/src/PostgREST/Query.hs index 232451ccb0..d3344ec14a 100644 --- a/src/PostgREST/Query.hs +++ b/src/PostgREST/Query.hs @@ -8,13 +8,21 @@ module PostgREST.Query , singleUpsertQuery , txMode , updateQuery + , runPgLocals , DbHandler ) where -import qualified Data.HashMap.Strict as HM -import qualified Hasql.DynamicStatements.Snippet as SQL (Snippet) -import qualified Hasql.Transaction as SQL -import qualified Hasql.Transaction.Sessions as SQL +import qualified Data.Aeson as JSON +import qualified Data.Aeson.Key as K +import qualified Data.Aeson.KeyMap as KM +import qualified Data.ByteString.Lazy.Char8 as LBS +import qualified Data.HashMap.Strict as HM +import qualified Data.Text.Encoding as T +import qualified Hasql.Decoders as HD +import qualified Hasql.DynamicStatements.Snippet as SQL (Snippet) +import qualified Hasql.DynamicStatements.Statement as SQL +import qualified Hasql.Transaction as SQL +import qualified Hasql.Transaction.Sessions as SQL import qualified PostgREST.DbStructure as DbStructure import qualified PostgREST.DbStructure.Proc as Proc @@ -25,9 +33,12 @@ import qualified PostgREST.RangeQuery as RangeQuery import qualified PostgREST.Request.MutateQuery as MutateRequest import qualified PostgREST.Request.Types as ApiRequestTypes +import Data.Scientific (FPFormat (..), formatScientific, isInteger) + import PostgREST.Config (AppConfig (..), OpenAPIMode (..)) -import PostgREST.Config.PgVersion (PgVersion (..)) +import PostgREST.Config.PgVersion (PgVersion (..), + pgVersion140) import PostgREST.DbStructure (DbStructure (..)) import PostgREST.DbStructure.Identifiers (FieldName, QualifiedIdentifier (..), @@ -38,6 +49,10 @@ import PostgREST.DbStructure.Proc (ProcDescription (..), import PostgREST.DbStructure.Table (TablesMap) import PostgREST.Error (Error) import PostgREST.MediaType (MediaType (..)) +import PostgREST.Query.SqlFragment (fromQi, intercalateSnippet, + pgFmtIdentList, + setConfigLocal, + setConfigLocalJson) import PostgREST.Query.Statements (ResultSet (..)) import PostgREST.Request.ApiRequest (Action (..), ApiRequest (..), @@ -169,8 +184,6 @@ txMode ApiRequest{..} = case (iAction, iTarget) of (ActionRead _, _) -> SQL.Read - (ActionInfo, _) -> - SQL.Read (ActionInspect _, _) -> SQL.Read (ActionInvoke InvGet, _) -> @@ -214,3 +227,40 @@ failsChangesOffLimits (Just maxChanges) RSStandard{rsQueryTotal=queryTotal} = lift SQL.condemn throwError $ Error.OffLimitsChangesError queryTotal maxChanges +-- | Runs local(transaction scoped) GUCs for every request, plus the pre-request function +runPgLocals :: AppConfig -> KM.KeyMap JSON.Value -> Text -> + ApiRequest -> ByteString -> PgVersion -> DbHandler () +runPgLocals conf claims role req jsonDbS actualPgVersion = do + lift $ SQL.statement mempty $ SQL.dynamicallyParameterized + ("select " <> intercalateSnippet ", " (searchPathSql : roleSql ++ claimsSql ++ [methodSql, pathSql] ++ headersSql ++ cookiesSql ++ appSettingsSql ++ specSql)) + HD.noResult (configDbPreparedStatements conf) + lift $ traverse_ SQL.sql preReqSql + where + methodSql = setConfigLocal mempty ("request.method", iMethod req) + pathSql = setConfigLocal mempty ("request.path", iPath req) + headersSql = if usesLegacyGucs + then setConfigLocal "request.header." <$> iHeaders req + else setConfigLocalJson "request.headers" (iHeaders req) + cookiesSql = if usesLegacyGucs + then setConfigLocal "request.cookie." <$> iCookies req + else setConfigLocalJson "request.cookies" (iCookies req) + claimsSql = if usesLegacyGucs + then setConfigLocal "request.jwt.claim." <$> [(toUtf8 $ K.toText c, toUtf8 $ unquoted v) | (c,v) <- KM.toList claims] + else [setConfigLocal mempty ("request.jwt.claims", LBS.toStrict $ JSON.encode claims)] + roleSql = [setConfigLocal mempty ("role", toUtf8 role)] + appSettingsSql = setConfigLocal mempty <$> (join bimap toUtf8 <$> configAppSettings conf) + searchPathSql = + let schemas = pgFmtIdentList (iSchema req : configDbExtraSearchPath conf) in + setConfigLocal mempty ("search_path", schemas) + preReqSql = (\f -> "select " <> fromQi f <> "();") <$> configDbPreRequest conf + specSql = case iTarget req of + TargetProc{tpIsRootSpec=True} -> [setConfigLocal mempty ("request.spec", jsonDbS)] + _ -> mempty + usesLegacyGucs = configDbUseLegacyGucs conf && actualPgVersion < pgVersion140 + + unquoted :: JSON.Value -> Text + unquoted (JSON.String t) = t + unquoted (JSON.Number n) = + toS $ formatScientific Fixed (if isInteger n then Just 0 else Nothing) n + unquoted (JSON.Bool b) = show b + unquoted v = T.decodeUtf8 . LBS.toStrict $ JSON.encode v diff --git a/src/PostgREST/Query/SqlFragment.hs b/src/PostgREST/Query/SqlFragment.hs index db79408684..6f6c9389d9 100644 --- a/src/PostgREST/Query/SqlFragment.hs +++ b/src/PostgREST/Query/SqlFragment.hs @@ -38,14 +38,21 @@ module PostgREST.Query.SqlFragment , unknownEncoder , intercalateSnippet , explainF + , setConfigLocal + , setConfigLocalJson ) where +import qualified Data.Aeson as JSON import qualified Data.ByteString.Char8 as BS import qualified Data.ByteString.Lazy as LBS +import qualified Data.HashMap.Strict as HM import qualified Data.Text as T +import qualified Data.Text.Encoding as T import qualified Hasql.DynamicStatements.Snippet as SQL import qualified Hasql.Encoders as HE +import Control.Arrow ((***)) + import Data.Foldable (foldr1) import Text.InterpolatedString.Perl6 (qc) @@ -386,3 +393,18 @@ explainF fmt opts snip = fmtPlanFmt PlanJSON = "FORMAT JSON" fmtPlanFmt PlanText = "FORMAT TEXT" + +-- | Do a pg set_config(setting, value, true) call. This is equivalent to a SET LOCAL. +setConfigLocal :: ByteString -> (ByteString, ByteString) -> SQL.Snippet +setConfigLocal prefix (k, v) = + "set_config(" <> unknownEncoder (prefix <> k) <> ", " <> unknownEncoder v <> ", true)" + +-- | Starting from PostgreSQL v14, some characters are not allowed for config names (mostly affecting headers with "-"). +-- | A JSON format string is used to avoid this problem. See https://github.com/PostgREST/postgrest/issues/1857 +setConfigLocalJson :: ByteString -> [(ByteString, ByteString)] -> [SQL.Snippet] +setConfigLocalJson prefix keyVals = [setConfigLocal mempty (prefix, gucJsonVal keyVals)] + where + gucJsonVal :: [(ByteString, ByteString)] -> ByteString + gucJsonVal = LBS.toStrict . JSON.encode . HM.fromList . arrayByteStringToText + arrayByteStringToText :: [(ByteString, ByteString)] -> [(Text,Text)] + arrayByteStringToText keyVal = (T.decodeUtf8 *** T.decodeUtf8) <$> keyVal