diff --git a/CHANGELOG.md b/CHANGELOG.md index b54256e552a..87eb986d946 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ This project adheres to [Semantic Versioning](http://semver.org/). ### Added - #2887, Add Preference `max-affected` to limit affected resources - @taimoorzaeem + - #3061, Apply all function settings as transaction-scoped settings - @taimoorzaeem ### Fixed diff --git a/src/PostgREST/App.hs b/src/PostgREST/App.hs index 3c3b187707e..98805fb7b57 100644 --- a/src/PostgREST/App.hs +++ b/src/PostgREST/App.hs @@ -43,19 +43,21 @@ import qualified PostgREST.Query as Query import qualified PostgREST.Response as Response import qualified PostgREST.Unix as Unix (installSignalHandlers) -import PostgREST.ApiRequest (Action (..), ApiRequest (..), - Mutation (..), Target (..)) -import PostgREST.AppState (AppState) -import PostgREST.Auth (AuthResult (..)) -import PostgREST.Config (AppConfig (..)) -import PostgREST.Config.PgVersion (PgVersion (..)) -import PostgREST.Error (Error) -import PostgREST.Query (DbHandler) -import PostgREST.Response.Performance (ServerTiming (..), - serverTimingHeader) -import PostgREST.SchemaCache (SchemaCache (..)) -import PostgREST.SchemaCache.Routine (Routine (..)) -import PostgREST.Version (docsVersion, prettyVersion) +import PostgREST.ApiRequest (Action (..), + ApiRequest (..), + Mutation (..), Target (..)) +import PostgREST.AppState (AppState) +import PostgREST.Auth (AuthResult (..)) +import PostgREST.Config (AppConfig (..)) +import PostgREST.Config.PgVersion (PgVersion (..)) +import PostgREST.Error (Error) +import PostgREST.Query (DbHandler) +import PostgREST.Response.Performance (ServerTiming (..), + serverTimingHeader) +import PostgREST.SchemaCache (SchemaCache (..)) +import PostgREST.SchemaCache.Identifiers (QualifiedIdentifier (..)) +import PostgREST.SchemaCache.Routine (Routine (..)) +import PostgREST.Version (docsVersion, prettyVersion) import qualified Data.ByteString.Char8 as BS import qualified Data.List as L @@ -170,43 +172,44 @@ handleRequest AuthResult{..} conf appState authenticated prepared pgVer apiReq@A case (iAction, iTarget) of (ActionRead headersOnly, TargetIdent identifier) -> do (planTime', wrPlan) <- withTiming $ liftEither $ Plan.wrappedReadPlan identifier conf sCache apiReq - (txTime', resultSet) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.wrTxMode wrPlan) $ Query.readQuery wrPlan conf apiReq + (txTime', resultSet) <- withTiming $ runQuery roleIsoLvl (Plan.wrTxMode wrPlan) mempty $ Query.readQuery wrPlan conf apiReq (respTime', pgrst) <- withTiming $ liftEither $ Response.readResponse wrPlan headersOnly identifier apiReq resultSet return $ pgrstResponse (ServerTiming jwtTime parseTime planTime' txTime' respTime') pgrst (ActionMutate MutationCreate, TargetIdent identifier) -> do (planTime', mrPlan) <- withTiming $ liftEither $ Plan.mutateReadPlan MutationCreate apiReq identifier conf sCache - (txTime', resultSet) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.mrTxMode mrPlan) $ Query.createQuery mrPlan apiReq conf + (txTime', resultSet) <- withTiming $ runQuery roleIsoLvl (Plan.mrTxMode mrPlan) mempty $ Query.createQuery mrPlan apiReq conf (respTime', pgrst) <- withTiming $ liftEither $ Response.createResponse identifier mrPlan apiReq resultSet return $ pgrstResponse (ServerTiming jwtTime parseTime planTime' txTime' respTime') pgrst (ActionMutate MutationUpdate, TargetIdent identifier) -> do (planTime', mrPlan) <- withTiming $ liftEither $ Plan.mutateReadPlan MutationUpdate apiReq identifier conf sCache - (txTime', resultSet) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.mrTxMode mrPlan) $ Query.updateQuery mrPlan apiReq conf + (txTime', resultSet) <- withTiming $ runQuery roleIsoLvl (Plan.mrTxMode mrPlan) mempty $ Query.updateQuery mrPlan apiReq conf (respTime', pgrst) <- withTiming $ liftEither $ Response.updateResponse mrPlan apiReq resultSet return $ pgrstResponse (ServerTiming jwtTime parseTime planTime' txTime' respTime') pgrst (ActionMutate MutationSingleUpsert, TargetIdent identifier) -> do (planTime', mrPlan) <- withTiming $ liftEither $ Plan.mutateReadPlan MutationSingleUpsert apiReq identifier conf sCache - (txTime', resultSet) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.mrTxMode mrPlan) $ Query.singleUpsertQuery mrPlan apiReq conf + (txTime', resultSet) <- withTiming $ runQuery roleIsoLvl (Plan.mrTxMode mrPlan) mempty $ Query.singleUpsertQuery mrPlan apiReq conf (respTime', pgrst) <- withTiming $ liftEither $ Response.singleUpsertResponse mrPlan apiReq resultSet return $ pgrstResponse (ServerTiming jwtTime parseTime planTime' txTime' respTime') pgrst (ActionMutate MutationDelete, TargetIdent identifier) -> do (planTime', mrPlan) <- withTiming $ liftEither $ Plan.mutateReadPlan MutationDelete apiReq identifier conf sCache - (txTime', resultSet) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.mrTxMode mrPlan) $ Query.deleteQuery mrPlan apiReq conf + (txTime', resultSet) <- withTiming $ runQuery roleIsoLvl (Plan.mrTxMode mrPlan) mempty $ Query.deleteQuery mrPlan apiReq conf (respTime', pgrst) <- withTiming $ liftEither $ Response.deleteResponse mrPlan apiReq resultSet return $ pgrstResponse (ServerTiming jwtTime parseTime planTime' txTime' respTime') pgrst - (ActionInvoke invMethod, TargetProc identifier _) -> do + (ActionInvoke invMethod, TargetProc identifier@(QualifiedIdentifier _ proname) _) -> do + let setting = [(y,z) | (x,y,z) <- funcSettings, x == encodeUtf8 proname] (planTime', cPlan) <- withTiming $ liftEither $ Plan.callReadPlan identifier conf sCache apiReq invMethod - (txTime', resultSet) <- withTiming $ runQuery (fromMaybe roleIsoLvl $ pdIsoLvl (Plan.crProc cPlan)) (pdTimeout $ Plan.crProc cPlan) (Plan.crTxMode cPlan) $ Query.invokeQuery (Plan.crProc cPlan) cPlan apiReq conf pgVer + (txTime', resultSet) <- withTiming $ runQuery (fromMaybe roleIsoLvl $ pdIsoLvl (Plan.crProc cPlan)) (Plan.crTxMode cPlan) setting $ Query.invokeQuery (Plan.crProc cPlan) cPlan apiReq conf pgVer (respTime', pgrst) <- withTiming $ liftEither $ Response.invokeResponse cPlan invMethod (Plan.crProc cPlan) apiReq resultSet return $ pgrstResponse (ServerTiming jwtTime parseTime planTime' txTime' respTime') pgrst (ActionInspect headersOnly, TargetDefaultSpec tSchema) -> do (planTime', iPlan) <- withTiming $ liftEither $ Plan.inspectPlan apiReq - (txTime', oaiResult) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.ipTxmode iPlan) $ Query.openApiQuery sCache pgVer conf tSchema + (txTime', oaiResult) <- withTiming $ runQuery roleIsoLvl (Plan.ipTxmode iPlan) mempty $ Query.openApiQuery sCache pgVer conf tSchema (respTime', pgrst) <- withTiming $ liftEither $ Response.openApiResponse (T.decodeUtf8 prettyVersion, docsVersion) headersOnly oaiResult conf sCache iSchema iNegotiatedByProfile return $ pgrstResponse (ServerTiming jwtTime parseTime planTime' txTime' respTime') pgrst @@ -230,9 +233,10 @@ handleRequest AuthResult{..} conf appState authenticated prepared pgVer apiReq@A where roleSettings = fromMaybe mempty (HM.lookup authRole $ configRoleSettings conf) roleIsoLvl = HM.findWithDefault SQL.ReadCommitted authRole $ configRoleIsoLvl conf - runQuery isoLvl timeout mode query = + funcSettings = dbFuncSettings sCache + runQuery isoLvl mode funcSet query = runDbHandler appState conf isoLvl mode authenticated prepared $ do - Query.setPgLocals conf authClaims authRole (HM.toList roleSettings) apiReq timeout + Query.setPgLocals conf authClaims authRole (HM.toList roleSettings) funcSet apiReq Query.runPreReq conf query diff --git a/src/PostgREST/Config/Database.hs b/src/PostgREST/Config/Database.hs index 8180975fe71..5d5c59fe72c 100644 --- a/src/PostgREST/Config/Database.hs +++ b/src/PostgREST/Config/Database.hs @@ -8,6 +8,7 @@ module PostgREST.Config.Database , RoleSettings , RoleIsolationLvl , TimezoneNames + , FuncSettings , toIsolationLevel ) where @@ -31,6 +32,7 @@ import Protolude type RoleSettings = (HM.HashMap ByteString (HM.HashMap ByteString ByteString)) type RoleIsolationLvl = HM.HashMap ByteString SQL.IsolationLevel type TimezoneNames = Set ByteString -- cache timezone names for prefer timezone= +type FuncSettings = [(ByteString,ByteString,ByteString)] toIsolationLevel :: (Eq a, IsString a) => a -> SQL.IsolationLevel toIsolationLevel a = case a of diff --git a/src/PostgREST/Query.hs b/src/PostgREST/Query.hs index 7242dce1407..c47164b4fc2 100644 --- a/src/PostgREST/Query.hs +++ b/src/PostgREST/Query.hs @@ -247,12 +247,12 @@ optionalRollback AppConfig{..} ApiRequest{iPreferences=Preferences{..}} = do -- | Set transaction scoped settings setPgLocals :: AppConfig -> KM.KeyMap JSON.Value -> BS.ByteString -> [(ByteString, ByteString)] -> - ApiRequest -> Maybe Text -> DbHandler () -setPgLocals AppConfig{..} claims role roleSettings ApiRequest{..} tout = lift $ + [(ByteString,ByteString)] -> ApiRequest -> DbHandler () +setPgLocals AppConfig{..} claims role roleSettings funcSetting ApiRequest{..} = lift $ SQL.statement mempty $ SQL.dynamicallyParameterized -- To ensure `GRANT SET ON PARAMETER TO authenticator` works, the role settings must be set before the impersonated role. -- Otherwise the GRANT SET would have to be applied to the impersonated role. See https://github.com/PostgREST/postgrest/issues/3045 - ("select " <> intercalateSnippet ", " (searchPathSql : roleSettingsSql ++ roleSql ++ claimsSql ++ [methodSql, pathSql] ++ headersSql ++ cookiesSql ++ timezoneSql ++ timeoutSql ++ appSettingsSql)) + ("select " <> intercalateSnippet ", " (searchPathSql : roleSettingsSql ++ roleSql ++ claimsSql ++ [methodSql, pathSql] ++ headersSql ++ cookiesSql ++ timezoneSql ++ funcSettingSql ++ appSettingsSql)) HD.noResult configDbPreparedStatements where methodSql = setConfigWithConstantName ("request.method", iMethod) @@ -264,7 +264,7 @@ setPgLocals AppConfig{..} claims role roleSettings ApiRequest{..} tout = lift $ roleSettingsSql = setConfigWithDynamicName <$> roleSettings appSettingsSql = setConfigWithDynamicName <$> (join bimap toUtf8 <$> configAppSettings) timezoneSql = maybe mempty (\(PreferTimezone tz) -> [setConfigWithConstantName ("timezone", tz)]) $ preferTimezone iPreferences - timeoutSql = maybe mempty ((\t -> [setConfigWithConstantName ("statement_timeout", t)]) . encodeUtf8) tout + funcSettingSql = setConfigWithDynamicName <$> funcSetting searchPathSql = let schemas = escapeIdentList (iSchema : configDbExtraSearchPath) in setConfigWithConstantName ("search_path", schemas) diff --git a/src/PostgREST/SchemaCache.hs b/src/PostgREST/SchemaCache.hs index 65d0416d37c..2b1e91a3cfa 100644 --- a/src/PostgREST/SchemaCache.hs +++ b/src/PostgREST/SchemaCache.hs @@ -43,12 +43,14 @@ import Contravariant.Extras (contrazip2) import Text.InterpolatedString.Perl6 (q) import PostgREST.Config (AppConfig (..)) -import PostgREST.Config.Database (TimezoneNames, +import PostgREST.Config.Database (FuncSettings, + TimezoneNames, pgVersionStatement, toIsolationLevel) import PostgREST.Config.PgVersion (PgVersion, pgVersion100, pgVersion110, - pgVersion120) + pgVersion120, + pgVersion150) import PostgREST.SchemaCache.Identifiers (AccessSet, FieldName, QualifiedIdentifier (..), RelIdentifier (..), @@ -74,7 +76,6 @@ import qualified PostgREST.MediaType as MediaType import Protolude - data SchemaCache = SchemaCache { dbTables :: TablesMap , dbRelationships :: RelationshipsMap @@ -82,16 +83,18 @@ data SchemaCache = SchemaCache , dbRepresentations :: RepresentationsMap , dbMediaHandlers :: MediaHandlerMap , dbTimezones :: TimezoneNames + , dbFuncSettings :: FuncSettings } instance JSON.ToJSON SchemaCache where - toJSON (SchemaCache tabs rels routs reps _ _) = JSON.object [ + toJSON (SchemaCache tabs rels routs reps _ _ _) = JSON.object [ "dbTables" .= JSON.toJSON tabs , "dbRelationships" .= JSON.toJSON rels , "dbRoutines" .= JSON.toJSON routs , "dbRepresentations" .= JSON.toJSON reps , "dbMediaHandlers" .= JSON.emptyArray , "dbTimezones" .= JSON.emptyArray + , "dbFuncSettings" .= JSON.emptyArray ] -- | A view foreign key or primary key dependency detected on its source table @@ -145,6 +148,7 @@ querySchemaCache AppConfig{..} = do reps <- SQL.statement schemas $ dataRepresentations prepared mHdlers <- SQL.statement schemas $ mediaHandlers pgVer prepared tzones <- SQL.statement mempty $ timezones prepared + funSets <- SQL.statement mempty $ funcSettings pgVer prepared _ <- let sleepCall = SQL.Statement "select pg_sleep($1)" (param HE.int4) HD.noResult prepared in whenJust configInternalSCSleep (`SQL.statement` sleepCall) -- only used for testing @@ -159,6 +163,7 @@ querySchemaCache AppConfig{..} = do , dbRepresentations = reps , dbMediaHandlers = HM.union mHdlers initialMediaHandlers -- the custom handlers will override the initial ones , dbTimezones = tzones + , dbFuncSettings = funSets } where schemas = toList configDbSchemas @@ -195,6 +200,7 @@ removeInternal schemas dbStruct = , dbRepresentations = dbRepresentations dbStruct -- no need to filter, not directly exposed through the API , dbMediaHandlers = dbMediaHandlers dbStruct , dbTimezones = dbTimezones dbStruct + , dbFuncSettings = dbFuncSettings dbStruct } where hasInternalJunction ComputedRelationship{} = False @@ -297,7 +303,6 @@ decodeFuncs = <*> (parseVolatility <$> column HD.char) <*> column HD.bool <*> nullableColumn (toIsolationLevel <$> HD.text) - <*> nullableColumn HD.text addKey :: Routine -> (QualifiedIdentifier, Routine) addKey pd = (QualifiedIdentifier (pdSchema pd) (pdName pd), pd) @@ -431,8 +436,7 @@ funcsSqlQuery pgVer = [q| bt.oid <> bt.base as rettype_is_composite_alias, p.provolatile, p.provariadic > 0 as hasvariadic, - lower((regexp_split_to_array((regexp_split_to_array(iso_config, '='))[2], ','))[1]) AS transaction_isolation_level, - lower((regexp_split_to_array((regexp_split_to_array(timeout_config, '='))[2], ','))[1]) AS statement_timeout + lower((regexp_split_to_array((regexp_split_to_array(iso_config, '='))[2], ','))[1]) AS transaction_isolation_level FROM pg_proc p LEFT JOIN arguments a ON a.oid = p.oid JOIN pg_namespace pn ON pn.oid = p.pronamespace @@ -442,7 +446,6 @@ funcsSqlQuery pgVer = [q| LEFT JOIN pg_class comp ON comp.oid = t.typrelid LEFT JOIN pg_description as d ON d.objoid = p.oid LEFT JOIN LATERAL unnest(proconfig) iso_config ON iso_config like 'default_transaction_isolation%' - LEFT JOIN LATERAL unnest(proconfig) timeout_config ON timeout_config like 'statement_timeout%' WHERE t.oid <> 'trigger'::regtype AND COALESCE(a.callable, true) |] <> (if pgVer >= pgVersion110 then "AND prokind = 'f'" else "AND NOT (proisagg OR proiswindow)") @@ -1203,6 +1206,34 @@ timezones = SQL.Statement sql HE.noParams decodeTimezones decodeTimezones :: HD.Result TimezoneNames decodeTimezones = S.fromList . map encodeUtf8 <$> HD.rowList (column HD.text) +funcSettings :: PgVersion -> Bool -> SQL.Statement () FuncSettings +funcSettings pgVer = SQL.Statement sql HE.noParams rows + where + sql = [q| + WITH + func_setting AS ( + SELECT p.proname, unnest(p.proconfig) AS setting + FROM pg_proc p + ), + kv_settings AS ( + SELECT + proname, + substr(setting, 1, strpos(setting, '=') - 1) as key, + lower(substr(setting, strpos(setting, '=') + 1)) as value + FROM func_setting + ) + SELECT + proname, kv.key AS key, kv.value AS value + FROM kv_settings kv + JOIN pg_settings ps ON ps.name = kv.key |] <> + (if pgVer >= pgVersion150 + then "and (ps.context = 'user' or has_parameter_privilege(current_user::regrole::oid, ps.name, 'set'));" + else "and ps.context = 'user';") + + rows :: HD.Result FuncSettings + rows = HD.rowList $ (,,) <$> (encodeUtf8 <$> column HD.text) <*> (encodeUtf8 <$> column HD.text) <*> (encodeUtf8 <$> column HD.text) + + param :: HE.Value a -> HE.Params a param = HE.param . HE.nonNullable diff --git a/src/PostgREST/SchemaCache/Routine.hs b/src/PostgREST/SchemaCache/Routine.hs index e84d722b9a3..a1b91ab9224 100644 --- a/src/PostgREST/SchemaCache/Routine.hs +++ b/src/PostgREST/SchemaCache/Routine.hs @@ -58,12 +58,11 @@ data Routine = Function , pdVolatility :: FuncVolatility , pdHasVariadic :: Bool , pdIsoLvl :: Maybe SQL.IsolationLevel - , pdTimeout :: Maybe Text } deriving (Eq, Show, Generic) -- need to define JSON manually bc SQL.IsolationLevel doesn't have a JSON instance(and we can't define one for that type without getting a compiler error) instance JSON.ToJSON Routine where - toJSON (Function sch nam desc params ret vol hasVar _ tout) = JSON.object + toJSON (Function sch nam desc params ret vol hasVar _) = JSON.object [ "pdSchema" .= sch , "pdName" .= nam @@ -72,7 +71,6 @@ instance JSON.ToJSON Routine where , "pdReturnType" .= JSON.toJSON ret , "pdVolatility" .= JSON.toJSON vol , "pdHasVariadic" .= JSON.toJSON hasVar - , "pdTimeout" .= tout ] data RoutineParam = RoutineParam @@ -86,10 +84,10 @@ data RoutineParam = RoutineParam -- Order by least number of params in the case of overloaded functions instance Ord Routine where - Function schema1 name1 des1 prms1 rt1 vol1 hasVar1 iso1 tout1 `compare` Function schema2 name2 des2 prms2 rt2 vol2 hasVar2 iso2 tout2 + Function schema1 name1 des1 prms1 rt1 vol1 hasVar1 iso1 `compare` Function schema2 name2 des2 prms2 rt2 vol2 hasVar2 iso2 | schema1 == schema2 && name1 == name2 && length prms1 < length prms2 = LT | schema2 == schema2 && name1 == name2 && length prms1 > length prms2 = GT - | otherwise = (schema1, name1, des1, prms1, rt1, vol1, hasVar1, iso1, tout1) `compare` (schema2, name2, des2, prms2, rt2, vol2, hasVar2, iso2, tout2) + | otherwise = (schema1, name1, des1, prms1, rt1, vol1, hasVar1, iso1) `compare` (schema2, name2, des2, prms2, rt2, vol2, hasVar2, iso2) -- | A map of all procs, all of which can be overloaded(one entry will have more than one Routine). -- | It uses a HashMap for a faster lookup. diff --git a/test/io/fixtures.sql b/test/io/fixtures.sql index bd4593ab80c..f3a392dc17f 100644 --- a/test/io/fixtures.sql +++ b/test/io/fixtures.sql @@ -198,3 +198,9 @@ $$ language sql set statement_timeout = '4s'; create function get_postgres_version() returns int as $$ select current_setting('server_version_num')::int; $$ language sql; + +GRANT SET ON PARAMETER log_min_duration_sample TO postgrest_test_anonymous; + +create or replace function log_min_duration_test() returns text as $$ + select current_setting('log_min_duration_sample',false); +$$ language sql set log_min_duration_sample = '5s'; diff --git a/test/io/test_io.py b/test/io/test_io.py index e09a4720cc1..295a532c41e 100644 --- a/test/io/test_io.py +++ b/test/io/test_io.py @@ -1327,28 +1327,6 @@ def test_no_preflight_request_with_CORS_config_should_not_return_header(defaulte assert "Access-Control-Allow-Origin" not in response.headers -def test_fail_with_3_sec_statement_and_1_sec_statement_timeout(defaultenv): - "statement that takes three seconds to execute should fail with one second timeout" - - with run(env=defaultenv) as postgrest: - response = postgrest.session.post("/rpc/one_sec_timeout") - - assert response.status_code == 500 - assert ( - response.text - == '{"code":"57014","details":null,"hint":null,"message":"canceling statement due to statement timeout"}' - ) - - -def test_passes_with_3_sec_statement_and_4_sec_statement_timeout(defaultenv): - "statement that takes three seconds to execute should succeed with four second timeout" - - with run(env=defaultenv) as postgrest: - response = postgrest.session.post("/rpc/four_sec_timeout") - - assert response.status_code == 204 - - @pytest.mark.parametrize("level", ["crit", "error", "warn", "info"]) def test_db_error_logging_to_stderr(level, defaultenv, metapostgrest): "verify that DB errors are logged to stderr" @@ -1375,3 +1353,34 @@ def test_db_error_logging_to_stderr(level, defaultenv, metapostgrest): else: assert " 500 " in output[0] assert "canceling statement due to statement timeout" in output[1] + + +def test_function_setting_statement_timeout_fails(defaultenv): + "statement that takes three seconds to execute should fail with one second timeout" + + with run(env=defaultenv) as postgrest: + response = postgrest.session.post("/rpc/one_sec_timeout") + + assert response.status_code == 500 + assert ( + response.text + == '{"code":"57014","details":null,"hint":null,"message":"canceling statement due to statement timeout"}' + ) + + +def test_function_setting_statement_timeout_passes(defaultenv): + "statement that takes three seconds to execute should succeed with four second timeout" + + with run(env=defaultenv) as postgrest: + response = postgrest.session.post("/rpc/four_sec_timeout") + + assert response.status_code == 204 + + +def test_function_setting_log_min_duration_sample(defaultenv): + "check function setting log_min_duration_sample is applied" + + with run(env=defaultenv) as postgrest: + response = postgrest.session.post("/rpc/log_min_duration_test") + + assert response.text == '"5s"'