From 2e6e4f68b435ec162b49e898dc911e610c6fb213 Mon Sep 17 00:00:00 2001 From: Remo Rechkemmer <59358383+monacoremo@users.noreply.github.com> Date: Tue, 23 Feb 2021 22:41:48 +0100 Subject: [PATCH] refactor: App.hs and related changes (#1725) * Use ExceptT to avoid 'staircasing' case analysis in App.hs * Split large function in App.hs into individual handler functions * Adapt API of Auth.hs, OpenApi.hs etc. to simplify the use of those modules in App.hs * Split optional rollback functionality into Middleware * Unify SimpleError and ApiRequestError into one Error type, so it can be used across modules --- postgrest.cabal | 1 + src/PostgREST/App.hs | 953 +++++++++++++++++------------- src/PostgREST/Auth.hs | 164 ++--- src/PostgREST/Config.hs | 55 +- src/PostgREST/DbRequestBuilder.hs | 12 +- src/PostgREST/Error.hs | 24 +- src/PostgREST/Middleware.hs | 54 +- src/PostgREST/OpenAPI.hs | 61 +- src/PostgREST/Statements.hs | 8 +- test/SpecHelper.hs | 3 +- test/io-tests/test_io.py | 2 + 11 files changed, 752 insertions(+), 585 deletions(-) diff --git a/postgrest.cabal b/postgrest.cabal index a6768845b77..752f9483eb4 100644 --- a/postgrest.cabal +++ b/postgrest.cabal @@ -81,6 +81,7 @@ library , jose >= 0.8.1 && < 0.9 , lens >= 4.14 && < 4.20 , lens-aeson >= 1.0.1 && < 1.2 + , mtl >= 2.2.2 && < 2.3 , network-uri >= 2.6.1 && < 2.8 , optparse-applicative >= 0.13 && < 0.17 , parsec >= 3.1.11 && < 3.2 diff --git a/src/PostgREST/App.hs b/src/PostgREST/App.hs index adb26b7c32b..1ab873bd883 100644 --- a/src/PostgREST/App.hs +++ b/src/PostgREST/App.hs @@ -9,430 +9,547 @@ Some of its functionality includes: - Producing HTTP Headers according to RFCs. - Content Negotiation -} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE MultiWayIf #-} -{-# LANGUAGE NamedFieldPuns #-} -{-# LANGUAGE ScopedTypeVariables #-} - -module PostgREST.App ( - postgrest -) where - -import qualified Data.ByteString.Char8 as BS -import qualified Data.HashMap.Strict as M -import qualified Data.List as L (union) -import qualified Data.Set as S -import qualified Hasql.Pool as P -import qualified Hasql.Transaction as H -import qualified Hasql.Transaction as HT -import qualified Hasql.Transaction.Sessions as HT - -import Data.IORef (IORef, readIORef) -import Data.Time.Clock (UTCTime) -import Network.HTTP.Types.URI (renderSimpleQuery) - -import Control.Applicative -import Data.Maybe -import Network.HTTP.Types.Header -import Network.HTTP.Types.Status -import Network.Wai - -import PostgREST.ApiRequest (Action (..), ApiRequest (..), - InvokeMethod (..), Target (..), - mutuallyAgreeable, userApiRequest) -import PostgREST.Auth (attemptJwtClaims, containsRole, - jwtClaims) -import PostgREST.Config (AppConfig (..)) -import PostgREST.DbRequestBuilder (mutateRequest, readRequest, - returningCols) -import PostgREST.DbStructure -import PostgREST.Error (PgError (..), SimpleError (..), - errorResponseFor, singularityError) -import PostgREST.Middleware -import PostgREST.OpenAPI -import PostgREST.QueryBuilder (limitedQuery, mutateRequestToQuery, - readRequestToCountQuery, - readRequestToQuery, - requestToCallProcQuery) -import PostgREST.RangeQuery (allRange, contentRangeH, - rangeStatusHeader) -import PostgREST.Statements (callProcStatement, - createExplainStatement, - createReadStatement, - createWriteStatement) +{-# LANGUAGE RecordWildCards #-} +module PostgREST.App (postgrest) where + +import Control.Monad.Except (liftEither) +import Data.Either.Combinators (mapLeft) +import Data.IORef (IORef, readIORef) +import Data.List (union) +import Data.Time.Clock (UTCTime) + +import qualified Data.ByteString.Char8 as BS8 +import qualified Data.ByteString.Lazy as LBS +import qualified Data.Set as Set +import qualified Hasql.DynamicStatements.Snippet as SQL +import qualified Hasql.Pool as SQL +import qualified Hasql.Transaction as SQL +import qualified Hasql.Transaction.Sessions as SQL +import qualified Network.HTTP.Types.Header as HTTP +import qualified Network.HTTP.Types.Status as HTTP +import qualified Network.HTTP.Types.URI as HTTP +import qualified Network.Wai as Wai + +import qualified PostgREST.ApiRequest as ApiRequest +import qualified PostgREST.Auth as Auth +import qualified PostgREST.DbRequestBuilder as ReqBuilder +import qualified PostgREST.DbStructure as DbStructure +import qualified PostgREST.Error as Error +import qualified PostgREST.Middleware as Middleware +import qualified PostgREST.OpenAPI as OpenAPI +import qualified PostgREST.QueryBuilder as QueryBuilder +import qualified PostgREST.RangeQuery as RangeQuery +import qualified PostgREST.Statements as Statements + +import PostgREST.ApiRequest (Action (..), ApiRequest (..), + InvokeMethod (..), Target (..)) +import PostgREST.Config (AppConfig (..)) +import PostgREST.Error (Error) + import PostgREST.Types -import Protolude hiding (Proxy, intercalate, toS) -import Protolude.Conv (toS) -postgrest :: LogLevel -> IORef AppConfig -> IORef (Maybe DbStructure) -> P.Pool -> IO UTCTime -> IO () -> Application +import Protolude hiding (Handler, toS) +import Protolude.Conv (toS) + + +data RequestContext = RequestContext + { ctxConfig :: AppConfig + , ctxDbStructure :: DbStructure + , ctxApiRequest :: ApiRequest + , ctxContentType :: ContentType + } + +type Handler = ExceptT Error + +type DbHandler = Handler SQL.Transaction + + +-- | PostgREST application +postgrest + :: LogLevel + -> IORef AppConfig + -> IORef (Maybe DbStructure) + -> SQL.Pool + -> IO UTCTime + -> IO () -- ^ Lauch connection worker in a separate thread + -> Wai.Application postgrest logLev refConf refDbStructure pool getTime connWorker = - pgrstMiddleware logLev $ \ req respond -> do - time <- getTime - body <- strictRequestBody req - maybeDbStructure <- readIORef refDbStructure - conf <- readIORef refConf + Middleware.pgrstMiddleware logLev $ + \req respond -> do + time <- getTime + conf <- readIORef refConf + maybeDbStructure <- readIORef refDbStructure + + let + eitherResponse :: IO (Either Error Wai.Response) + eitherResponse = + runExceptT $ postgrestResponse conf maybeDbStructure pool time req + + response <- either Error.errorResponseFor identity <$> eitherResponse + + -- Launch the connWorker when the connection is down. The postgrest + -- function can respond successfully (with a stale schema cache) before + -- the connWorker is done. + when (Wai.responseStatus response == HTTP.status503) connWorker + + respond response + +postgrestResponse + :: AppConfig + -> Maybe DbStructure + -> SQL.Pool + -> UTCTime + -> Wai.Request + -> Handler IO Wai.Response +postgrestResponse conf@AppConfig{..} maybeDbStructure pool time req = do + body <- lift $ Wai.strictRequestBody req + + dbStructure <- case maybeDbStructure of - Nothing -> respond . errorResponseFor $ ConnectionLostError - Just dbStructure -> do - response <- do - let apiReq = userApiRequest (configDbSchemas conf) (configDbRootSpec conf) dbStructure req body - case apiReq of - Left err -> return . errorResponseFor $ err - Right apiRequest -> do - -- The jwt must be checked before touching the db. - attempt <- attemptJwtClaims (configJWKS conf) (configJwtAudience conf) (toS $ iJWT apiRequest) time (configJwtRoleClaimKey conf) - case jwtClaims attempt of - Left errJwt -> return . errorResponseFor $ errJwt - Right claims -> do - let - authed = containsRole claims - shouldCommit = configDbTxAllowOverride conf && iPreferTransaction apiRequest == Just Commit - shouldRollback = configDbTxAllowOverride conf && iPreferTransaction apiRequest == Just Rollback - preferenceApplied - | shouldCommit = addHeadersIfNotIncluded [(hPreferenceApplied, BS.pack (show Commit))] - | shouldRollback = addHeadersIfNotIncluded [(hPreferenceApplied, BS.pack (show Rollback))] - | otherwise = identity - handleReq = do - when (shouldRollback || (configDbTxRollbackAll conf && not shouldCommit)) HT.condemn - mapResponseHeaders preferenceApplied <$> runPgLocals conf claims (app dbStructure conf) apiRequest - dbResp <- P.use pool $ HT.transaction HT.ReadCommitted (txMode apiRequest) handleReq - return $ either (errorResponseFor . PgError authed) identity dbResp - -- Launch the connWorker when the connection is down. The postgrest function can respond successfully(with a stale schema cache) before the connWorker is done. - when (responseStatus response == status503) connWorker - respond response - -txMode :: ApiRequest -> HT.Mode -txMode apiRequest = - case (iAction apiRequest, iTarget apiRequest) of - (ActionRead _ , _) -> HT.Read - (ActionInfo , _) -> HT.Read - (ActionInspect _ , _) -> HT.Read - (ActionInvoke InvGet , _) -> HT.Read - (ActionInvoke InvHead, _) -> HT.Read - (ActionInvoke InvPost, TargetProc ProcDescription{pdVolatility=Stable} _) -> HT.Read - (ActionInvoke InvPost, TargetProc ProcDescription{pdVolatility=Immutable} _) -> HT.Read - _ -> HT.Write - -app :: DbStructure -> AppConfig -> ApiRequest -> H.Transaction Response -app dbStructure conf apiRequest = - let rawContentTypes = (decodeContentType <$> configRawMediaTypes conf) `L.union` [ CTOctetStream, CTTextPlain ] in - case responseContentTypeOrError (iAccepts apiRequest) rawContentTypes (iAction apiRequest) (iTarget apiRequest) of - Left errorResponse -> return errorResponse - Right contentType -> - case (iAction apiRequest, iTarget apiRequest) of - - (ActionRead headersOnly, TargetIdent (QualifiedIdentifier tSchema tName)) -> - case readSqlParts tSchema tName of - Left errorResponse -> return errorResponse - Right (q, cq, bField, _) -> do - let cQuery = if estimatedCount - then limitedQuery cq ((+ 1) <$> maxRows) -- LIMIT maxRows + 1 so we can determine below that maxRows was surpassed - else cq - stm = createReadStatement q cQuery (contentType == CTSingularJSON) shouldCount - (contentType == CTTextCSV) bField pgVer prepared - explStm = createExplainStatement cq prepared - row <- H.statement mempty stm - let (tableTotal, queryTotal, _ , body, gucHeaders, gucStatus) = row - gucs = (,) <$> gucHeaders <*> gucStatus - case gucs of - Left err -> return $ errorResponseFor err - Right (ghdrs, gstatus) -> do - total <- if | plannedCount -> H.statement mempty explStm - | estimatedCount -> if tableTotal > (fromIntegral <$> maxRows) - then do estTotal <- H.statement mempty explStm - pure $ if estTotal > tableTotal then estTotal else tableTotal - else pure tableTotal - | otherwise -> pure tableTotal - let (rangeStatus, contentRange) = rangeStatusHeader topLevelRange queryTotal total - status = fromMaybe rangeStatus gstatus - headers = addHeadersIfNotIncluded (catMaybes [ - Just $ toHeader contentType, Just contentRange, - Just $ contentLocationH tName (iCanonicalQS apiRequest), profileH]) - (unwrapGucHeader <$> ghdrs) - rBody = if headersOnly then mempty else toS body - return $ - if contentType == CTSingularJSON && queryTotal /= 1 - then errorResponseFor . singularityError $ queryTotal - else responseLBS status headers rBody - - (ActionCreate, TargetIdent (QualifiedIdentifier tSchema tName)) -> - case mutateSqlParts tSchema tName of - Left errorResponse -> return errorResponse - Right (sq, mq) -> do - let pkCols = tablePKCols dbStructure tSchema tName - stm = createWriteStatement sq mq - (contentType == CTSingularJSON) True - (contentType == CTTextCSV) (iPreferRepresentation apiRequest) pkCols pgVer prepared - row <- H.statement mempty stm - let (_, queryTotal, fields, body, gucHeaders, gucStatus) = row - gucs = (,) <$> gucHeaders <*> gucStatus - case gucs of - Left err -> return $ errorResponseFor err - Right (ghdrs, gstatus) -> do - let - (ctHeaders, rBody) = if iPreferRepresentation apiRequest == Full - then ([Just $ toHeader contentType, profileH], toS body) - else ([], mempty) - status = fromMaybe status201 gstatus - headers = addHeadersIfNotIncluded (catMaybes ([ - if null fields - then Nothing - else Just $ locationH tName fields - , Just $ contentRangeH 1 0 $ if shouldCount then Just queryTotal else Nothing - , if null pkCols && isNothing (iOnConflict apiRequest) - then Nothing - else (\x -> ("Preference-Applied", BS.pack (show x))) <$> iPreferResolution apiRequest - ] ++ ctHeaders)) (unwrapGucHeader <$> ghdrs) - if contentType == CTSingularJSON && queryTotal /= 1 - then do - HT.condemn - return . errorResponseFor . singularityError $ queryTotal - else - return $ responseLBS status headers rBody - - (ActionUpdate, TargetIdent (QualifiedIdentifier tSchema tName)) -> - case mutateSqlParts tSchema tName of - Left errorResponse -> return errorResponse - Right (sq, mq) -> do - row <- H.statement mempty $ - createWriteStatement sq mq - (contentType == CTSingularJSON) False (contentType == CTTextCSV) - (iPreferRepresentation apiRequest) mempty pgVer prepared - let (_, queryTotal, _, body, gucHeaders, gucStatus) = row - gucs = (,) <$> gucHeaders <*> gucStatus - case gucs of - Left err -> return $ errorResponseFor err - Right (ghdrs, gstatus) -> do - let - updateIsNoOp = S.null (iColumns apiRequest) - defStatus | queryTotal == 0 && not updateIsNoOp = status404 - | iPreferRepresentation apiRequest == Full = status200 - | otherwise = status204 - status = fromMaybe defStatus gstatus - contentRangeHeader = contentRangeH 0 (queryTotal - 1) $ if shouldCount then Just queryTotal else Nothing - (ctHeaders, rBody) = if iPreferRepresentation apiRequest == Full - then ([Just $ toHeader contentType, profileH], toS body) - else ([], mempty) - headers = addHeadersIfNotIncluded (catMaybes ctHeaders ++ [contentRangeHeader]) (unwrapGucHeader <$> ghdrs) - if contentType == CTSingularJSON && queryTotal /= 1 - then do - HT.condemn - return . errorResponseFor . singularityError $ queryTotal - else - return $ responseLBS status headers rBody - - (ActionSingleUpsert, TargetIdent (QualifiedIdentifier tSchema tName)) -> - case mutateSqlParts tSchema tName of - Left errorResponse -> return errorResponse - Right (sq, mq) -> - if topLevelRange /= allRange - then return . errorResponseFor $ PutRangeNotAllowedError - else do - row <- H.statement mempty $ - createWriteStatement sq mq (contentType == CTSingularJSON) False - (contentType == CTTextCSV) (iPreferRepresentation apiRequest) mempty pgVer prepared - let (_, queryTotal, _, body, gucHeaders, gucStatus) = row - gucs = (,) <$> gucHeaders <*> gucStatus - case gucs of - Left err -> return $ errorResponseFor err - Right (ghdrs, gstatus) -> do - let headers = addHeadersIfNotIncluded (catMaybes [Just $ toHeader contentType, profileH]) (unwrapGucHeader <$> ghdrs) - (defStatus, rBody) = if iPreferRepresentation apiRequest == Full then (status200, toS body) else (status204, mempty) - status = fromMaybe defStatus gstatus - -- Makes sure the querystring pk matches the payload pk - -- e.g. PUT /items?id=eq.1 { "id" : 1, .. } is accepted, PUT /items?id=eq.14 { "id" : 2, .. } is rejected - -- If this condition is not satisfied then nothing is inserted, check the WHERE for INSERT in QueryBuilder.hs to see how it's done - if queryTotal /= 1 - then do - HT.condemn - return . errorResponseFor $ PutMatchingPkError - else - return $ responseLBS status headers rBody - - (ActionDelete, TargetIdent (QualifiedIdentifier tSchema tName)) -> - case mutateSqlParts tSchema tName of - Left errorResponse -> return errorResponse - Right (sq, mq) -> do - let stm = createWriteStatement sq mq - (contentType == CTSingularJSON) False - (contentType == CTTextCSV) - (iPreferRepresentation apiRequest) mempty pgVer prepared - row <- H.statement mempty stm - let (_, queryTotal, _, body, gucHeaders, gucStatus) = row - gucs = (,) <$> gucHeaders <*> gucStatus - case gucs of - Left err -> return $ errorResponseFor err - Right (ghdrs, gstatus) -> do - let - defStatus = if iPreferRepresentation apiRequest == Full then status200 else status204 - status = fromMaybe defStatus gstatus - contentRangeHeader = contentRangeH 1 0 $ if shouldCount then Just queryTotal else Nothing - (ctHeaders, rBody) = if iPreferRepresentation apiRequest == Full - then ([Just $ toHeader contentType, profileH], toS body) - else ([], mempty) - headers = addHeadersIfNotIncluded (catMaybes ctHeaders ++ [contentRangeHeader]) (unwrapGucHeader <$> ghdrs) - if contentType == CTSingularJSON - && queryTotal /= 1 - then do - HT.condemn - return . errorResponseFor . singularityError $ queryTotal - else - return $ responseLBS status headers rBody - - (ActionInfo, TargetIdent (QualifiedIdentifier tSchema tTable)) -> - let mTable = find (\t -> tableName t == tTable && tableSchema t == tSchema) (dbTables dbStructure) in - case mTable of - Nothing -> return notFound - Just table -> - let allowH = (hAllow, if tableInsertable table then "GET,POST,PATCH,DELETE" else "GET") - allOrigins = ("Access-Control-Allow-Origin", "*") :: Header in - return $ responseLBS status200 [allOrigins, allowH] mempty - - (ActionInvoke invMethod, TargetProc proc@ProcDescription{pdSchema, pdName} _) -> - let tName = fromMaybe pdName $ procTableName proc in - case readSqlParts pdSchema tName of - Left errorResponse -> return errorResponse - Right (q, cq, bField, returning) -> do - let - preferParams = iPreferParameters apiRequest - pq = requestToCallProcQuery (QualifiedIdentifier pdSchema pdName) (specifiedProcArgs (iColumns apiRequest) proc) - (iPayload apiRequest) returnsScalar preferParams returning - stm = callProcStatement returnsScalar returnsSingle pq q cq shouldCount (contentType == CTSingularJSON) - (contentType == CTTextCSV) (preferParams == Just MultipleObjects) bField pgVer prepared - row <- H.statement mempty stm - let (tableTotal, queryTotal, body, gucHeaders, gucStatus) = row - gucs = (,) <$> gucHeaders <*> gucStatus - case gucs of - Left err -> return $ errorResponseFor err - Right (ghdrs, gstatus) -> do - let (rangeStatus, contentRange) = rangeStatusHeader topLevelRange queryTotal tableTotal - status = fromMaybe rangeStatus gstatus - headers = addHeadersIfNotIncluded - (catMaybes [Just $ toHeader contentType, Just contentRange, profileH]) - (unwrapGucHeader <$> ghdrs) - rBody = if invMethod == InvHead then mempty else toS body - if contentType == CTSingularJSON && queryTotal /= 1 - then do - HT.condemn - return . errorResponseFor . singularityError $ queryTotal - else - return $ responseLBS status headers rBody - - (ActionInspect headersOnly, TargetDefaultSpec tSchema) -> do - let host = configServerHost conf - port = toInteger $ configServerPort conf - proxy = pickProxy $ toS <$> configOpenApiServerProxyUri conf - uri Nothing = ("http", host, port, "/") - uri (Just Proxy { proxyScheme = s, proxyHost = h, proxyPort = p, proxyPath = b }) = (s, h, p, b) - uri' = uri proxy - toTableInfo :: [Table] -> [(Table, [Column], [Text])] - toTableInfo = map (\t -> let (s, tn) = (tableSchema t, tableName t) in (t, tableCols dbStructure s tn, tablePKCols dbStructure s tn)) - encodeApi ti sd procs = encodeOpenAPI (concat $ M.elems procs) (toTableInfo ti) uri' sd $ dbPrimaryKeys dbStructure - - body <- encodeApi <$> - H.statement tSchema (accessibleTables prepared) <*> - H.statement tSchema (schemaDescription prepared) <*> - H.statement tSchema (accessibleProcs prepared) - return $ responseLBS status200 (catMaybes [Just $ toHeader CTOpenAPI, profileH]) (if headersOnly then mempty else toS body) - - _ -> return notFound - - where - notFound = responseLBS status404 mempty "" - maxRows = configDbMaxRows conf - prepared = configDbPreparedStatements conf - exactCount = iPreferCount apiRequest == Just ExactCount - estimatedCount = iPreferCount apiRequest == Just EstimatedCount - plannedCount = iPreferCount apiRequest == Just PlannedCount - shouldCount = exactCount || estimatedCount - topLevelRange = iTopLevelRange apiRequest - returnsScalar = - case iTarget apiRequest of - TargetProc proc _ -> procReturnsScalar proc - _ -> False - returnsSingle = - case iTarget apiRequest of - TargetProc proc _ -> procReturnsSingle proc - _ -> False - pgVer = pgVersion dbStructure - profileH = contentProfileH <$> iProfile apiRequest - - readSqlParts s t = - let - readReq = readRequest s t maxRows (dbRelations dbStructure) apiRequest - returnings :: ReadRequest -> Either Response [FieldName] - returnings rr = Right (returningCols rr []) - in - (,,,) <$> - (readRequestToQuery <$> readReq) <*> - (readRequestToCountQuery <$> readReq) <*> - (binaryField contentType rawContentTypes returnsScalar =<< readReq) <*> - (returnings =<< readReq) - - mutateSqlParts s t = - let - readReq = readRequest s t maxRows (dbRelations dbStructure) apiRequest - mutReq = mutateRequest s t apiRequest (tablePKCols dbStructure s t) =<< readReq - in - (,) <$> - (readRequestToQuery <$> readReq) <*> - (mutateRequestToQuery <$> mutReq) - -responseContentTypeOrError :: [ContentType] -> [ContentType] -> Action -> Target -> Either Response ContentType -responseContentTypeOrError accepts rawContentTypes action target = serves contentTypesForRequest accepts + Just dbStructure -> + return dbStructure + Nothing -> + throwError Error.ConnectionLostError + + apiRequest@ApiRequest{..} <- + liftEither . mapLeft Error.ApiRequestError $ + ApiRequest.userApiRequest configDbSchemas configDbRootSpec dbStructure req body + + -- The JWT must be checked before touching the db + jwtClaims <- Auth.jwtClaims conf (toS iJWT) time + + contentType <- + case ApiRequest.mutuallyAgreeable (requestContentTypes conf apiRequest) iAccepts of + Just ct -> + return ct + Nothing -> + throwError . Error.ContentTypeError $ map toMime iAccepts + + let + handleReq apiReq = + handleRequest $ RequestContext conf dbStructure apiReq contentType + + runDbHandler pool (txMode apiRequest) jwtClaims . + Middleware.optionalRollback conf apiRequest $ + Middleware.runPgLocals conf jwtClaims handleReq apiRequest + +runDbHandler :: SQL.Pool -> SQL.Mode -> Auth.JWTClaims -> DbHandler a -> Handler IO a +runDbHandler pool mode jwtClaims handler = do + dbResp <- + lift . SQL.use pool . SQL.transaction SQL.ReadCommitted mode $ runExceptT handler + + resp <- + liftEither . mapLeft Error.PgErr $ + mapLeft (Error.PgError $ Auth.containsRole jwtClaims) dbResp + + liftEither resp + +handleRequest :: RequestContext -> DbHandler Wai.Response +handleRequest context@(RequestContext _ _ ApiRequest{..} _) = + case (iAction, iTarget) of + (ActionRead headersOnly, TargetIdent identifier) -> + handleRead headersOnly identifier context + (ActionCreate, TargetIdent identifier) -> + handleCreate identifier context + (ActionUpdate, TargetIdent identifier) -> + handleUpdate identifier context + (ActionSingleUpsert, TargetIdent identifier) -> + handleSingleUpsert identifier context + (ActionDelete, TargetIdent identifier) -> + handleDelete identifier context + (ActionInfo, TargetIdent identifier) -> + handleInfo identifier context + (ActionInvoke invMethod, TargetProc proc _) -> + handleInvoke invMethod proc context + (ActionInspect headersOnly, TargetDefaultSpec tSchema) -> + handleOpenApi headersOnly tSchema context + _ -> + throwError Error.NotFound + +handleRead :: Bool -> QualifiedIdentifier -> RequestContext -> DbHandler Wai.Response +handleRead headersOnly identifier context@RequestContext{..} = do + req <- readRequest identifier context + bField <- binaryField context req + + let + ApiRequest{..} = ctxApiRequest + AppConfig{..} = ctxConfig + countQuery = QueryBuilder.readRequestToCountQuery req + + (tableTotal, queryTotal, _ , body, gucHeaders, gucStatus) <- + lift . SQL.statement mempty $ + Statements.createReadStatement + (QueryBuilder.readRequestToQuery req) + (if iPreferCount == Just EstimatedCount then + -- LIMIT maxRows + 1 so we can determine below that maxRows was surpassed + QueryBuilder.limitedQuery countQuery ((+ 1) <$> configDbMaxRows) + else + countQuery + ) + (ctxContentType == CTSingularJSON) + (shouldCount iPreferCount) + (ctxContentType == CTTextCSV) + bField + (pgVersion ctxDbStructure) + configDbPreparedStatements + + total <- readTotal ctxConfig ctxApiRequest tableTotal countQuery + response <- liftEither $ gucResponse <$> gucStatus <*> gucHeaders + + let + (status, contentRange) = RangeQuery.rangeStatusHeader iTopLevelRange queryTotal total + headers = + [ contentRange + , ( "Content-Location" + , "/" + <> toS (qiName identifier) + <> if BS8.null iCanonicalQS then mempty else "?" <> toS iCanonicalQS + ) + ] + ++ contentTypeHeaders context + + failNotSingular ctxContentType queryTotal . response status headers $ + if headersOnly then mempty else toS body + +readTotal :: AppConfig -> ApiRequest -> Maybe Int64 -> SQL.Snippet -> DbHandler (Maybe Int64) +readTotal AppConfig{..} ApiRequest{..} tableTotal countQuery = + case iPreferCount of + Just PlannedCount -> + explain + Just EstimatedCount -> + if tableTotal > (fromIntegral <$> configDbMaxRows) then + max tableTotal <$> explain + else + return tableTotal + _ -> + return tableTotal where - contentTypesForRequest = case action of - ActionRead _ -> [CTApplicationJSON, CTSingularJSON, CTTextCSV] - ++ rawContentTypes - ActionCreate -> [CTApplicationJSON, CTSingularJSON, CTTextCSV] - ActionUpdate -> [CTApplicationJSON, CTSingularJSON, CTTextCSV] - ActionDelete -> [CTApplicationJSON, CTSingularJSON, CTTextCSV] - ActionInvoke _ -> [CTApplicationJSON, CTSingularJSON, CTTextCSV] - ++ rawContentTypes - ++ [CTOpenAPI | tpIsRootSpec target] - ActionInspect _ -> [CTOpenAPI, CTApplicationJSON] - ActionInfo -> [CTTextCSV] - ActionSingleUpsert -> [CTApplicationJSON, CTSingularJSON, CTTextCSV] - serves sProduces cAccepts = - case mutuallyAgreeable sProduces cAccepts of - Nothing -> Left . errorResponseFor . ContentTypeError . map toMime $ cAccepts - Just ct -> Right ct - -{- - | If raw(binary) output is requested, check that ContentType is one of the admitted rawContentTypes and that - | `?select=...` contains only one field other than `*` --} -binaryField :: ContentType -> [ContentType] -> Bool -> ReadRequest -> Either Response (Maybe FieldName) -binaryField ct rawContentTypes isScalarProc readReq - | isScalarProc = - if ct `elem` rawContentTypes - then Right $ Just "pgrst_scalar" - else Right Nothing - | ct `elem` rawContentTypes = - let fieldName = headMay fldNames in - if length fldNames == 1 && fieldName /= Just "*" - then Right fieldName - else Left . errorResponseFor $ BinaryFieldError ct - | otherwise = Right Nothing + explain = + lift . SQL.statement mempty . Statements.createExplainStatement countQuery $ + configDbPreparedStatements + +handleCreate :: QualifiedIdentifier -> RequestContext -> DbHandler Wai.Response +handleCreate identifier@QualifiedIdentifier{..} context@RequestContext{..} = do + let + ApiRequest{..} = ctxApiRequest + pkCols = tablePKCols ctxDbStructure qiSchema qiName + + WriteQueryResult{..} <- writeQuery identifier True pkCols context + + let + response = gucResponse resGucStatus resGucHeaders + headers = + catMaybes + [ if null resFields then + Nothing + else + Just + ( HTTP.hLocation + , "/" + <> toS qiName + <> HTTP.renderSimpleQuery True (splitKeyValue <$> resFields) + ) + , Just . RangeQuery.contentRangeH 1 0 $ + if shouldCount iPreferCount then Just resQueryTotal else Nothing + , if null pkCols && isNothing iOnConflict then + Nothing + else + (\x -> ("Preference-Applied", BS8.pack $ show x)) <$> iPreferResolution + ] + + failNotSingular ctxContentType resQueryTotal $ + if iPreferRepresentation == Full then + response HTTP.status201 (headers ++ contentTypeHeaders context) (toS resBody) + else + response HTTP.status201 headers mempty + +handleUpdate :: QualifiedIdentifier -> RequestContext -> DbHandler Wai.Response +handleUpdate identifier context@(RequestContext _ _ ApiRequest{..} contentType) = do + WriteQueryResult{..} <- writeQuery identifier False mempty context + + let + response = gucResponse resGucStatus resGucHeaders + fullRepr = iPreferRepresentation == Full + updateIsNoOp = Set.null iColumns + status + | resQueryTotal == 0 && not updateIsNoOp = HTTP.status404 + | fullRepr = HTTP.status200 + | otherwise = HTTP.status204 + contentRangeHeader = + RangeQuery.contentRangeH 0 (resQueryTotal - 1) $ + if shouldCount iPreferCount then Just resQueryTotal else Nothing + + failNotSingular contentType resQueryTotal $ + if fullRepr then + response status (contentTypeHeaders context ++ [contentRangeHeader]) (toS resBody) + else + response status [contentRangeHeader] mempty + +handleSingleUpsert :: QualifiedIdentifier -> RequestContext-> DbHandler Wai.Response +handleSingleUpsert identifier context@(RequestContext _ _ ApiRequest{..} _) = do + when (iTopLevelRange /= RangeQuery.allRange) $ + throwError Error.PutRangeNotAllowedError + + WriteQueryResult{..} <- writeQuery identifier False mempty context + + let response = gucResponse resGucStatus resGucHeaders + + -- Makes sure the querystring pk matches the payload pk + -- e.g. PUT /items?id=eq.1 { "id" : 1, .. } is accepted, + -- PUT /items?id=eq.14 { "id" : 2, .. } is rejected. + -- If this condition is not satisfied then nothing is inserted, + -- check the WHERE for INSERT in QueryBuilder.hs to see how it's done + when (resQueryTotal /= 1) $ do + lift SQL.condemn + throwError Error.PutMatchingPkError + + return $ + if iPreferRepresentation == Full then + response HTTP.status200 (contentTypeHeaders context) (toS resBody) + else + response HTTP.status204 (contentTypeHeaders context) mempty + +handleDelete :: QualifiedIdentifier -> RequestContext -> DbHandler Wai.Response +handleDelete identifier context@(RequestContext _ _ ApiRequest{..} contentType) = do + WriteQueryResult{..} <- writeQuery identifier False mempty context + + let + response = gucResponse resGucStatus resGucHeaders + contentRangeHeader = + RangeQuery.contentRangeH 1 0 $ + if shouldCount iPreferCount then Just resQueryTotal else Nothing + + failNotSingular contentType resQueryTotal $ + if iPreferRepresentation == Full then + response HTTP.status200 + (contentTypeHeaders context ++ [contentRangeHeader]) + (toS resBody) + else + response HTTP.status204 [contentRangeHeader] mempty + +handleInfo :: Monad m => QualifiedIdentifier -> RequestContext -> Handler m Wai.Response +handleInfo identifier RequestContext{..} = + case find tableMatches $ dbTables ctxDbStructure of + Just table -> + return $ Wai.responseLBS HTTP.status200 [allOrigins, allowH table] mempty + Nothing -> + throwError Error.NotFound where - fldNames = fstFieldNames readReq + allOrigins = ("Access-Control-Allow-Origin", "*") + allowH table = + ( HTTP.hAllow + , if tableInsertable table then "GET,POST,PATCH,DELETE" else "GET" + ) + tableMatches table = + tableName table == qiName identifier + && tableSchema table == qiSchema identifier + +handleInvoke :: InvokeMethod -> ProcDescription -> RequestContext -> DbHandler Wai.Response +handleInvoke invMethod proc context@RequestContext{..} = do + let + ApiRequest{..} = ctxApiRequest + + identifier = + QualifiedIdentifier + (pdSchema proc) + (fromMaybe (pdName proc) $ procTableName proc) + + returnsSingle (ApiRequest.TargetProc target _) = procReturnsSingle target + returnsSingle _ = False + + req <- readRequest identifier context + bField <- binaryField context req + + (tableTotal, queryTotal, body, gucHeaders, gucStatus) <- + lift . SQL.statement mempty $ + Statements.callProcStatement + (returnsScalar iTarget) + (returnsSingle iTarget) + (QueryBuilder.requestToCallProcQuery + (QualifiedIdentifier (pdSchema proc) (pdName proc)) + (specifiedProcArgs iColumns proc) + iPayload + (returnsScalar iTarget) + iPreferParameters + (ReqBuilder.returningCols req []) + ) + (QueryBuilder.readRequestToQuery req) + (QueryBuilder.readRequestToCountQuery req) + (shouldCount iPreferCount) + (ctxContentType == CTSingularJSON) + (ctxContentType == CTTextCSV) + (iPreferParameters == Just MultipleObjects) + bField + (pgVersion ctxDbStructure) + (configDbPreparedStatements ctxConfig) + + response <- liftEither $ gucResponse <$> gucStatus <*> gucHeaders -locationH :: TableName -> [BS.ByteString] -> Header -locationH tName fields = let - locationFields = renderSimpleQuery True $ splitKeyValue <$> fields - in - (hLocation, "/" <> toS tName <> locationFields) + (status, contentRange) = + RangeQuery.rangeStatusHeader iTopLevelRange queryTotal tableTotal + + failNotSingular ctxContentType queryTotal $ + response status + (contentTypeHeaders context ++ [contentRange]) + (if invMethod == InvHead then mempty else toS body) + +handleOpenApi :: Bool -> Schema -> RequestContext -> DbHandler Wai.Response +handleOpenApi headersOnly tSchema (RequestContext conf@AppConfig{..} dbStructure apiRequest _) = do + body <- + lift $ + OpenAPI.encode conf dbStructure + <$> SQL.statement tSchema (DbStructure.accessibleTables configDbPreparedStatements) + <*> SQL.statement tSchema (DbStructure.schemaDescription configDbPreparedStatements) + <*> SQL.statement tSchema (DbStructure.accessibleProcs configDbPreparedStatements) + + return $ + Wai.responseLBS HTTP.status200 + (toHeader CTOpenAPI : maybeToList (profileHeader apiRequest)) + (if headersOnly then mempty else toS body) + +txMode :: ApiRequest -> SQL.Mode +txMode ApiRequest{..} = + case (iAction, iTarget) of + (ActionRead _, _) -> + SQL.Read + (ActionInfo, _) -> + SQL.Read + (ActionInspect _, _) -> + SQL.Read + (ActionInvoke InvGet, _) -> + SQL.Read + (ActionInvoke InvHead, _) -> + SQL.Read + (ActionInvoke InvPost, TargetProc ProcDescription{pdVolatility=Stable} _) -> + SQL.Read + (ActionInvoke InvPost, TargetProc ProcDescription{pdVolatility=Immutable} _) -> + SQL.Read + _ -> + SQL.Write + +-- | Result from executing a write query on the database +data WriteQueryResult = WriteQueryResult + { resQueryTotal :: Int64 + , resFields :: [ByteString] + , resBody :: ByteString + , resGucStatus :: Maybe HTTP.Status + , resGucHeaders :: [GucHeader] + } + +writeQuery :: QualifiedIdentifier -> Bool -> [Text] -> RequestContext -> DbHandler WriteQueryResult +writeQuery identifier@QualifiedIdentifier{..} isInsert pkCols context@RequestContext{..} = do + readReq <- readRequest identifier context + + mutateReq <- + liftEither $ + ReqBuilder.mutateRequest qiSchema qiName ctxApiRequest + (tablePKCols ctxDbStructure qiSchema qiName) + readReq + + (_, queryTotal, fields, body, gucHeaders, gucStatus) <- + lift . SQL.statement mempty $ + Statements.createWriteStatement + (QueryBuilder.readRequestToQuery readReq) + (QueryBuilder.mutateRequestToQuery mutateReq) + (ctxContentType == CTSingularJSON) + isInsert + (ctxContentType == CTTextCSV) + (iPreferRepresentation ctxApiRequest) + pkCols + (pgVersion ctxDbStructure) + (configDbPreparedStatements ctxConfig) + + liftEither $ WriteQueryResult queryTotal fields body <$> gucStatus <*> gucHeaders + +-- | Response with headers and status overridden from GUCs. +gucResponse + :: Maybe HTTP.Status + -> [GucHeader] + -> HTTP.Status + -> [HTTP.Header] + -> LBS.ByteString + -> Wai.Response +gucResponse gucStatus gucHeaders status headers = + Wai.responseLBS (fromMaybe status gucStatus) $ + addHeadersIfNotIncluded headers (map unwrapGucHeader gucHeaders) + +-- | +-- Fail a response if a single JSON object was requested and not exactly one +-- was found. +failNotSingular :: ContentType -> Int64 -> Wai.Response -> DbHandler Wai.Response +failNotSingular contentType queryTotal response = + if contentType == CTSingularJSON && queryTotal /= 1 then + do + lift SQL.condemn + throwError $ Error.singularityError queryTotal + else + return response + +shouldCount :: Maybe PreferCount -> Bool +shouldCount preferCount = + preferCount == Just ExactCount || preferCount == Just EstimatedCount + +returnsScalar :: ApiRequest.Target -> Bool +returnsScalar (TargetProc proc _) = procReturnsScalar proc +returnsScalar _ = False + +readRequest :: Monad m => QualifiedIdentifier -> RequestContext -> Handler m ReadRequest +readRequest QualifiedIdentifier{..} (RequestContext AppConfig{..} dbStructure apiRequest _) = + liftEither $ + ReqBuilder.readRequest qiSchema qiName configDbMaxRows + (dbRelations dbStructure) + apiRequest + +contentTypeHeaders :: RequestContext -> [HTTP.Header] +contentTypeHeaders RequestContext{..} = + toHeader ctxContentType : maybeToList (profileHeader ctxApiRequest) + +requestContentTypes :: AppConfig -> ApiRequest -> [ContentType] +requestContentTypes conf ApiRequest{..} = + case iAction of + ActionRead _ -> defaultContentTypes ++ rawContentTypes conf + ActionInvoke _ -> invokeContentTypes + ActionInspect _ -> [CTOpenAPI, CTApplicationJSON] + ActionInfo -> [CTTextCSV] + _ -> defaultContentTypes + where + invokeContentTypes = + defaultContentTypes + ++ rawContentTypes conf + ++ [CTOpenAPI | ApiRequest.tpIsRootSpec iTarget] + defaultContentTypes = + [CTApplicationJSON, CTSingularJSON, CTTextCSV] + +-- | +-- If raw(binary) output is requested, check that ContentType is one of the admitted +-- rawContentTypes and that`?select=...` contains only one field other than `*` +binaryField :: Monad m => RequestContext -> ReadRequest -> Handler m (Maybe FieldName) +binaryField RequestContext{..} readReq + | returnsScalar (iTarget ctxApiRequest) && ctxContentType `elem` rawContentTypes ctxConfig = + return $ Just "pgrst_scalar" + | ctxContentType `elem` rawContentTypes ctxConfig = + let + fldNames = fstFieldNames readReq + fieldName = headMay fldNames + in + if length fldNames == 1 && fieldName /= Just "*" then + return fieldName + else + throwError $ Error.BinaryFieldError ctxContentType + | otherwise = + return Nothing + +rawContentTypes :: AppConfig -> [ContentType] +rawContentTypes AppConfig{..} = + (decodeContentType <$> configRawMediaTypes) `union` [CTOctetStream, CTTextPlain] + +profileHeader :: ApiRequest -> Maybe HTTP.Header +profileHeader ApiRequest{..} = + (,) "Content-Profile" <$> (toS <$> iProfile) + +splitKeyValue :: ByteString -> (ByteString, ByteString) +splitKeyValue kv = + (k, BS8.tail v) where - splitKeyValue :: BS.ByteString -> (BS.ByteString, BS.ByteString) - splitKeyValue kv = - let (k, v) = BS.break (== '=') kv - in (k, BS.tail v) - -contentLocationH :: TableName -> ByteString -> Header -contentLocationH tName qString = - ("Content-Location", "/" <> toS tName <> if BS.null qString then mempty else "?" <> toS qString) - -contentProfileH :: Schema -> Header -contentProfileH schema = - ("Content-Profile", toS schema) + (k, v) = BS8.break (== '=') kv diff --git a/src/PostgREST/Auth.hs b/src/PostgREST/Auth.hs index 98d728fb90c..3d73e622f46 100644 --- a/src/PostgREST/Auth.hs +++ b/src/PostgREST/Auth.hs @@ -1,5 +1,3 @@ -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE LambdaCase #-} {-| Module : PostgREST.Auth Description : PostgREST authorization functions. @@ -12,108 +10,70 @@ Authentication should always be implemented in an external service. In the test suite there is an example of simple login function that can be used for a very simple authentication system inside the PostgreSQL database. -} -module PostgREST.Auth ( - containsRole - , jwtClaims - , attemptJwtClaims - , parseSecret - ) where +{-# LANGUAGE RecordWildCards #-} +module PostgREST.Auth (containsRole, jwtClaims, JWTClaims) where -import qualified Crypto.JOSE.Types as JOSE.Types +import qualified Crypto.JWT as JWT import qualified Data.Aeson as JSON import qualified Data.HashMap.Strict as M -import Data.Vector as V - -import Control.Lens (set) -import Data.Time.Clock (UTCTime) - -import Crypto.JWT - -import PostgREST.Error (SimpleError (..)) -import PostgREST.Types -import Protolude hiding (toS) -import Protolude.Conv (toS) - -{-| - Possible situations encountered with client JWTs --} -data JWTAttempt = JWTInvalid JWTError - | JWTMissingSecret - | JWTClaims (M.HashMap Text JSON.Value) - - -jwtClaims :: JWTAttempt -> Either SimpleError (M.HashMap Text JSON.Value) -jwtClaims attempt = - case attempt of - JWTMissingSecret -> Left JwtTokenMissing - JWTInvalid JWTExpired -> Left $ JwtTokenInvalid "JWT expired" - JWTInvalid e -> Left $ JwtTokenInvalid $ show e - JWTClaims claims -> Right claims - -{-| - Receives the JWT secret and audience (from config) and a JWT and returns a map - of JWT claims. --} -attemptJwtClaims :: Maybe JWKSet -> Maybe StringOrURI -> LByteString -> UTCTime -> JSPath -> IO JWTAttempt -attemptJwtClaims _ _ "" _ _ = return $ JWTClaims M.empty -attemptJwtClaims maybeSecret audience payload time jspath = - case maybeSecret of - Nothing -> return JWTMissingSecret - Just secret -> do - let validation = set allowedSkew 1 $ defaultJWTValidationSettings (maybe (const True) (==) audience) - eJwt <- runExceptT $ do - jwt <- decodeCompact payload - verifyClaimsAt validation secret time jwt - return $ case eJwt of - Left e -> JWTInvalid e - Right jwt -> JWTClaims $ claims2map jwt jspath - -{-| - Turn JWT ClaimSet into something easier to work with, - also here the jspath is applied to put the "role" in the map --} -claims2map :: ClaimsSet -> JSPath -> M.HashMap Text JSON.Value -claims2map claims jspath = (\case +import qualified Data.Vector as V + +import Control.Lens (set) +import Control.Monad.Except (liftEither) +import Data.Either.Combinators (mapLeft) +import Data.Time.Clock (UTCTime) + +import PostgREST.Config (AppConfig (..)) +import PostgREST.Error (Error (..)) +import PostgREST.Types (JSPath, JSPathExp (..)) + +import Protolude + + +type JWTClaims = M.HashMap Text JSON.Value + +-- | Receives the JWT secret and audience (from config) and a JWT and returns a +-- map of JWT claims. +jwtClaims :: Monad m => + AppConfig -> LByteString -> UTCTime -> ExceptT Error m JWTClaims +jwtClaims _ "" _ = return M.empty +jwtClaims AppConfig{..} payload time = do + secret <- liftEither . maybeToRight JwtTokenMissing $ configJWKS + eitherClaims <- + lift . runExceptT $ + JWT.verifyClaimsAt validation secret time =<< JWT.decodeCompact payload + liftEither . mapLeft jwtClaimsError $ claimsMap configJwtRoleClaimKey <$> eitherClaims + where + validation = + JWT.defaultJWTValidationSettings audienceCheck & set JWT.allowedSkew 1 + + audienceCheck :: JWT.StringOrURI -> Bool + audienceCheck = maybe (const True) (==) configJwtAudience + + jwtClaimsError :: JWT.JWTError -> Error + jwtClaimsError JWT.JWTExpired = JwtTokenInvalid "JWT expired" + jwtClaimsError e = JwtTokenInvalid $ show e + +-- | Turn JWT ClaimSet into something easier to work with. +-- +-- Also, here the jspath is applied to put the "role" in the map. +claimsMap :: JSPath -> JWT.ClaimsSet -> JWTClaims +claimsMap jspath claims = + case JSON.toJSON claims of val@(JSON.Object o) -> - let role = maybe M.empty (M.singleton "role") $ - walkJSPath (Just val) jspath in - M.delete "role" o `M.union` role -- mutating the map - _ -> M.empty - ) $ JSON.toJSON claims - -walkJSPath :: Maybe JSON.Value -> JSPath -> Maybe JSON.Value -walkJSPath x [] = x -walkJSPath (Just (JSON.Object o)) (JSPKey key:rest) = walkJSPath (M.lookup key o) rest -walkJSPath (Just (JSON.Array ar)) (JSPIdx idx:rest) = walkJSPath (ar V.!? idx) rest -walkJSPath _ _ = Nothing - -{-| - Whether a response from jwtClaims contains a role claim --} -containsRole :: M.HashMap Text JSON.Value -> Bool + M.delete "role" o `M.union` role val + _ -> + M.empty + where + role value = + maybe M.empty (M.singleton "role") $ walkJSPath (Just value) jspath + + walkJSPath :: Maybe JSON.Value -> JSPath -> Maybe JSON.Value + walkJSPath x [] = x + walkJSPath (Just (JSON.Object o)) (JSPKey key:rest) = walkJSPath (M.lookup key o) rest + walkJSPath (Just (JSON.Array ar)) (JSPIdx idx:rest) = walkJSPath (ar V.!? idx) rest + walkJSPath _ _ = Nothing + +-- | Whether a response from jwtClaims contains a role claim +containsRole :: JWTClaims -> Bool containsRole = M.member "role" - -{-| - Parse `jwt-secret` configuration option and turn into a JWKSet. - - There are three ways to specify `jwt-secret`: text secret, JSON Web Key - (JWK), or JSON Web Key Set (JWKS). The first two are converted into a JWKSet - with one key and the last is converted as is. --} -parseSecret :: ByteString -> JWKSet -parseSecret str = - fromMaybe (maybe secret (\jwk' -> JWKSet [jwk']) maybeJWK) - maybeJWKSet - where - maybeJWKSet = JSON.decode (toS str) :: Maybe JWKSet - maybeJWK = JSON.decode (toS str) :: Maybe JWK - secret = JWKSet [jwkFromSecret str] - -{-| - Internal helper to generate a symmetric HMAC-SHA256 JWK from a text secret. --} -jwkFromSecret :: ByteString -> JWK -jwkFromSecret key = - fromKeyMaterial km - where - km = OctKeyMaterial (OctKeyParameters (JOSE.Types.Base64Octets key)) diff --git a/src/PostgREST/Config.hs b/src/PostgREST/Config.hs index 6471c471f64..554b6022b38 100644 --- a/src/PostgREST/Config.hs +++ b/src/PostgREST/Config.hs @@ -20,20 +20,24 @@ Other hardcoded options such as the minimum version number also belong here. {-# LANGUAGE TemplateHaskell #-} {-# OPTIONS_GHC -fno-warn-type-defaults #-} -module PostgREST.Config ( prettyVersion - , docsVersion - , CLI (..) - , Command (..) - , AppConfig (..) - , configDbPoolTimeout' - , dumpAppConfig - , Environment - , readCLIShowHelp - , readEnvironment - , readConfig - ) - where - +module PostgREST.Config + ( prettyVersion + , docsVersion + , CLI (..) + , Command (..) + , AppConfig (..) + , configDbPoolTimeout' + , dumpAppConfig + , Environment + , readCLIShowHelp + , readEnvironment + , readConfig + , parseSecret + ) where + +import qualified Crypto.JOSE.Types as JOSE +import qualified Crypto.JWT as JWT +import qualified Data.Aeson as JSON import qualified Data.ByteString as B import qualified Data.ByteString.Base64 as B64 import qualified Data.ByteString.Char8 as BS @@ -42,7 +46,7 @@ import qualified Data.Map.Strict as M import Control.Lens (preview) import Control.Monad (fail) -import Crypto.JWT (JWKSet, StringOrURI, stringOrUri) +import Crypto.JWT (JWK, JWKSet, StringOrURI, stringOrUri) import Data.Aeson (encode, toJSON) import Data.Either.Combinators (mapLeft) import Data.List (lookup) @@ -63,9 +67,8 @@ import System.Posix.Types (FileMode) import Control.Applicative import Data.Monoid import Options.Applicative hiding (str) -import Text.Heredoc +import Text.Heredoc (str) -import PostgREST.Auth (parseSecret) import PostgREST.Parsers (pRoleClaimKey) import PostgREST.Private.ProxyUri (isMalformedProxyUri) import PostgREST.Types (JSPath, JSPathExp (..), @@ -593,3 +596,21 @@ loadDbUriFile conf = extractDbUri mDbUri Nothing -> return dbUri Just filename -> strip <$> readFile (toS filename) setDbUri dbUri = conf {configDbUri = dbUri} + + +{-| + Parse `jwt-secret` configuration option and turn into a JWKSet. + + There are three ways to specify `jwt-secret`: text secret, JSON Web Key + (JWK), or JSON Web Key Set (JWKS). The first two are converted into a JWKSet + with one key and the last is converted as is. +-} +parseSecret :: ByteString -> JWKSet +parseSecret bytes = + fromMaybe (maybe secret (\jwk' -> JWT.JWKSet [jwk']) maybeJWK) + maybeJWKSet + where + maybeJWKSet = JSON.decode (toS bytes) :: Maybe JWKSet + maybeJWK = JSON.decode (toS bytes) :: Maybe JWK + secret = JWT.JWKSet [JWT.fromKeyMaterial keyMaterial] + keyMaterial = JWT.OctKeyMaterial . JWT.OctKeyParameters $ JOSE.Base64Octets bytes diff --git a/src/PostgREST/DbRequestBuilder.hs b/src/PostgREST/DbRequestBuilder.hs index 020d06c4607..a52a30f085c 100644 --- a/src/PostgREST/DbRequestBuilder.hs +++ b/src/PostgREST/DbRequestBuilder.hs @@ -7,7 +7,6 @@ This module is in charge of building an intermediate representation(ReadRequest, A query tree is built in case of resource embedding. By inferring the relationship between tables, join conditions are added for every embedded resource. -} {-# LANGUAGE DuplicateRecordFields #-} -{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE RecordWildCards #-} @@ -28,10 +27,9 @@ import Data.Text (isInfixOf) import Control.Applicative import Data.Tree -import Network.Wai import PostgREST.ApiRequest (Action (..), ApiRequest (..)) -import PostgREST.Error (ApiRequestError (..), errorResponseFor) +import PostgREST.Error (ApiRequestError (..), Error (..)) import PostgREST.Parsers import PostgREST.RangeQuery (NonnegRange, allRange, restrictRange) import PostgREST.Types @@ -40,9 +38,9 @@ import Protolude hiding (from) -- | Builds the ReadRequest tree on a number of stages. -- | Adds filters, order, limits on its respective nodes. -- | Adds joins conditions obtained from resource embedding. -readRequest :: Schema -> TableName -> Maybe Integer -> [Relation] -> ApiRequest -> Either Response ReadRequest +readRequest :: Schema -> TableName -> Maybe Integer -> [Relation] -> ApiRequest -> Either Error ReadRequest readRequest schema rootTableName maxRows allRels apiRequest = - mapLeft errorResponseFor $ + mapLeft ApiRequestError $ treeRestrictRange maxRows =<< augmentRequestWithJoin schema rootRels =<< addFiltersOrdersRanges apiRequest =<< @@ -281,8 +279,8 @@ addProperty f (targetNodeName:remainingPath, a) (Node rn forest) = where pathNode = find (\(Node (_,(nodeName,_,alias,_,_)) _) -> nodeName == targetNodeName || alias == Just targetNodeName) forest -mutateRequest :: Schema -> TableName -> ApiRequest -> [FieldName] -> ReadRequest -> Either Response MutateRequest -mutateRequest schema tName apiRequest pkCols readReq = mapLeft errorResponseFor $ +mutateRequest :: Schema -> TableName -> ApiRequest -> [FieldName] -> ReadRequest -> Either Error MutateRequest +mutateRequest schema tName apiRequest pkCols readReq = mapLeft ApiRequestError $ case action of ActionCreate -> do confCols <- case iOnConflict apiRequest of diff --git a/src/PostgREST/Error.hs b/src/PostgREST/Error.hs index da18263cb99..c5d97fbbe73 100644 --- a/src/PostgREST/Error.hs +++ b/src/PostgREST/Error.hs @@ -3,14 +3,13 @@ Module : PostgREST.Error Description : PostgREST error HTTP responses -} {-# OPTIONS_GHC -fno-warn-orphans #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE RecordWildCards #-} module PostgREST.Error ( errorResponseFor , ApiRequestError(..) , PgError(..) -, SimpleError(..) +, Error(..) , errorPayload , checkIsFatal , singularityError @@ -220,7 +219,7 @@ checkIsFatal (PgError _ (P.SessionError (H.QueryError _ _ (H.ResultError (H.Serv checkIsFatal _ = Nothing -data SimpleError +data Error = GucHeadersError | GucStatusError | BinaryFieldError ContentType @@ -231,8 +230,11 @@ data SimpleError | JwtTokenInvalid Text | SingularityError Integer | ContentTypeError [ByteString] + | NotFound + | ApiRequestError ApiRequestError + | PgErr PgError -instance PgrstError SimpleError where +instance PgrstError Error where status GucHeadersError = HT.status500 status GucStatusError = HT.status500 status (BinaryFieldError _) = HT.status406 @@ -243,12 +245,17 @@ instance PgrstError SimpleError where status (JwtTokenInvalid _) = HT.unauthorized401 status (SingularityError _) = HT.status406 status (ContentTypeError _) = HT.status415 + status NotFound = HT.status404 + status (PgErr err) = status err + status (ApiRequestError err) = status err headers (SingularityError _) = [toHeader CTSingularJSON] headers (JwtTokenInvalid m) = [toHeader CTApplicationJSON, invalidTokenHeader m] + headers (PgErr err) = headers err + headers (ApiRequestError err) = headers err headers _ = [toHeader CTApplicationJSON] -instance JSON.ToJSON SimpleError where +instance JSON.ToJSON Error where toJSON GucHeadersError = JSON.object [ "message" .= ("response.headers guc must be a JSON array composed of objects with a single key and a string value" :: Text)] toJSON GucStatusError = JSON.object [ @@ -273,10 +280,13 @@ instance JSON.ToJSON SimpleError where "message" .= ("Server lacks JWT secret" :: Text)] toJSON (JwtTokenInvalid message) = JSON.object [ "message" .= (message :: Text)] + toJSON NotFound = JSON.object [] + toJSON (PgErr err) = JSON.toJSON err + toJSON (ApiRequestError err) = JSON.toJSON err invalidTokenHeader :: Text -> Header invalidTokenHeader m = ("WWW-Authenticate", "Bearer error=\"invalid_token\", " <> "error_description=" <> encodeUtf8 (show m)) -singularityError :: (Integral a) => a -> SimpleError +singularityError :: (Integral a) => a -> Error singularityError = SingularityError . toInteger diff --git a/src/PostgREST/Middleware.hs b/src/PostgREST/Middleware.hs index 844b65d1c33..5e6732ba9e6 100644 --- a/src/PostgREST/Middleware.hs +++ b/src/PostgREST/Middleware.hs @@ -2,11 +2,15 @@ Module : PostgREST.Middleware Description : Sets CORS policy. Also the PostgreSQL GUCs, role, search_path and pre-request function. -} -{-# OPTIONS_GHC -fno-warn-orphans #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE ScopedTypeVariables #-} - -module PostgREST.Middleware where +{-# LANGUAGE RecordWildCards #-} +module PostgREST.Middleware + ( runPgLocals + , pgrstFormat + , pgrstMiddleware + , defaultCorsPolicy + , corsPolicy + , optionalRollback + ) where import qualified Hasql.Decoders as HD import qualified Hasql.DynamicStatements.Statement as H @@ -23,8 +27,10 @@ import Data.Scientific (FPFormat (..), isInteger) import qualified Data.Text as T import qualified Hasql.Transaction as H +import qualified Network.HTTP.Types.Header as HTTP import Network.HTTP.Types.Status (Status, status400, status500, statusCode) +import qualified Network.Wai as Wai import Network.Wai.Logger (showSockAddr) import System.Log.FastLogger (toLogStr) @@ -35,8 +41,11 @@ import Network.Wai.Middleware.Gzip (def, gzip) import Network.Wai.Middleware.RequestLogger import Network.Wai.Middleware.Static (only, staticPolicy) +import qualified PostgREST.Types as Types + import PostgREST.ApiRequest (ApiRequest (..)) import PostgREST.Config (AppConfig (..)) +import PostgREST.Error (Error, errorResponseFor) import PostgREST.QueryBuilder (setConfigLocal) import PostgREST.Types (LogLevel (..)) import Protolude hiding (head, toS) @@ -45,13 +54,13 @@ import System.IO.Unsafe (unsafePerformIO) -- | Runs local(transaction scoped) GUCs for every request, plus the pre-request function runPgLocals :: AppConfig -> M.HashMap Text JSON.Value -> - (ApiRequest -> H.Transaction Response) -> - ApiRequest -> H.Transaction Response + (ApiRequest -> ExceptT Error H.Transaction Response) -> + ApiRequest -> ExceptT Error H.Transaction Response runPgLocals conf claims app req = do - H.statement mempty $ H.dynamicallyParameterized + lift $ H.statement mempty $ H.dynamicallyParameterized ("select " <> intercalateSnippet ", " (searchPathSql : roleSql ++ claimsSql ++ [methodSql, pathSql] ++ headersSql ++ cookiesSql ++ appSettingsSql)) HD.noResult (configDbPreparedStatements conf) - traverse_ H.sql preReqSql + lift $ traverse_ H.sql preReqSql app req where methodSql = setConfigLocal mempty ("request.method", toS $ iMethod req) @@ -140,3 +149,30 @@ unquoted (JSON.Number n) = toS $ formatScientific Fixed (if isInteger n then Just 0 else Nothing) n unquoted (JSON.Bool b) = show b unquoted v = toS $ JSON.encode v + +-- | Set a transaction to eventually roll back if requested and set respective +-- headers on the response. +optionalRollback + :: AppConfig + -> ApiRequest + -> ExceptT Error H.Transaction Wai.Response + -> ExceptT Error H.Transaction Wai.Response +optionalRollback AppConfig{..} ApiRequest{..} transaction = do + resp <- catchError transaction $ return . errorResponseFor + when (shouldRollback || (configDbTxRollbackAll && not shouldCommit)) + (lift H.condemn) + return $ Wai.mapResponseHeaders preferenceApplied resp + where + shouldCommit = + configDbTxAllowOverride && iPreferTransaction == Just Types.Commit + shouldRollback = + configDbTxAllowOverride && iPreferTransaction == Just Types.Rollback + preferenceApplied + | shouldCommit = + Types.addHeadersIfNotIncluded + [(HTTP.hPreferenceApplied, BS.pack (show Types.Commit))] + | shouldRollback = + Types.addHeadersIfNotIncluded + [(HTTP.hPreferenceApplied, BS.pack (show Types.Rollback))] + | otherwise = + identity diff --git a/src/PostgREST/OpenAPI.hs b/src/PostgREST/OpenAPI.hs index c11b3316573..2d8f33c3443 100644 --- a/src/PostgREST/OpenAPI.hs +++ b/src/PostgREST/OpenAPI.hs @@ -2,17 +2,15 @@ Module : PostgREST.OpenAPI Description : Generates the OpenAPI output -} -{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RecordWildCards #-} +module PostgREST.OpenAPI (encode) where -module PostgREST.OpenAPI ( - encodeOpenAPI -, pickProxy -) where - -import qualified Data.HashSet.InsOrd as Set +import qualified Data.Aeson as JSON +import qualified Data.ByteString.Lazy as LBS +import qualified Data.HashMap.Strict as HashMap +import qualified Data.HashSet.InsOrd as Set import Control.Arrow ((&&&)) -import Data.Aeson (decode, encode) import Data.HashMap.Strict.InsOrd (InsOrdHashMap, fromList) import Data.Maybe (fromJust) import Data.String (IsString (..)) @@ -25,16 +23,29 @@ import Control.Lens import Data.Swagger import PostgREST.ApiRequest (ContentType (..)) -import PostgREST.Config (docsVersion, prettyVersion) +import PostgREST.Config (AppConfig (..), docsVersion, + prettyVersion) import PostgREST.Private.ProxyUri (isMalformedProxyUri, toURI) -import PostgREST.Types (Column (..), ForeignKey (..), - PgArg (..), PrimaryKey (..), +import PostgREST.Types (Column (..), DbStructure (..), + ForeignKey (..), PgArg (..), + PrimaryKey (..), ProcDescription (..), Proxy (..), - Table (..), toMime) + Table (..), tableCols, tableName, + tablePKCols, tableSchema, toMime) import Protolude hiding (Proxy, dropWhile, get, intercalate, toLower, toS, (&)) import Protolude.Conv (toS) +encode :: AppConfig -> DbStructure -> [Table] -> Maybe Text -> HashMap.HashMap k [ProcDescription] -> LBS.ByteString +encode conf dbStructure tables schemaDescription procs = + JSON.encode $ + postgrestSpec + (concat $ HashMap.elems procs) + (openApiTableInfo dbStructure <$> tables) + (proxyUri conf) + schemaDescription + (dbPrimaryKeys dbStructure) + makeMimeList :: [ContentType] -> MimeList makeMimeList cs = MimeList $ map (fromString . toS . toMime) cs @@ -63,7 +74,7 @@ makeTableDef pks (t, cs, _) = makeProperty :: [PrimaryKey] -> Column -> (Text, Referenced Schema) makeProperty pks c = (colName c, Inline s) where - e = if null $ colEnum c then Nothing else decode $ encode $ colEnum c + e = if null $ colEnum c then Nothing else JSON.decode $ JSON.encode $ colEnum c fk ForeignKey{fkCol=Column{colTable=Table{tableName=a}, colName=b}} = intercalate "" ["This is a Foreign Key to `", a, ".", b, "`."] pk :: Bool @@ -80,7 +91,7 @@ makeProperty pks c = (colName c, Inline s) colDescription c s = (mempty :: Schema) - & default_ .~ (decode . toS =<< colDefault c) + & default_ .~ (JSON.decode . toS =<< colDefault c) & description .~ d & enum_ .~ e & format ?~ colType c @@ -111,7 +122,7 @@ makePreferParam ts = & schema .~ ParamOther ((mempty :: ParamOtherSchema) & in_ .~ ParamHeader & type_ ?~ SwaggerString - & enum_ .~ decode (encode ts)) + & enum_ .~ JSON.decode (JSON.encode ts)) makeProcParam :: ProcDescription -> [Referenced Param] makeProcParam pd = @@ -162,7 +173,7 @@ makeParamDefs ti = & schema .~ ParamOther ((mempty :: ParamOtherSchema) & in_ .~ ParamHeader & type_ ?~ SwaggerString - & default_ .~ decode "\"items\"")) + & default_ .~ JSON.decode "\"items\"")) , ("offset", (mempty :: Param) & name .~ "offset" & description ?~ "Limiting and Pagination" @@ -303,9 +314,6 @@ postgrestSpec pds ti (s, h, p, b) sd pks = (mempty :: Swagger) h' = Just $ Host (unpack $ escapeHostName h) (Just (fromInteger p)) d = fromMaybe "This is a dynamic API generated by PostgREST" sd -encodeOpenAPI :: [ProcDescription] -> [(Table, [Column], [Text])] -> (Text, Text, Integer, Text) -> Maybe Text -> [PrimaryKey] -> LByteString -encodeOpenAPI pds ti uri sd pks = encode $ postgrestSpec pds ti uri sd pks - pickProxy :: Maybe Text -> Maybe Proxy pickProxy proxy | isNothing proxy = Nothing @@ -334,3 +342,18 @@ pickProxy proxy ("", "http") -> 80 ("", "https") -> 443 _ -> readPort $ unpack $ tail $ pack port' + +proxyUri :: AppConfig -> (Text, Text, Integer, Text) +proxyUri AppConfig{..} = + case pickProxy $ toS <$> configOpenApiServerProxyUri of + Just Proxy{..} -> + (proxyScheme, proxyHost, proxyPort, proxyPath) + Nothing -> + ("http", configServerHost, toInteger configServerPort, "/") + +openApiTableInfo :: DbStructure -> Table -> (Table, [Column], [Text]) +openApiTableInfo dbStructure table = + ( table + , tableCols dbStructure (tableSchema table) (tableName table) + , tablePKCols dbStructure (tableSchema table) (tableName table) + ) diff --git a/src/PostgREST/Statements.hs b/src/PostgREST/Statements.hs index c8e38554ce5..7bfebfa3cb2 100644 --- a/src/PostgREST/Statements.hs +++ b/src/PostgREST/Statements.hs @@ -45,7 +45,7 @@ import Text.InterpolatedString.Perl6 (q) is represented as a list of strings containing variable bindings like @"k1=eq.42"@, or the empty list if there is no location header. -} -type ResultsWithCount = (Maybe Int64, Int64, [BS.ByteString], BS.ByteString, Either SimpleError [GucHeader], Either SimpleError (Maybe Status)) +type ResultsWithCount = (Maybe Int64, Int64, [BS.ByteString], BS.ByteString, Either Error [GucHeader], Either Error (Maybe Status)) createWriteStatement :: H.Snippet -> H.Snippet -> Bool -> Bool -> Bool -> PreferRepresentation -> [Text] -> PgVersion -> Bool -> @@ -130,7 +130,7 @@ standardRow = (,,,,,) <$> nullableColumn HD.int8 <*> column HD.int8 <*> (fromMaybe (Right []) <$> nullableColumn decodeGucHeaders) <*> (fromMaybe (Right Nothing) <$> nullableColumn decodeGucStatus) -type ProcResults = (Maybe Int64, Int64, ByteString, Either SimpleError [GucHeader], Either SimpleError (Maybe Status)) +type ProcResults = (Maybe Int64, Int64, ByteString, Either Error [GucHeader], Either Error (Maybe Status)) callProcStatement :: Bool -> Bool -> H.Snippet -> H.Snippet -> H.Snippet -> Bool -> Bool -> Bool -> Bool -> Maybe FieldName -> PgVersion -> Bool -> @@ -189,10 +189,10 @@ createExplainStatement countQuery = let row = HD.singleRow $ column HD.bytea in (^? L.nth 0 . L.key "Plan" . L.key "Plan Rows" . L._Integral) <$> row -decodeGucHeaders :: HD.Value (Either SimpleError [GucHeader]) +decodeGucHeaders :: HD.Value (Either Error [GucHeader]) decodeGucHeaders = first (const GucHeadersError) . JSON.eitherDecode . toS <$> HD.bytea -decodeGucStatus :: HD.Value (Either SimpleError (Maybe Status)) +decodeGucStatus :: HD.Value (Either Error (Maybe Status)) decodeGucStatus = first (const GucStatusError) . fmap (Just . toEnum . fst) . decimal <$> HD.text -- | Get db settings from the connection role. Only used for configuration. diff --git a/test/SpecHelper.hs b/test/SpecHelper.hs index 4f92ab23e8f..80ebe3decaa 100644 --- a/test/SpecHelper.hs +++ b/test/SpecHelper.hs @@ -22,8 +22,7 @@ import Test.Hspec import Test.Hspec.Wai import Text.Heredoc -import PostgREST.Auth (parseSecret) -import PostgREST.Config (AppConfig (..)) +import PostgREST.Config (AppConfig (..), parseSecret) import PostgREST.Types (JSPathExp (..), LogLevel (..)) import Protolude hiding (toS) import Protolude.Conv (toS) diff --git a/test/io-tests/test_io.py b/test/io-tests/test_io.py index ac66a22cfc0..a53bb172092 100644 --- a/test/io-tests/test_io.py +++ b/test/io-tests/test_io.py @@ -297,6 +297,7 @@ def test_expected_config_from_db_settings(defaultenv, role, expectedconfig): assert dumpconfig(configpath=config, env=env) == expected + @pytest.mark.parametrize( "config", [conf for conf in CONFIGSDIR.iterdir() if conf.suffix == ".config"], @@ -604,6 +605,7 @@ def test_max_rows_notify_reload(defaultenv): # reset max-rows config on the db postgrest.session.post("/rpc/reset_max_rows_config") + def test_invalid_role_claim_key_notify_reload(defaultenv): "NOTIFY reload config should show an error if role-claim-key is invalid"