Skip to content

Commit

Permalink
refactor: App.hs and related changes (PostgREST#1725)
Browse files Browse the repository at this point in the history
* Use ExceptT to avoid 'staircasing' case analysis in App.hs
* Split large function in App.hs into individual handler functions
* Adapt API of Auth.hs, OpenApi.hs etc. to simplify the use of those modules in App.hs
* Split optional rollback functionality into Middleware
* Unify SimpleError and ApiRequestError into one Error type, so it can be used across modules
  • Loading branch information
monacoremo authored Feb 23, 2021
1 parent e12685d commit 2e6e4f6
Show file tree
Hide file tree
Showing 11 changed files with 752 additions and 585 deletions.
1 change: 1 addition & 0 deletions postgrest.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ library
, jose >= 0.8.1 && < 0.9
, lens >= 4.14 && < 4.20
, lens-aeson >= 1.0.1 && < 1.2
, mtl >= 2.2.2 && < 2.3
, network-uri >= 2.6.1 && < 2.8
, optparse-applicative >= 0.13 && < 0.17
, parsec >= 3.1.11 && < 3.2
Expand Down
953 changes: 535 additions & 418 deletions src/PostgREST/App.hs

Large diffs are not rendered by default.

164 changes: 62 additions & 102 deletions src/PostgREST/Auth.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-|
Module : PostgREST.Auth
Description : PostgREST authorization functions.
Expand All @@ -12,108 +10,70 @@ Authentication should always be implemented in an external service.
In the test suite there is an example of simple login function that can be used for a
very simple authentication system inside the PostgreSQL database.
-}
module PostgREST.Auth (
containsRole
, jwtClaims
, attemptJwtClaims
, parseSecret
) where
{-# LANGUAGE RecordWildCards #-}
module PostgREST.Auth (containsRole, jwtClaims, JWTClaims) where

import qualified Crypto.JOSE.Types as JOSE.Types
import qualified Crypto.JWT as JWT
import qualified Data.Aeson as JSON
import qualified Data.HashMap.Strict as M
import Data.Vector as V

import Control.Lens (set)
import Data.Time.Clock (UTCTime)

import Crypto.JWT

import PostgREST.Error (SimpleError (..))
import PostgREST.Types
import Protolude hiding (toS)
import Protolude.Conv (toS)

{-|
Possible situations encountered with client JWTs
-}
data JWTAttempt = JWTInvalid JWTError
| JWTMissingSecret
| JWTClaims (M.HashMap Text JSON.Value)


jwtClaims :: JWTAttempt -> Either SimpleError (M.HashMap Text JSON.Value)
jwtClaims attempt =
case attempt of
JWTMissingSecret -> Left JwtTokenMissing
JWTInvalid JWTExpired -> Left $ JwtTokenInvalid "JWT expired"
JWTInvalid e -> Left $ JwtTokenInvalid $ show e
JWTClaims claims -> Right claims

{-|
Receives the JWT secret and audience (from config) and a JWT and returns a map
of JWT claims.
-}
attemptJwtClaims :: Maybe JWKSet -> Maybe StringOrURI -> LByteString -> UTCTime -> JSPath -> IO JWTAttempt
attemptJwtClaims _ _ "" _ _ = return $ JWTClaims M.empty
attemptJwtClaims maybeSecret audience payload time jspath =
case maybeSecret of
Nothing -> return JWTMissingSecret
Just secret -> do
let validation = set allowedSkew 1 $ defaultJWTValidationSettings (maybe (const True) (==) audience)
eJwt <- runExceptT $ do
jwt <- decodeCompact payload
verifyClaimsAt validation secret time jwt
return $ case eJwt of
Left e -> JWTInvalid e
Right jwt -> JWTClaims $ claims2map jwt jspath

{-|
Turn JWT ClaimSet into something easier to work with,
also here the jspath is applied to put the "role" in the map
-}
claims2map :: ClaimsSet -> JSPath -> M.HashMap Text JSON.Value
claims2map claims jspath = (\case
import qualified Data.Vector as V

import Control.Lens (set)
import Control.Monad.Except (liftEither)
import Data.Either.Combinators (mapLeft)
import Data.Time.Clock (UTCTime)

import PostgREST.Config (AppConfig (..))
import PostgREST.Error (Error (..))
import PostgREST.Types (JSPath, JSPathExp (..))

import Protolude


type JWTClaims = M.HashMap Text JSON.Value

-- | Receives the JWT secret and audience (from config) and a JWT and returns a
-- map of JWT claims.
jwtClaims :: Monad m =>
AppConfig -> LByteString -> UTCTime -> ExceptT Error m JWTClaims
jwtClaims _ "" _ = return M.empty
jwtClaims AppConfig{..} payload time = do
secret <- liftEither . maybeToRight JwtTokenMissing $ configJWKS
eitherClaims <-
lift . runExceptT $
JWT.verifyClaimsAt validation secret time =<< JWT.decodeCompact payload
liftEither . mapLeft jwtClaimsError $ claimsMap configJwtRoleClaimKey <$> eitherClaims
where
validation =
JWT.defaultJWTValidationSettings audienceCheck & set JWT.allowedSkew 1

audienceCheck :: JWT.StringOrURI -> Bool
audienceCheck = maybe (const True) (==) configJwtAudience

jwtClaimsError :: JWT.JWTError -> Error
jwtClaimsError JWT.JWTExpired = JwtTokenInvalid "JWT expired"
jwtClaimsError e = JwtTokenInvalid $ show e

-- | Turn JWT ClaimSet into something easier to work with.
--
-- Also, here the jspath is applied to put the "role" in the map.
claimsMap :: JSPath -> JWT.ClaimsSet -> JWTClaims
claimsMap jspath claims =
case JSON.toJSON claims of
val@(JSON.Object o) ->
let role = maybe M.empty (M.singleton "role") $
walkJSPath (Just val) jspath in
M.delete "role" o `M.union` role -- mutating the map
_ -> M.empty
) $ JSON.toJSON claims

walkJSPath :: Maybe JSON.Value -> JSPath -> Maybe JSON.Value
walkJSPath x [] = x
walkJSPath (Just (JSON.Object o)) (JSPKey key:rest) = walkJSPath (M.lookup key o) rest
walkJSPath (Just (JSON.Array ar)) (JSPIdx idx:rest) = walkJSPath (ar V.!? idx) rest
walkJSPath _ _ = Nothing

{-|
Whether a response from jwtClaims contains a role claim
-}
containsRole :: M.HashMap Text JSON.Value -> Bool
M.delete "role" o `M.union` role val
_ ->
M.empty
where
role value =
maybe M.empty (M.singleton "role") $ walkJSPath (Just value) jspath

walkJSPath :: Maybe JSON.Value -> JSPath -> Maybe JSON.Value
walkJSPath x [] = x
walkJSPath (Just (JSON.Object o)) (JSPKey key:rest) = walkJSPath (M.lookup key o) rest
walkJSPath (Just (JSON.Array ar)) (JSPIdx idx:rest) = walkJSPath (ar V.!? idx) rest
walkJSPath _ _ = Nothing

-- | Whether a response from jwtClaims contains a role claim
containsRole :: JWTClaims -> Bool
containsRole = M.member "role"

{-|
Parse `jwt-secret` configuration option and turn into a JWKSet.
There are three ways to specify `jwt-secret`: text secret, JSON Web Key
(JWK), or JSON Web Key Set (JWKS). The first two are converted into a JWKSet
with one key and the last is converted as is.
-}
parseSecret :: ByteString -> JWKSet
parseSecret str =
fromMaybe (maybe secret (\jwk' -> JWKSet [jwk']) maybeJWK)
maybeJWKSet
where
maybeJWKSet = JSON.decode (toS str) :: Maybe JWKSet
maybeJWK = JSON.decode (toS str) :: Maybe JWK
secret = JWKSet [jwkFromSecret str]

{-|
Internal helper to generate a symmetric HMAC-SHA256 JWK from a text secret.
-}
jwkFromSecret :: ByteString -> JWK
jwkFromSecret key =
fromKeyMaterial km
where
km = OctKeyMaterial (OctKeyParameters (JOSE.Types.Base64Octets key))
55 changes: 38 additions & 17 deletions src/PostgREST/Config.hs
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,24 @@ Other hardcoded options such as the minimum version number also belong here.
{-# LANGUAGE TemplateHaskell #-}
{-# OPTIONS_GHC -fno-warn-type-defaults #-}

module PostgREST.Config ( prettyVersion
, docsVersion
, CLI (..)
, Command (..)
, AppConfig (..)
, configDbPoolTimeout'
, dumpAppConfig
, Environment
, readCLIShowHelp
, readEnvironment
, readConfig
)
where

module PostgREST.Config
( prettyVersion
, docsVersion
, CLI (..)
, Command (..)
, AppConfig (..)
, configDbPoolTimeout'
, dumpAppConfig
, Environment
, readCLIShowHelp
, readEnvironment
, readConfig
, parseSecret
) where

import qualified Crypto.JOSE.Types as JOSE
import qualified Crypto.JWT as JWT
import qualified Data.Aeson as JSON
import qualified Data.ByteString as B
import qualified Data.ByteString.Base64 as B64
import qualified Data.ByteString.Char8 as BS
Expand All @@ -42,7 +46,7 @@ import qualified Data.Map.Strict as M

import Control.Lens (preview)
import Control.Monad (fail)
import Crypto.JWT (JWKSet, StringOrURI, stringOrUri)
import Crypto.JWT (JWK, JWKSet, StringOrURI, stringOrUri)
import Data.Aeson (encode, toJSON)
import Data.Either.Combinators (mapLeft)
import Data.List (lookup)
Expand All @@ -63,9 +67,8 @@ import System.Posix.Types (FileMode)
import Control.Applicative
import Data.Monoid
import Options.Applicative hiding (str)
import Text.Heredoc
import Text.Heredoc (str)

import PostgREST.Auth (parseSecret)
import PostgREST.Parsers (pRoleClaimKey)
import PostgREST.Private.ProxyUri (isMalformedProxyUri)
import PostgREST.Types (JSPath, JSPathExp (..),
Expand Down Expand Up @@ -593,3 +596,21 @@ loadDbUriFile conf = extractDbUri mDbUri
Nothing -> return dbUri
Just filename -> strip <$> readFile (toS filename)
setDbUri dbUri = conf {configDbUri = dbUri}


{-|
Parse `jwt-secret` configuration option and turn into a JWKSet.
There are three ways to specify `jwt-secret`: text secret, JSON Web Key
(JWK), or JSON Web Key Set (JWKS). The first two are converted into a JWKSet
with one key and the last is converted as is.
-}
parseSecret :: ByteString -> JWKSet
parseSecret bytes =
fromMaybe (maybe secret (\jwk' -> JWT.JWKSet [jwk']) maybeJWK)
maybeJWKSet
where
maybeJWKSet = JSON.decode (toS bytes) :: Maybe JWKSet
maybeJWK = JSON.decode (toS bytes) :: Maybe JWK
secret = JWT.JWKSet [JWT.fromKeyMaterial keyMaterial]
keyMaterial = JWT.OctKeyMaterial . JWT.OctKeyParameters $ JOSE.Base64Octets bytes
12 changes: 5 additions & 7 deletions src/PostgREST/DbRequestBuilder.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ This module is in charge of building an intermediate representation(ReadRequest,
A query tree is built in case of resource embedding. By inferring the relationship between tables, join conditions are added for every embedded resource.
-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RecordWildCards #-}
Expand All @@ -28,10 +27,9 @@ import Data.Text (isInfixOf)

import Control.Applicative
import Data.Tree
import Network.Wai

import PostgREST.ApiRequest (Action (..), ApiRequest (..))
import PostgREST.Error (ApiRequestError (..), errorResponseFor)
import PostgREST.Error (ApiRequestError (..), Error (..))
import PostgREST.Parsers
import PostgREST.RangeQuery (NonnegRange, allRange, restrictRange)
import PostgREST.Types
Expand All @@ -40,9 +38,9 @@ import Protolude hiding (from)
-- | Builds the ReadRequest tree on a number of stages.
-- | Adds filters, order, limits on its respective nodes.
-- | Adds joins conditions obtained from resource embedding.
readRequest :: Schema -> TableName -> Maybe Integer -> [Relation] -> ApiRequest -> Either Response ReadRequest
readRequest :: Schema -> TableName -> Maybe Integer -> [Relation] -> ApiRequest -> Either Error ReadRequest
readRequest schema rootTableName maxRows allRels apiRequest =
mapLeft errorResponseFor $
mapLeft ApiRequestError $
treeRestrictRange maxRows =<<
augmentRequestWithJoin schema rootRels =<<
addFiltersOrdersRanges apiRequest =<<
Expand Down Expand Up @@ -281,8 +279,8 @@ addProperty f (targetNodeName:remainingPath, a) (Node rn forest) =
where
pathNode = find (\(Node (_,(nodeName,_,alias,_,_)) _) -> nodeName == targetNodeName || alias == Just targetNodeName) forest

mutateRequest :: Schema -> TableName -> ApiRequest -> [FieldName] -> ReadRequest -> Either Response MutateRequest
mutateRequest schema tName apiRequest pkCols readReq = mapLeft errorResponseFor $
mutateRequest :: Schema -> TableName -> ApiRequest -> [FieldName] -> ReadRequest -> Either Error MutateRequest
mutateRequest schema tName apiRequest pkCols readReq = mapLeft ApiRequestError $
case action of
ActionCreate -> do
confCols <- case iOnConflict apiRequest of
Expand Down
24 changes: 17 additions & 7 deletions src/PostgREST/Error.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@ Module : PostgREST.Error
Description : PostgREST error HTTP responses
-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE RecordWildCards #-}

module PostgREST.Error (
errorResponseFor
, ApiRequestError(..)
, PgError(..)
, SimpleError(..)
, Error(..)
, errorPayload
, checkIsFatal
, singularityError
Expand Down Expand Up @@ -220,7 +219,7 @@ checkIsFatal (PgError _ (P.SessionError (H.QueryError _ _ (H.ResultError (H.Serv
checkIsFatal _ = Nothing


data SimpleError
data Error
= GucHeadersError
| GucStatusError
| BinaryFieldError ContentType
Expand All @@ -231,8 +230,11 @@ data SimpleError
| JwtTokenInvalid Text
| SingularityError Integer
| ContentTypeError [ByteString]
| NotFound
| ApiRequestError ApiRequestError
| PgErr PgError

instance PgrstError SimpleError where
instance PgrstError Error where
status GucHeadersError = HT.status500
status GucStatusError = HT.status500
status (BinaryFieldError _) = HT.status406
Expand All @@ -243,12 +245,17 @@ instance PgrstError SimpleError where
status (JwtTokenInvalid _) = HT.unauthorized401
status (SingularityError _) = HT.status406
status (ContentTypeError _) = HT.status415
status NotFound = HT.status404
status (PgErr err) = status err
status (ApiRequestError err) = status err

headers (SingularityError _) = [toHeader CTSingularJSON]
headers (JwtTokenInvalid m) = [toHeader CTApplicationJSON, invalidTokenHeader m]
headers (PgErr err) = headers err
headers (ApiRequestError err) = headers err
headers _ = [toHeader CTApplicationJSON]

instance JSON.ToJSON SimpleError where
instance JSON.ToJSON Error where
toJSON GucHeadersError = JSON.object [
"message" .= ("response.headers guc must be a JSON array composed of objects with a single key and a string value" :: Text)]
toJSON GucStatusError = JSON.object [
Expand All @@ -273,10 +280,13 @@ instance JSON.ToJSON SimpleError where
"message" .= ("Server lacks JWT secret" :: Text)]
toJSON (JwtTokenInvalid message) = JSON.object [
"message" .= (message :: Text)]
toJSON NotFound = JSON.object []
toJSON (PgErr err) = JSON.toJSON err
toJSON (ApiRequestError err) = JSON.toJSON err

invalidTokenHeader :: Text -> Header
invalidTokenHeader m =
("WWW-Authenticate", "Bearer error=\"invalid_token\", " <> "error_description=" <> encodeUtf8 (show m))

singularityError :: (Integral a) => a -> SimpleError
singularityError :: (Integral a) => a -> Error
singularityError = SingularityError . toInteger
Loading

0 comments on commit 2e6e4f6

Please sign in to comment.