Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add request.spec to db-root-spec #1794

Merged
merged 3 commits into from
May 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion postgrest.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ library
PostgREST.Config
PostgREST.Config.Database
PostgREST.Config.JSPath
PostgREST.Config.PgVersion
PostgREST.Config.Proxy
PostgREST.ContentType
PostgREST.DbStructure
PostgREST.DbStructure.Identifiers
PostgREST.DbStructure.PgVersion
PostgREST.DbStructure.Proc
PostgREST.DbStructure.Relationship
PostgREST.DbStructure.Table
Expand Down
10 changes: 6 additions & 4 deletions src/PostgREST/App.hs
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@ import qualified PostgREST.Request.DbRequestBuilder as ReqBuilder
import PostgREST.AppState (AppState)
import PostgREST.Config (AppConfig (..),
LogLevel (..))
import PostgREST.Config.PgVersion (PgVersion (..))
import PostgREST.ContentType (ContentType (..))
import PostgREST.DbStructure (DbStructure (..),
tablePKCols)
import PostgREST.DbStructure.Identifiers (FieldName,
QualifiedIdentifier (..),
Schema)
import PostgREST.DbStructure.PgVersion (PgVersion (..))
import PostgREST.DbStructure.Proc (ProcDescription (..),
ProcVolatility (..))
import PostgREST.DbStructure.Table (Table (..))
Expand Down Expand Up @@ -141,11 +141,12 @@ postgrest logLev appState connWorker =
conf <- AppState.getConfig appState
maybeDbStructure <- AppState.getDbStructure appState
pgVer <- AppState.getPgVersion appState
jsonDbS <- AppState.getJsonDbS appState

let
eitherResponse :: IO (Either Error Wai.Response)
eitherResponse =
runExceptT $ postgrestResponse conf maybeDbStructure pgVer (AppState.getPool appState) time req
runExceptT $ postgrestResponse conf maybeDbStructure jsonDbS pgVer (AppState.getPool appState) time req

response <- either Error.errorResponseFor identity <$> eitherResponse

Expand All @@ -159,12 +160,13 @@ postgrest logLev appState connWorker =
postgrestResponse
:: AppConfig
-> Maybe DbStructure
-> ByteString
-> PgVersion
-> SQL.Pool
-> UTCTime
-> Wai.Request
-> Handler IO Wai.Response
postgrestResponse conf maybeDbStructure pgVer pool time req = do
postgrestResponse conf maybeDbStructure jsonDbS pgVer pool time req = do
body <- lift $ Wai.strictRequestBody req

dbStructure <-
Expand All @@ -187,7 +189,7 @@ postgrestResponse conf maybeDbStructure pgVer pool time req = do

runDbHandler pool (txMode apiRequest) jwtClaims .
Middleware.optionalRollback conf apiRequest $
Middleware.runPgLocals conf jwtClaims handleReq apiRequest
Middleware.runPgLocals conf jwtClaims handleReq apiRequest jsonDbS

runDbHandler :: SQL.Pool -> SQL.Mode -> Auth.JWTClaims -> DbHandler a -> Handler IO a
runDbHandler pool mode jwtClaims handler = do
Expand Down
18 changes: 14 additions & 4 deletions src/PostgREST/AppState.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ module PostgREST.AppState
, getConfig
, getDbStructure
, getIsWorkerOn
, getJsonDbS
, getMainThreadId
, getPgVersion
, getPool
Expand All @@ -14,6 +15,7 @@ module PostgREST.AppState
, putConfig
, putDbStructure
, putIsWorkerOn
, putJsonDbS
, putPgVersion
, releasePool
, signalListener
Expand All @@ -28,10 +30,9 @@ import Data.IORef (IORef, atomicWriteIORef, newIORef,
readIORef)
import Data.Time.Clock (UTCTime, getCurrentTime)

import PostgREST.Config (AppConfig (..))
import PostgREST.DbStructure (DbStructure)
import PostgREST.DbStructure.PgVersion (PgVersion (..),
minimumPgVersion)
import PostgREST.Config (AppConfig (..))
import PostgREST.Config.PgVersion (PgVersion (..), minimumPgVersion)
import PostgREST.DbStructure (DbStructure)

import Protolude hiding (toS)
import Protolude.Conv (toS)
Expand All @@ -42,6 +43,8 @@ data AppState = AppState
, statePgVersion :: IORef PgVersion
-- | No schema cache at the start. Will be filled in by the connectionWorker
, stateDbStructure :: IORef (Maybe DbStructure)
-- | Cached DbStructure in json
, stateJsonDbS :: IORef ByteString
-- | Helper ref to make sure just one connectionWorker can run at a time
, stateIsWorkerOn :: IORef Bool
-- | Binary semaphore used to sync the listener(NOTIFY reload) with the connectionWorker.
Expand All @@ -63,6 +66,7 @@ initWithPool newPool conf =
-- assume we're in a supported version when starting, this will be corrected on a later step
<$> newIORef minimumPgVersion
<*> newIORef Nothing
<*> newIORef mempty
<*> newIORef False
<*> newEmptyMVar
<*> newIORef conf
Expand Down Expand Up @@ -92,6 +96,12 @@ putDbStructure :: AppState -> DbStructure -> IO ()
putDbStructure appState structure =
atomicWriteIORef (stateDbStructure appState) $ Just structure

getJsonDbS :: AppState -> IO ByteString
getJsonDbS = readIORef . stateJsonDbS

putJsonDbS :: AppState -> ByteString -> IO ()
putJsonDbS appState = atomicWriteIORef (stateJsonDbS appState)

getIsWorkerOn :: AppState -> IO Bool
getIsWorkerOn = readIORef . stateIsWorkerOn

Expand Down
4 changes: 2 additions & 2 deletions src/PostgREST/CLI.hs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import Text.Heredoc (str)

import PostgREST.AppState (AppState)
import PostgREST.Config (AppConfig (..))
import PostgREST.DbStructure (getDbStructure)
import PostgREST.DbStructure (queryDbStructure)
import PostgREST.Version (prettyVersion)
import PostgREST.Workers (reReadConfig)

Expand Down Expand Up @@ -56,7 +56,7 @@ dumpSchema appState = do
result <-
P.use (AppState.getPool appState) $
HT.transaction HT.ReadCommitted HT.Read $
getDbStructure
queryDbStructure
(toList configDbSchemas)
configDbExtraSearchPath
configDbPreparedStatements
Expand Down
24 changes: 13 additions & 11 deletions src/PostgREST/Config.hs
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,11 @@ import Numeric (readOct, showOct)
import System.Environment (getEnvironment)
import System.Posix.Types (FileMode)

import PostgREST.Config.JSPath (JSPath, JSPathExp (..), pRoleClaimKey)
import PostgREST.Config.Proxy (Proxy (..), isMalformedProxyUri,
toURI)
import PostgREST.Config.JSPath (JSPath, JSPathExp (..),
pRoleClaimKey)
import PostgREST.Config.Proxy (Proxy (..),
isMalformedProxyUri, toURI)
import PostgREST.DbStructure.Identifiers (QualifiedIdentifier, toQi)

import Protolude hiding (Proxy, toList, toS)
import Protolude.Conv (toS)
Expand All @@ -68,9 +70,9 @@ data AppConfig = AppConfig
, configDbMaxRows :: Maybe Integer
, configDbPoolSize :: Int
, configDbPoolTimeout :: NominalDiffTime
, configDbPreRequest :: Maybe Text
, configDbPreRequest :: Maybe QualifiedIdentifier
, configDbPreparedStatements :: Bool
, configDbRootSpec :: Maybe Text
, configDbRootSpec :: Maybe QualifiedIdentifier
, configDbSchemas :: NonEmpty Text
, configDbConfig :: Bool
, configDbTxAllowOverride :: Bool
Expand Down Expand Up @@ -113,9 +115,9 @@ toText conf =
,("db-max-rows", maybe "\"\"" show . configDbMaxRows)
,("db-pool", show . configDbPoolSize)
,("db-pool-timeout", show . floor . configDbPoolTimeout)
,("db-pre-request", q . fromMaybe mempty . configDbPreRequest)
,("db-pre-request", q . maybe mempty show . configDbPreRequest)
,("db-prepared-statements", T.toLower . show . configDbPreparedStatements)
,("db-root-spec", q . fromMaybe mempty . configDbRootSpec)
,("db-root-spec", q . maybe mempty show . configDbRootSpec)
,("db-schemas", q . T.intercalate "," . toList . configDbSchemas)
,("db-config", q . T.toLower . show . configDbConfig)
,("db-tx-end", q . showTxEnd)
Expand Down Expand Up @@ -197,11 +199,11 @@ parser optPath env dbSettings =
(optInt "max-rows")
<*> (fromMaybe 10 <$> optInt "db-pool")
<*> (fromIntegral . fromMaybe 10 <$> optInt "db-pool-timeout")
<*> optWithAlias (optString "db-pre-request")
(optString "pre-request")
<*> (fmap toQi <$> optWithAlias (optString "db-pre-request")
(optString "pre-request"))
<*> (fromMaybe True <$> optBool "db-prepared-statements")
<*> optWithAlias (optString "db-root-spec")
(optString "root-spec")
<*> (fmap toQi <$> optWithAlias (optString "db-root-spec")
(optString "root-spec"))
<*> (fromList . splitOnCommas <$> reqWithAlias (optValue "db-schemas")
(optValue "db-schema")
"missing key: either db-schemas or db-schema must be set")
Expand Down
15 changes: 12 additions & 3 deletions src/PostgREST/Config/Database.hs
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
{-# LANGUAGE QuasiQuotes #-}

module PostgREST.Config.Database
( loadDbSettings
( queryDbSettings
, queryPgVersion
) where

import PostgREST.Config.PgVersion (PgVersion (..))

import qualified Hasql.Decoders as HD
import qualified Hasql.Encoders as HE
import qualified Hasql.Pool as P
import qualified Hasql.Session as H
import qualified Hasql.Statement as H
import qualified Hasql.Transaction as HT
import qualified Hasql.Transaction.Sessions as HT
Expand All @@ -16,9 +20,14 @@ import Text.InterpolatedString.Perl6 (q)

import Protolude hiding (hPutStrLn)

queryPgVersion :: H.Session PgVersion
queryPgVersion = H.statement mempty $ H.Statement sql HE.noParams versionRow False
where
sql = "SELECT current_setting('server_version_num')::integer, current_setting('server_version')"
versionRow = HD.singleRow $ PgVersion <$> column HD.int4 <*> column HD.text

loadDbSettings :: P.Pool -> IO [(Text, Text)]
loadDbSettings pool = do
queryDbSettings :: P.Pool -> IO [(Text, Text)]
queryDbSettings pool = do
result <-
P.use pool . HT.transaction HT.ReadCommitted HT.Read $
HT.statement mempty dbSettingsStatement
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
module PostgREST.DbStructure.PgVersion
module PostgREST.Config.PgVersion
( PgVersion(..)
, minimumPgVersion
, pgVersion95
Expand Down
1 change: 0 additions & 1 deletion src/PostgREST/Config/Proxy.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

{-|
Module : PostgREST.Private.ProxyUri
Description : Proxy Uri validator
Expand Down
15 changes: 3 additions & 12 deletions src/PostgREST/DbStructure.hs
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@ These queries are executed once at startup or when PostgREST is reloaded.

module PostgREST.DbStructure
( DbStructure(..)
, getDbStructure
, queryDbStructure
, accessibleTables
, accessibleProcs
, schemaDescription
, getPgVersion
, tableCols
, tablePKCols
) where
Expand All @@ -34,7 +33,6 @@ import qualified Data.HashMap.Strict as M
import qualified Data.List as L
import qualified Hasql.Decoders as HD
import qualified Hasql.Encoders as HE
import qualified Hasql.Session as H
import qualified Hasql.Statement as H
import qualified Hasql.Transaction as HT

Expand All @@ -45,7 +43,6 @@ import Text.InterpolatedString.Perl6 (q)

import PostgREST.DbStructure.Identifiers (QualifiedIdentifier (..),
Schema, TableName)
import PostgREST.DbStructure.PgVersion (PgVersion (..))
import PostgREST.DbStructure.Proc (PgArg (..), PgType (..),
ProcDescription (..),
ProcVolatility (..),
Expand Down Expand Up @@ -85,8 +82,8 @@ type ViewColumn = Column
-- | A SQL query that can be executed independently
type SqlQuery = ByteString

getDbStructure :: [Schema] -> [Schema] -> Bool -> HT.Transaction DbStructure
getDbStructure schemas extraSearchPath prepared = do
queryDbStructure :: [Schema] -> [Schema] -> Bool -> HT.Transaction DbStructure
queryDbStructure schemas extraSearchPath prepared = do
HT.sql "set local schema ''" -- This voids the search path. The following queries need this for getting the fully qualified name(schema.name) of every db object
tabs <- HT.statement mempty $ allTables prepared
cols <- HT.statement schemas $ allColumns tabs prepared
Expand Down Expand Up @@ -937,12 +934,6 @@ pfkSourceColumns cols =
join pks_fks using (resorigtbl, resorigcol)
order by view_schema, view_name, view_column_name; |]

getPgVersion :: H.Session PgVersion
getPgVersion = H.statement mempty $ H.Statement sql HE.noParams versionRow False
where
sql = "SELECT current_setting('server_version_num')::integer, current_setting('server_version')"
versionRow = HD.singleRow $ PgVersion <$> column HD.int4 <*> column HD.text

param :: HE.Value a -> HE.Params a
param = HE.param . HE.nonNullable

Expand Down
14 changes: 14 additions & 0 deletions src/PostgREST/DbStructure/Identifiers.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@ module PostgREST.DbStructure.Identifiers
, Schema
, TableName
, FieldName
, toQi
) where

import qualified Data.Aeson as JSON
import qualified Data.Text as T
import qualified GHC.Show

import Protolude

Expand All @@ -23,6 +26,17 @@ data QualifiedIdentifier = QualifiedIdentifier

instance Hashable QualifiedIdentifier

instance Show QualifiedIdentifier where
show (QualifiedIdentifier s i) =
(if T.null s then mempty else toS s <> ".") <> toS i

-- TODO: Handle a case where the QI comes like this: "my.fav.schema"."my.identifier"
-- Right now it only handles the schema.identifier case
toQi :: Text -> QualifiedIdentifier
toQi txt = case T.drop 1 <$> T.breakOn "." txt of
(i, "") -> QualifiedIdentifier mempty i
(s, i) -> QualifiedIdentifier s i

type Schema = Text
type TableName = Text
type FieldName = Text
34 changes: 18 additions & 16 deletions src/PostgREST/Middleware.hs
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ import System.Log.FastLogger (toLogStr)
import PostgREST.Config (AppConfig (..), LogLevel (..))
import PostgREST.Error (Error, errorResponseFor)
import PostgREST.GucHeader (addHeadersIfNotIncluded)
import PostgREST.Query.SqlFragment (intercalateSnippet,
unknownLiteral)
import PostgREST.Request.ApiRequest (ApiRequest (..))
import PostgREST.Query.SqlFragment (fromQi, intercalateSnippet,
unknownEncoder)
import PostgREST.Request.ApiRequest (ApiRequest (..), Target (..))

import PostgREST.Request.Preferences

Expand All @@ -54,33 +54,35 @@ import Protolude.Conv (toS)
-- | Runs local(transaction scoped) GUCs for every request, plus the pre-request function
runPgLocals :: AppConfig -> M.HashMap Text JSON.Value ->
(ApiRequest -> ExceptT Error H.Transaction Wai.Response) ->
ApiRequest -> ExceptT Error H.Transaction Wai.Response
runPgLocals conf claims app req = do
ApiRequest -> ByteString -> ExceptT Error H.Transaction Wai.Response
runPgLocals conf claims app req jsonDbS = do
lift $ H.statement mempty $ H.dynamicallyParameterized
("select " <> intercalateSnippet ", " (searchPathSql : roleSql ++ claimsSql ++ [methodSql, pathSql] ++ headersSql ++ cookiesSql ++ appSettingsSql))
("select " <> intercalateSnippet ", " (searchPathSql : roleSql ++ claimsSql ++ [methodSql, pathSql] ++ headersSql ++ cookiesSql ++ appSettingsSql ++ specSql))
HD.noResult (configDbPreparedStatements conf)
lift $ traverse_ H.sql preReqSql
app req
where
methodSql = setConfigLocal mempty ("request.method", toS $ iMethod req)
pathSql = setConfigLocal mempty ("request.path", toS $ iPath req)
methodSql = setConfigLocal mempty ("request.method", iMethod req)
pathSql = setConfigLocal mempty ("request.path", iPath req)
headersSql = setConfigLocal "request.header." <$> iHeaders req
cookiesSql = setConfigLocal "request.cookie." <$> iCookies req
claimsWithRole =
let anon = JSON.String . toS $ configDbAnonRole conf in -- role claim defaults to anon if not specified in jwt
M.union claims (M.singleton "role" anon)
claimsSql = setConfigLocal "request.jwt.claim." <$> [(c,unquoted v) | (c,v) <- M.toList claimsWithRole]
roleSql = maybeToList $ (\x -> setConfigLocal mempty ("role", unquoted x)) <$> M.lookup "role" claimsWithRole
appSettingsSql = setConfigLocal mempty <$> configAppSettings conf
claimsSql = setConfigLocal "request.jwt.claim." <$> [(toS c, toS $ unquoted v) | (c,v) <- M.toList claimsWithRole]
roleSql = maybeToList $ (\x -> setConfigLocal mempty ("role", toS $ unquoted x)) <$> M.lookup "role" claimsWithRole
appSettingsSql = setConfigLocal mempty <$> (join bimap toS <$> configAppSettings conf)
searchPathSql =
let schemas = T.intercalate ", " (iSchema req : configDbExtraSearchPath conf) in
setConfigLocal mempty ("search_path", schemas)
preReqSql = (\f -> "select " <> toS f <> "();") <$> configDbPreRequest conf

setConfigLocal mempty ("search_path", toS schemas)
preReqSql = (\f -> "select " <> fromQi f <> "();") <$> configDbPreRequest conf
specSql = case iTarget req of
TargetProc{tpIsRootSpec=True} -> [setConfigLocal mempty ("request.spec", jsonDbS)]
_ -> mempty
-- | Do a pg set_config(setting, value, true) call. This is equivalent to a SET LOCAL.
setConfigLocal :: Text -> (Text, Text) -> H.Snippet
setConfigLocal :: ByteString -> (ByteString, ByteString) -> H.Snippet
setConfigLocal prefix (k, v) =
"set_config(" <> unknownLiteral (prefix <> k) <> ", " <> unknownLiteral v <> ", true)"
"set_config(" <> unknownEncoder (prefix <> k) <> ", " <> unknownEncoder v <> ", true)"

-- | Log in apache format. Only requests that have a status greater than minStatus are logged.
-- | There's no way to filter logs in the apache format on wai-extra: https://hackage.haskell.org/package/wai-extra-3.0.29.2/docs/Network-Wai-Middleware-RequestLogger.html#t:OutputFormat.
Expand Down
Loading