From 078c6ec08c69da1482604e4688e2d19ec7c32a2f Mon Sep 17 00:00:00 2001 From: steve-chavez Date: Thu, 22 Jun 2023 19:51:24 -0500 Subject: [PATCH] refactor: remove SqlFragment and use SQL.Snippet --- src/PostgREST/Plan.hs | 2 +- src/PostgREST/Query.hs | 11 ++- src/PostgREST/Query/QueryBuilder.hs | 90 ++++++++++++------------ src/PostgREST/Query/SqlFragment.hs | 102 ++++++++++++++-------------- src/PostgREST/Query/Statements.hs | 30 ++++---- 5 files changed, 113 insertions(+), 122 deletions(-) diff --git a/src/PostgREST/Plan.hs b/src/PostgREST/Plan.hs index 658e7bfb5a..b5429083fb 100644 --- a/src/PostgREST/Plan.hs +++ b/src/PostgREST/Plan.hs @@ -271,7 +271,7 @@ addRels schema action allRels parentNode (Node rPlan@ReadPlan{relName,relHint,re Node <$> newReadPlan <*> (updateForest . hush $ Node <$> newReadPlan <*> pure forest) Nothing -> -- root case let - newFrom = QualifiedIdentifier mempty $ decodeUtf8 sourceCTEName + newFrom = QualifiedIdentifier mempty sourceCTEName newAlias = Just (qiName $ from rPlan) newReadPlan = case action of -- the CTE for mutations/rpc is used as WITH sourceCTEName .. SELECT .. FROM sourceCTEName as alias, diff --git a/src/PostgREST/Query.hs b/src/PostgREST/Query.hs index c0f2097b59..10694758c5 100644 --- a/src/PostgREST/Query.hs +++ b/src/PostgREST/Query.hs @@ -24,8 +24,6 @@ import qualified Data.Text.Encoding as T import qualified Hasql.Decoders as HD import qualified Hasql.DynamicStatements.Snippet as SQL (Snippet) import qualified Hasql.DynamicStatements.Statement as SQL -import qualified Hasql.Encoders as HE -import qualified Hasql.Statement as SQL import qualified Hasql.Transaction as SQL import qualified PostgREST.Error as Error @@ -51,8 +49,8 @@ import PostgREST.Plan (CallReadPlan (..), MutateReadPlan (..), WrappedReadPlan (..)) import PostgREST.Plan.MutatePlan (MutatePlan (..)) -import PostgREST.Query.SqlFragment (fromQi, intercalateSnippet, - pgFmtIdentList, +import PostgREST.Query.SqlFragment (escapeIdentList, fromQi, + intercalateSnippet, setConfigLocal, setConfigLocalJson) import PostgREST.Query.Statements (ResultSet (..)) @@ -254,7 +252,7 @@ setPgLocals AppConfig{..} claims role roleSettings req actualPgVersion = lift $ roleSettingsSql = setConfigLocal mempty <$> roleSettings appSettingsSql = setConfigLocal mempty <$> (join bimap toUtf8 <$> configAppSettings) searchPathSql = - let schemas = pgFmtIdentList (iSchema req : configDbExtraSearchPath) in + let schemas = escapeIdentList (iSchema req : configDbExtraSearchPath) in setConfigLocal mempty ("search_path", schemas) usesLegacyGucs = configDbUseLegacyGucs && actualPgVersion < pgVersion140 @@ -269,8 +267,7 @@ setPgLocals AppConfig{..} claims role roleSettings req actualPgVersion = lift $ runPreReq :: AppConfig -> DbHandler () runPreReq conf = lift $ traverse_ (SQL.statement mempty . stmt) (configDbPreRequest conf) where - stmt req = SQL.Statement + stmt req = SQL.dynamicallyParameterized ("select " <> fromQi req <> "()") - HE.noParams HD.noResult (configDbPreparedStatements conf) diff --git a/src/PostgREST/Query/QueryBuilder.hs b/src/PostgREST/Query/QueryBuilder.hs index 1ede09a4b0..17d5f3ff96 100644 --- a/src/PostgREST/Query/QueryBuilder.hs +++ b/src/PostgREST/Query/QueryBuilder.hs @@ -66,45 +66,43 @@ getSelectsJoins rr@(Node ReadPlan{select, relName, relToParent=Just rel, relAggA aliasOrName = pgFmtIdent $ fromMaybe relName relAlias aggAlias = pgFmtIdent relAggAlias correlatedSubquery sub al cond = - (if relJoinType == Just JTInner then "INNER" else "LEFT") <> " JOIN LATERAL ( " <> sub <> " ) AS " <> SQL.sql al <> " ON " <> cond + (if relJoinType == Just JTInner then "INNER" else "LEFT") <> " JOIN LATERAL ( " <> sub <> " ) AS " <> al <> " ON " <> cond (sel, joi) = if relIsToOne rel then ( if relIsSpread - then SQL.sql aggAlias <> ".*" - else SQL.sql ("row_to_json(" <> aggAlias <> ".*) AS " <> aliasOrName) + then aggAlias <> ".*" + else "row_to_json(" <> aggAlias <> ".*) AS " <> aliasOrName , correlatedSubquery subquery aggAlias "TRUE") else - ( SQL.sql $ "COALESCE( " <> aggAlias <> "." <> aggAlias <> ", '[]') AS " <> aliasOrName + ( "COALESCE( " <> aggAlias <> "." <> aggAlias <> ", '[]') AS " <> aliasOrName , correlatedSubquery ( - "SELECT json_agg(" <> SQL.sql aggAlias <> ") AS " <> SQL.sql aggAlias <> - "FROM (" <> subquery <> " ) AS " <> SQL.sql aggAlias - ) aggAlias $ if relJoinType == Just JTInner then SQL.sql aggAlias <> " IS NOT NULL" else "TRUE") + "SELECT json_agg(" <> aggAlias <> ") AS " <> aggAlias <> + "FROM (" <> subquery <> " ) AS " <> aggAlias + ) aggAlias $ if relJoinType == Just JTInner then aggAlias <> " IS NOT NULL" else "TRUE") in (if null select && null forest then selects else sel:selects, joi:joins) mutatePlanToQuery :: MutatePlan -> SQL.Snippet mutatePlanToQuery (Insert mainQi iCols body onConflct putConditions returnings _ applyDefaults) = - "INSERT INTO " <> SQL.sql (fromQi mainQi) <> SQL.sql (if null iCols then " " else "(" <> cols <> ") ") <> + "INSERT INTO " <> fromQi mainQi <> (if null iCols then " " else "(" <> cols <> ") ") <> fromJsonBodyF body iCols True False applyDefaults <> -- Only used for PUT (if null putConditions then mempty else "WHERE " <> intercalateSnippet " AND " (pgFmtLogicTree (QualifiedIdentifier mempty "pgrst_body") <$> putConditions)) <> - SQL.sql (BS.unwords [ - maybe mempty (\(oncDo, oncCols) -> - if null oncCols then - mempty - else - " ON CONFLICT(" <> BS.intercalate ", " (pgFmtIdent <$> oncCols) <> ") " <> case oncDo of - IgnoreDuplicates -> - "DO NOTHING" - MergeDuplicates -> - if null iCols - then "DO NOTHING" - else "DO UPDATE SET " <> BS.intercalate ", " ((pgFmtIdent . tfName) <> const " = EXCLUDED." <> (pgFmtIdent . tfName) <$> iCols) - ) onConflct, - returningF mainQi returnings - ]) + maybe mempty (\(oncDo, oncCols) -> + if null oncCols then + mempty + else + " ON CONFLICT(" <> intercalateSnippet ", " (pgFmtIdent <$> oncCols) <> ") " <> case oncDo of + IgnoreDuplicates -> + "DO NOTHING" + MergeDuplicates -> + if null iCols + then "DO NOTHING" + else "DO UPDATE SET " <> intercalateSnippet ", " ((pgFmtIdent . tfName) <> const " = EXCLUDED." <> (pgFmtIdent . tfName) <$> iCols) + ) onConflct <> " " <> + returningF mainQi returnings where - cols = BS.intercalate ", " $ pgFmtIdent . tfName <$> iCols + cols = intercalateSnippet ", " $ pgFmtIdent . tfName <$> iCols -- An update without a limit is always filtered with a WHERE mutatePlanToQuery (Update mainQi uCols body logicForest range ordts returnings applyDefaults) @@ -112,54 +110,54 @@ mutatePlanToQuery (Update mainQi uCols body logicForest range ordts returnings a -- if there are no columns we cannot do UPDATE table SET {empty}, it'd be invalid syntax -- selecting an empty resultset from mainQi gives us the column names to prevent errors when using &select= -- the select has to be based on "returnings" to make computed overloaded functions not throw - SQL.sql $ "SELECT " <> emptyBodyReturnedColumns <> " FROM " <> fromQi mainQi <> " WHERE false" + "SELECT " <> emptyBodyReturnedColumns <> " FROM " <> fromQi mainQi <> " WHERE false" | range == allRange = - "UPDATE " <> mainTbl <> " SET " <> SQL.sql nonRangeCols <> " " <> + "UPDATE " <> mainTbl <> " SET " <> nonRangeCols <> " " <> fromJsonBodyF body uCols False False applyDefaults <> whereLogic <> " " <> - SQL.sql (returningF mainQi returnings) + returningF mainQi returnings | otherwise = "WITH " <> "pgrst_update_body AS (" <> fromJsonBodyF body uCols True True applyDefaults <> "), " <> "pgrst_affected_rows AS (" <> - "SELECT " <> SQL.sql rangeIdF <> " FROM " <> mainTbl <> + "SELECT " <> rangeIdF <> " FROM " <> mainTbl <> whereLogic <> " " <> orderF mainQi ordts <> " " <> limitOffsetF range <> ") " <> - "UPDATE " <> mainTbl <> " SET " <> SQL.sql rangeCols <> + "UPDATE " <> mainTbl <> " SET " <> rangeCols <> "FROM pgrst_affected_rows " <> - "WHERE " <> SQL.sql whereRangeIdF <> " " <> - SQL.sql (returningF mainQi returnings) + "WHERE " <> whereRangeIdF <> " " <> + returningF mainQi returnings where whereLogic = if null logicForest then mempty else " WHERE " <> intercalateSnippet " AND " (pgFmtLogicTree mainQi <$> logicForest) - mainTbl = SQL.sql (fromQi mainQi) - emptyBodyReturnedColumns = if null returnings then "NULL" else BS.intercalate ", " (pgFmtColumn (QualifiedIdentifier mempty $ qiName mainQi) <$> returnings) - nonRangeCols = BS.intercalate ", " (pgFmtIdent . tfName <> const " = " <> pgFmtColumn (QualifiedIdentifier mempty "pgrst_body") . tfName <$> uCols) - rangeCols = BS.intercalate ", " ((\col -> pgFmtIdent (tfName col) <> " = (SELECT " <> pgFmtIdent (tfName col) <> " FROM pgrst_update_body) ") <$> uCols) + mainTbl = fromQi mainQi + emptyBodyReturnedColumns = if null returnings then "NULL" else intercalateSnippet ", " (pgFmtColumn (QualifiedIdentifier mempty $ qiName mainQi) <$> returnings) + nonRangeCols = intercalateSnippet ", " (pgFmtIdent . tfName <> const " = " <> pgFmtColumn (QualifiedIdentifier mempty "pgrst_body") . tfName <$> uCols) + rangeCols = intercalateSnippet ", " ((\col -> pgFmtIdent (tfName col) <> " = (SELECT " <> pgFmtIdent (tfName col) <> " FROM pgrst_update_body) ") <$> uCols) (whereRangeIdF, rangeIdF) = mutRangeF mainQi (fst . otTerm <$> ordts) mutatePlanToQuery (Delete mainQi logicForest range ordts returnings) | range == allRange = - "DELETE FROM " <> SQL.sql (fromQi mainQi) <> " " <> + "DELETE FROM " <> fromQi mainQi <> " " <> whereLogic <> " " <> - SQL.sql (returningF mainQi returnings) + returningF mainQi returnings | otherwise = "WITH " <> "pgrst_affected_rows AS (" <> - "SELECT " <> SQL.sql rangeIdF <> " FROM " <> SQL.sql (fromQi mainQi) <> + "SELECT " <> rangeIdF <> " FROM " <> fromQi mainQi <> whereLogic <> " " <> orderF mainQi ordts <> " " <> limitOffsetF range <> ") " <> - "DELETE FROM " <> SQL.sql (fromQi mainQi) <> " " <> + "DELETE FROM " <> fromQi mainQi <> " " <> "USING pgrst_affected_rows " <> - "WHERE " <> SQL.sql whereRangeIdF <> " " <> - SQL.sql (returningF mainQi returnings) + "WHERE " <> whereRangeIdF <> " " <> + returningF mainQi returnings where whereLogic = if null logicForest then mempty else " WHERE " <> intercalateSnippet " AND " (pgFmtLogicTree mainQi <$> logicForest) @@ -177,17 +175,17 @@ callPlanToQuery (FunctionCall qi params args returnsScalar returnsSetOfScalar re "LATERAL " <> callIt (fmtParams prms) callIt :: SQL.Snippet -> SQL.Snippet - callIt argument | pgVer < pgVersion130 && pgVer >= pgVersion110 && returnsCompositeAlias = "(SELECT (" <> SQL.sql (fromQi qi) <> "(" <> argument <> ")).*) pgrst_call" - | otherwise = SQL.sql (fromQi qi) <> "(" <> argument <> ") pgrst_call" + callIt argument | pgVer < pgVersion130 && pgVer >= pgVersion110 && returnsCompositeAlias = "(SELECT (" <> fromQi qi <> "(" <> argument <> ")).*) pgrst_call" + | otherwise = fromQi qi <> "(" <> argument <> ") pgrst_call" fmtParams :: [RoutineParam] -> SQL.Snippet - fmtParams prms = SQL.sql $ BS.intercalate ", " + fmtParams prms = intercalateSnippet ", " ((\a -> (if ppVar a then "VARIADIC " else mempty) <> pgFmtIdent (ppName a) <> " := pgrst_body." <> pgFmtIdent (ppName a)) <$> prms) returnedColumns :: SQL.Snippet returnedColumns | null returnings = "*" - | otherwise = SQL.sql $ BS.intercalate ", " (pgFmtColumn (QualifiedIdentifier mempty "pgrst_call") <$> returnings) + | otherwise = intercalateSnippet ", " (pgFmtColumn (QualifiedIdentifier mempty "pgrst_call") <$> returnings) -- | SQL query meant for COUNTing the root node of the Tree. -- It only takes WHERE into account and doesn't include LIMIT/OFFSET because it would reduce the COUNT. @@ -229,7 +227,7 @@ getQualifiedIdentifier rel mainQi tblAlias = case rel of -- FROM clause plus implicit joins fromF :: Maybe Relationship -> QualifiedIdentifier -> Maybe Alias -> SQL.Snippet -fromF rel mainQi tblAlias = SQL.sql $ "FROM " <> +fromF rel mainQi tblAlias = "FROM " <> (case rel of Just ComputedRelationship{relFunction,relTable} -> fromQi relFunction <> "(" <> pgFmtIdent (qiName relTable) <> ")" _ -> fromQi mainQi) <> diff --git a/src/PostgREST/Query/SqlFragment.hs b/src/PostgREST/Query/SqlFragment.hs index 3834eec229..f30099ae7a 100644 --- a/src/PostgREST/Query/SqlFragment.hs +++ b/src/PostgREST/Query/SqlFragment.hs @@ -4,12 +4,9 @@ {-| Module : PostgREST.Query.SqlFragment Description : Helper functions for PostgREST.QueryBuilder. - -Any function that outputs a SqlFragment should be in this module. -} module PostgREST.Query.SqlFragment ( noLocationF - , SqlFragment , asBinaryF , asCsvF , asGeoJsonF @@ -24,7 +21,6 @@ module PostgREST.Query.SqlFragment , orderF , pgFmtColumn , pgFmtIdent - , pgFmtIdentList , pgFmtJoinCondition , pgFmtLogicTree , pgFmtOrderTerm @@ -34,12 +30,15 @@ module PostgREST.Query.SqlFragment , responseStatusF , returningF , singleParameter + , sourceCTE , sourceCTEName , unknownEncoder , intercalateSnippet , explainF , setConfigLocal , setConfigLocalJson + , escapeIdent + , escapeIdentList ) where import qualified Data.Aeson as JSON @@ -87,17 +86,16 @@ import PostgREST.SchemaCache.Routine (Routine (..), import Protolude hiding (cast) +sourceCTEName :: Text +sourceCTEName = "pgrst_source" --- | A part of a SQL query that cannot be executed independently -type SqlFragment = ByteString +sourceCTE :: SQL.Snippet +sourceCTE = "pgrst_source" -noLocationF :: SqlFragment +noLocationF :: SQL.Snippet noLocationF = "array[]::text[]" -sourceCTEName :: SqlFragment -sourceCTEName = "pgrst_source" - -simpleOperator :: SimpleOperator -> SqlFragment +simpleOperator :: SimpleOperator -> SQL.Snippet simpleOperator = \case OpNotEqual -> "<>" OpContains -> "@>" @@ -109,7 +107,7 @@ simpleOperator = \case OpNotExtendsLeft -> "&>" OpAdjacent -> "-|-" -quantOperator :: QuantOperator -> SqlFragment +quantOperator :: QuantOperator -> SQL.Snippet quantOperator = \case OpEqual -> "=" OpGreaterThanEqual -> ">=" @@ -121,7 +119,7 @@ quantOperator = \case OpMatch -> "~" OpIMatch -> "~*" -ftsOperator :: FtsOperator -> SqlFragment +ftsOperator :: FtsOperator -> SQL.Snippet ftsOperator = \case FilterFts -> "@@ to_tsquery" FilterFtsPlain -> "@@ plainto_tsquery" @@ -149,8 +147,11 @@ pgBuildArrayLiteral vals = "{" <> T.intercalate "," (escaped <$> vals) <> "}" -- TODO: refactor by following https://github.com/PostgREST/postgrest/pull/1631#issuecomment-711070833 -pgFmtIdent :: Text -> SqlFragment -pgFmtIdent x = encodeUtf8 $ "\"" <> T.replace "\"" "\"\"" (trimNullChars x) <> "\"" +pgFmtIdent :: Text -> SQL.Snippet +pgFmtIdent x = SQL.sql $ escapeIdent x + +escapeIdent :: Text -> ByteString +escapeIdent x = encodeUtf8 $ "\"" <> T.replace "\"" "\"\"" (trimNullChars x) <> "\"" -- Only use it if the input comes from the database itself, like on `jsonb_build_object('column_from_a_table', val)..` pgFmtLit :: Text -> Text @@ -168,12 +169,12 @@ trimNullChars = T.takeWhile (/= '\x0') -- | -- Format a list of identifiers and separate them by commas. -- --- >>> pgFmtIdentList ["schema_1", "schema_2", "SPECIAL \"@/\\#~_-"] +-- >>> escapeIdentList ["schema_1", "schema_2", "SPECIAL \"@/\\#~_-"] -- "\"schema_1\", \"schema_2\", \"SPECIAL \"\"@/\\#~_-\"" -pgFmtIdentList :: [Text] -> SqlFragment -pgFmtIdentList schemas = BS.intercalate ", " $ pgFmtIdent <$> schemas +escapeIdentList :: [Text] -> ByteString +escapeIdentList schemas = BS.intercalate ", " $ escapeIdent <$> schemas -asCsvF :: SqlFragment +asCsvF :: SQL.Snippet asCsvF = asCsvHeaderF <> " || '\n' || " <> asCsvBodyF where asCsvHeaderF = @@ -181,20 +182,20 @@ asCsvF = asCsvHeaderF <> " || '\n' || " <> asCsvBodyF " FROM (" <> " SELECT json_object_keys(r)::text as k" <> " FROM ( " <> - " SELECT row_to_json(hh) as r from " <> sourceCTEName <> " as hh limit 1" <> + " SELECT row_to_json(hh) as r from " <> sourceCTE <> " as hh limit 1" <> " ) s" <> " ) a" <> ")" asCsvBodyF = "coalesce(string_agg(substring(_postgrest_t::text, 2, length(_postgrest_t::text) - 2), '\n'), '')" -asJsonSingleF :: Maybe Routine -> SqlFragment +asJsonSingleF :: Maybe Routine -> SQL.Snippet asJsonSingleF rout | returnsScalar = "coalesce(json_agg(_postgrest_t.pgrst_scalar)->0, 'null')" | otherwise = "coalesce(json_agg(_postgrest_t)->0, 'null')" where returnsScalar = maybe False funcReturnsScalar rout -asJsonF :: Maybe Routine -> SqlFragment +asJsonF :: Maybe Routine -> SQL.Snippet asJsonF rout | returnsSingleComposite = "coalesce(json_agg(_postgrest_t)->0, 'null')" | returnsScalar = "coalesce(json_agg(_postgrest_t.pgrst_scalar)->0, 'null')" @@ -205,16 +206,16 @@ asJsonF rout Just r -> (funcReturnsSingleComposite r, funcReturnsScalar r, funcReturnsSetOfScalar r) Nothing -> (False, False, False) -asXmlF :: FieldName -> SqlFragment +asXmlF :: FieldName -> SQL.Snippet asXmlF fieldName = "coalesce(xmlagg(_postgrest_t." <> pgFmtIdent fieldName <> "), '')" -asGeoJsonF :: SqlFragment +asGeoJsonF :: SQL.Snippet asGeoJsonF = "json_build_object('type', 'FeatureCollection', 'features', coalesce(json_agg(ST_AsGeoJSON(_postgrest_t)::json), '[]'))" -asBinaryF :: FieldName -> SqlFragment +asBinaryF :: FieldName -> SQL.Snippet asBinaryF fieldName = "coalesce(string_agg(_postgrest_t." <> pgFmtIdent fieldName <> ", ''), '')" -locationF :: [Text] -> SqlFragment +locationF :: [Text] -> SQL.Snippet locationF pKeys = [qc|( WITH data AS (SELECT row_to_json(_) AS row FROM {sourceCTEName} AS _ LIMIT 1) SELECT array_agg(json_data.key || '=' || coalesce('eq.' || json_data.value, 'is.null')) @@ -224,33 +225,32 @@ locationF pKeys = [qc|( where fmtPKeys = T.intercalate "','" pKeys -fromQi :: QualifiedIdentifier -> SqlFragment +fromQi :: QualifiedIdentifier -> SQL.Snippet fromQi t = (if T.null s then mempty else pgFmtIdent s <> ".") <> pgFmtIdent n where n = qiName t s = qiSchema t -pgFmtColumn :: QualifiedIdentifier -> Text -> SqlFragment +pgFmtColumn :: QualifiedIdentifier -> Text -> SQL.Snippet pgFmtColumn table "*" = fromQi table <> ".*" pgFmtColumn table c = fromQi table <> "." <> pgFmtIdent c pgFmtField :: QualifiedIdentifier -> Field -> SQL.Snippet -pgFmtField table (c, []) = SQL.sql (pgFmtColumn table c) +pgFmtField table (c, []) = pgFmtColumn table c -- Using to_jsonb instead of to_json to avoid missing operator errors when filtering: -- "operator does not exist: json = unknown" -pgFmtField table (c, jp) = SQL.sql ("to_jsonb(" <> pgFmtColumn table c <> ")") <> pgFmtJsonPath jp +pgFmtField table (c, jp) = "to_jsonb(" <> pgFmtColumn table c <> ")" <> pgFmtJsonPath jp pgFmtSelectItem :: QualifiedIdentifier -> (Field, Maybe Cast, Maybe Alias) -> SQL.Snippet -pgFmtSelectItem table (f@(fName, jp), Nothing, alias) = pgFmtField table f <> SQL.sql (pgFmtAs fName jp alias) +pgFmtSelectItem table (f@(fName, jp), Nothing, alias) = pgFmtField table f <> pgFmtAs fName jp alias -- Ideally we'd quote the cast with "pgFmtIdent cast". However, that would invalidate common casts such as "int", "bigint", etc. -- Try doing: `select 1::"bigint"` - it'll err, using "int8" will work though. There's some parser magic that pg does that's invalidated when quoting. -- Not quoting should be fine, we validate the input on Parsers. -pgFmtSelectItem table (f@(fName, jp), Just cast, alias) = "CAST (" <> pgFmtField table f <> " AS " <> SQL.sql (encodeUtf8 cast) <> " )" <> SQL.sql (pgFmtAs fName jp alias) +pgFmtSelectItem table (f@(fName, jp), Just cast, alias) = "CAST (" <> pgFmtField table f <> " AS " <> SQL.sql (encodeUtf8 cast) <> " )" <> pgFmtAs fName jp alias -- TODO: At this stage there shouldn't be a Maybe since ApiRequest should ensure that an INSERT/UPDATE has a body fromJsonBodyF :: Maybe LBS.ByteString -> [TypedField] -> Bool -> Bool -> Bool -> SQL.Snippet fromJsonBodyF body fields includeSelect includeLimitOne includeDefaults = - SQL.sql (if includeSelect then "SELECT " <> parsedCols <> " " else mempty) <> "FROM (SELECT " <> jsonPlaceHolder <> " AS json_data) pgrst_payload, " <> -- convert a json object into a json array, this way we can use json_to_recordset for all json payloads @@ -266,12 +266,12 @@ fromJsonBodyF body fields includeSelect includeLimitOne includeDefaults = -- because it can't extract records with no columns (there's no valid syntax for the `AS (colName colType,...)` -- part). But we still need to ensure as many rows are created as there are array elements. then SQL.sql $ jsonArrayElementsF <> "(" <> finalBodyF <> ") _ " - else SQL.sql $ jsonToRecordsetF <> "(" <> finalBodyF <> ") AS _(" <> typedCols <> ") " <> if includeLimitOne then "LIMIT 1" else mempty + else jsonToRecordsetF <> "(" <> SQL.sql finalBodyF <> ") AS _(" <> typedCols <> ") " <> if includeLimitOne then "LIMIT 1" else mempty ) <> ") pgrst_body " where - parsedCols = BS.intercalate ", " $ fromQi . QualifiedIdentifier "pgrst_body" . tfName <$> fields - typedCols = BS.intercalate ", " $ pgFmtIdent . tfName <> const " " <> encodeUtf8 . tfIRType <$> fields + parsedCols = intercalateSnippet ", " $ fromQi . QualifiedIdentifier "pgrst_body" . tfName <$> fields + typedCols = intercalateSnippet ", " $ pgFmtIdent . tfName <> const " " <> SQL.sql . encodeUtf8 . tfIRType <$> fields defsJsonb = SQL.sql $ BS.intercalate "," fieldsWDefaults fieldsWDefaults = mapMaybe (\case TypedField{tfName=nam, tfDefault=Just def} -> Just $ encodeUtf8 (pgFmtLit nam <> ", " <> def) @@ -302,12 +302,12 @@ pgFmtOrderTerm qi ot = pgFmtFilter :: QualifiedIdentifier -> Filter -> SQL.Snippet -pgFmtFilter _ (FilterNullEmbed hasNot fld) = SQL.sql (pgFmtIdent fld) <> " IS " <> (if hasNot then "NOT" else mempty) <> " NULL" +pgFmtFilter _ (FilterNullEmbed hasNot fld) = pgFmtIdent fld <> " IS " <> (if hasNot then "NOT" else mempty) <> " NULL" pgFmtFilter _ (Filter _ (NoOpExpr _)) = mempty -- TODO unreachable because NoOpExpr is filtered on QueryParams pgFmtFilter table (Filter fld (OpExpr hasNot oper)) = notOp <> " " <> pgFmtField table fld <> case oper of - Op op val -> " " <> SQL.sql (simpleOperator op) <> " " <> unknownLiteral val + Op op val -> " " <> simpleOperator op <> " " <> unknownLiteral val - OpQuant op quant val -> " " <> SQL.sql (quantOperator op) <> " " <> case op of + OpQuant op quant val -> " " <> quantOperator op <> " " <> case op of OpLike -> fmtQuant quant $ unknownLiteral (T.map star val) OpILike -> fmtQuant quant $ unknownLiteral (T.map star val) _ -> fmtQuant quant $ unknownLiteral val @@ -331,7 +331,7 @@ pgFmtFilter table (Filter fld (OpExpr hasNot oper)) = notOp <> " " <> pgFmtField [""] -> "= ANY('{}') " _ -> "= ANY (" <> unknownLiteral (pgBuildArrayLiteral vals) <> ") " - Fts op lang val -> " " <> SQL.sql (ftsOperator op) <> "(" <> ftsLang lang <> unknownLiteral val <> ") " + Fts op lang val -> " " <> ftsOperator op <> "(" <> ftsLang lang <> unknownLiteral val <> ") " where ftsLang = maybe mempty (\l -> unknownLiteral l <> ", ") notOp = if hasNot then "NOT" else mempty @@ -343,7 +343,7 @@ pgFmtFilter table (Filter fld (OpExpr hasNot oper)) = notOp <> " " <> pgFmtField pgFmtJoinCondition :: JoinCondition -> SQL.Snippet pgFmtJoinCondition (JoinCondition (qi1, col1) (qi2, col2)) = - SQL.sql $ pgFmtColumn qi1 col1 <> " = " <> pgFmtColumn qi2 col2 + pgFmtColumn qi1 col1 <> " = " <> pgFmtColumn qi2 col2 pgFmtLogicTree :: QualifiedIdentifier -> LogicTree -> SQL.Snippet pgFmtLogicTree qi (Expr hasNot op forest) = SQL.sql notOp <> " (" <> intercalateSnippet (opSql op) (pgFmtLogicTree qi <$> forest) <> ")" @@ -363,7 +363,7 @@ pgFmtJsonPath = \case pgFmtJsonOperand (JKey k) = unknownLiteral k pgFmtJsonOperand (JIdx i) = unknownLiteral i <> "::int" -pgFmtAs :: FieldName -> JsonPath -> Maybe Alias -> SqlFragment +pgFmtAs :: FieldName -> JsonPath -> Maybe Alias -> SQL.Snippet pgFmtAs _ [] Nothing = mempty pgFmtAs fName jp Nothing = case jOp <$> lastMay jp of Just (JKey key) -> " AS " <> pgFmtIdent key @@ -375,7 +375,7 @@ pgFmtAs fName jp Nothing = case jOp <$> lastMay jp of Nothing -> mempty pgFmtAs _ _ (Just alias) = " AS " <> pgFmtIdent alias -countF :: SQL.Snippet -> Bool -> (SQL.Snippet, SqlFragment) +countF :: SQL.Snippet -> Bool -> (SQL.Snippet, SQL.Snippet) countF countQuery shouldCount = if shouldCount then ( @@ -385,11 +385,11 @@ countF countQuery shouldCount = mempty , "null::bigint") -returningF :: QualifiedIdentifier -> [FieldName] -> SqlFragment +returningF :: QualifiedIdentifier -> [FieldName] -> SQL.Snippet returningF qi returnings = if null returnings then "RETURNING 1" -- For mutation cases where there's no ?select, we return 1 to know how many rows were modified - else "RETURNING " <> BS.intercalate ", " (pgFmtColumn qi <$> returnings) + else "RETURNING " <> intercalateSnippet ", " (pgFmtColumn qi <$> returnings) limitOffsetF :: NonnegRange -> SQL.Snippet limitOffsetF range = @@ -398,22 +398,22 @@ limitOffsetF range = limit = maybe "ALL" (\l -> unknownEncoder (BS.pack $ show l)) $ rangeLimit range offset = unknownEncoder (BS.pack . show $ rangeOffset range) -responseHeadersF :: SqlFragment +responseHeadersF :: SQL.Snippet responseHeadersF = currentSettingF "response.headers" -responseStatusF :: SqlFragment +responseStatusF :: SQL.Snippet responseStatusF = currentSettingF "response.status" -currentSettingF :: SqlFragment -> SqlFragment +currentSettingF :: SQL.Snippet -> SQL.Snippet currentSettingF setting = -- nullif is used because of https://gist.github.com/steve-chavez/8d7033ea5655096903f3b52f8ed09a15 "nullif(current_setting('" <> setting <> "', true), '')" -mutRangeF :: QualifiedIdentifier -> [FieldName] -> (SqlFragment, SqlFragment) +mutRangeF :: QualifiedIdentifier -> [FieldName] -> (SQL.Snippet, SQL.Snippet) mutRangeF mainQi rangeId = ( - BS.intercalate " AND " $ (\col -> pgFmtColumn mainQi col <> " = " <> pgFmtColumn (QualifiedIdentifier mempty "pgrst_affected_rows") col) <$> rangeId - , BS.intercalate ", " (pgFmtColumn mainQi <$> rangeId) + intercalateSnippet " AND " $ (\col -> pgFmtColumn mainQi col <> " = " <> pgFmtColumn (QualifiedIdentifier mempty "pgrst_affected_rows") col) <$> rangeId + , intercalateSnippet ", " (pgFmtColumn mainQi <$> rangeId) ) orderF :: QualifiedIdentifier -> [OrderTerm] -> SQL.Snippet diff --git a/src/PostgREST/Query/Statements.hs b/src/PostgREST/Query/Statements.hs index 24840f46c7..a1ad4eb0f8 100644 --- a/src/PostgREST/Query/Statements.hs +++ b/src/PostgREST/Query/Statements.hs @@ -61,25 +61,23 @@ prepareWrite selectQuery mutateQuery isInsert mt rep pKeys = SQL.dynamicallyParameterized (mtSnippet mt snippet) decodeIt where snippet = - "WITH " <> SQL.sql sourceCTEName <> " AS (" <> mutateQuery <> ") " <> - SQL.sql ( + "WITH " <> sourceCTE <> " AS (" <> mutateQuery <> ") " <> "SELECT " <> "'' AS total_result_set, " <> "pg_catalog.count(_postgrest_t) AS page_total, " <> locF <> " AS header, " <> bodyF <> " AS body, " <> responseHeadersF <> " AS response_headers, " <> - responseStatusF <> " AS response_status " - ) <> + responseStatusF <> " AS response_status " <> "FROM (" <> selectF <> ") _postgrest_t" locF = if isInsert && rep == HeadersOnly - then BS.unwords [ - "CASE WHEN pg_catalog.count(_postgrest_t) = 1", - "THEN coalesce(" <> locationF pKeys <> ", " <> noLocationF <> ")", - "ELSE " <> noLocationF, - "END"] + then + "CASE WHEN pg_catalog.count(_postgrest_t) = 1 " <> + "THEN coalesce(" <> locationF pKeys <> ", " <> noLocationF <> ") " <> + "ELSE " <> noLocationF <> " " <> + "END" else noLocationF bodyF @@ -91,7 +89,7 @@ prepareWrite selectQuery mutateQuery isInsert mt rep pKeys = selectF -- prevent using any of the column names in ?select= when no response is returned from the CTE - | rep /= Full = SQL.sql ("SELECT * FROM " <> sourceCTEName) + | rep /= Full = "SELECT * FROM " <> sourceCTE | otherwise = selectQuery decodeIt :: HD.Result ResultSet @@ -104,16 +102,15 @@ prepareRead selectQuery countQuery countTotal mt binaryField = SQL.dynamicallyParameterized (mtSnippet mt snippet) decodeIt where snippet = - "WITH " <> - SQL.sql sourceCTEName <> " AS ( " <> selectQuery <> " ) " <> + "WITH " <> sourceCTE <> " AS ( " <> selectQuery <> " ) " <> countCTEF <> " " <> - SQL.sql ("SELECT " <> + "SELECT " <> countResultF <> " AS total_result_set, " <> "pg_catalog.count(_postgrest_t) AS page_total, " <> bodyF <> " AS body, " <> responseHeadersF <> " AS response_headers, " <> responseStatusF <> " AS response_status " <> - "FROM ( SELECT * FROM " <> sourceCTEName <> " ) _postgrest_t") + "FROM ( SELECT * FROM " <> sourceCTE <> " ) _postgrest_t" (countCTEF, countResultF) = countF countQuery countTotal @@ -137,15 +134,14 @@ prepareCall rout callProcQuery selectQuery countQuery countTotal mt binaryField SQL.dynamicallyParameterized (mtSnippet mt snippet) decodeIt where snippet = - "WITH " <> SQL.sql sourceCTEName <> " AS (" <> callProcQuery <> ") " <> + "WITH " <> sourceCTE <> " AS (" <> callProcQuery <> ") " <> countCTEF <> - SQL.sql ( "SELECT " <> countResultF <> " AS total_result_set, " <> "pg_catalog.count(_postgrest_t) AS page_total, " <> bodyF <> " AS body, " <> responseHeadersF <> " AS response_headers, " <> - responseStatusF <> " AS response_status ") <> + responseStatusF <> " AS response_status " <> "FROM (" <> selectQuery <> ") _postgrest_t" (countCTEF, countResultF) = countF countQuery countTotal