Skip to content

Commit

Permalink
shuffle things around to clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
monacoremo committed Jan 16, 2021
1 parent b90515c commit bd0dd75
Show file tree
Hide file tree
Showing 8 changed files with 142 additions and 144 deletions.
128 changes: 25 additions & 103 deletions src/PostgREST/App.hs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import Data.Time.Clock (UTCTime)

import qualified Data.ByteString.Char8 as BS
import qualified Data.ByteString.Lazy as LazyBS
import qualified Data.HashMap.Strict as HashMap
import qualified Data.Set as Set
import qualified Hasql.DynamicStatements.Snippet as SQL
import qualified Hasql.Pool as SQL
Expand All @@ -46,6 +45,7 @@ import qualified PostgREST.Types as Types

import PostgREST.ApiRequest (ApiRequest)
import PostgREST.Config (AppConfig)
import PostgREST.Error (Error)
import PostgREST.Types (ContentType, DbStructure,
QualifiedIdentifier, ReadRequest)

Expand All @@ -61,11 +61,10 @@ data RequestContext =
, rContentType :: ContentType
}

type Handler = ExceptT Error.SimpleError
type Handler = ExceptT Error

type DbHandler = Handler SQL.Transaction


-- | PostgREST application
postgrest
:: Types.LogLevel
Expand All @@ -84,7 +83,7 @@ postgrest logLev refConf refDbStructure pool getTime connWorker =
maybeDbStructure <- readIORef refDbStructure

let
eitherResponse :: IO (Either Error.SimpleError Wai.Response)
eitherResponse :: IO (Either Error Wai.Response)
eitherResponse =
runExceptT $ postgrestResponse conf maybeDbStructure pool time req

Expand Down Expand Up @@ -128,11 +127,17 @@ postgrestResponse conf maybeDbStructure pool time req =
-- The JWT must be checked before touching the db
jwtClaims <- Auth.jwtClaims conf (toS $ Req.iJWT apiRequest) time

let
reqContentTypes = requestContentTypes conf apiRequest
acceptContentType = Req.iAccepts apiRequest

contentType <-
liftEither . maybeToRight (contentTypeError apiRequest) $
Req.mutuallyAgreeable
(requestContentTypes conf apiRequest)
(Req.iAccepts apiRequest)
case Req.mutuallyAgreeable reqContentTypes acceptContentType of
Just ct ->
return ct

Nothing ->
throwError . Error.ContentTypeError $ map Types.toMime acceptContentType

let
context apiReq =
Expand All @@ -142,7 +147,7 @@ postgrestResponse conf maybeDbStructure pool time req =
handleRequest (context apiReq)

runDbHandler pool (txMode apiRequest) jwtClaims $
optionalRollback conf apiRequest $
Middleware.optionalRollback conf apiRequest $
Middleware.runPgLocals conf jwtClaims handleReq apiRequest

runDbHandler :: SQL.Pool -> SQL.Mode -> Auth.JWTClaims -> DbHandler a -> Handler IO a
Expand Down Expand Up @@ -383,8 +388,7 @@ handleDelete identifier context@(RequestContext _ _ apiRequest contentType) =
else
response HTTP.status204 [contentRangeHeader] mempty

handleInfo :: Monad m =>
QualifiedIdentifier -> DbStructure -> Handler m Wai.Response
handleInfo :: Monad m => QualifiedIdentifier -> DbStructure -> Handler m Wai.Response
handleInfo identifier dbStructure =
let
allowH table =
Expand All @@ -394,8 +398,12 @@ handleInfo identifier dbStructure =

allOrigins =
("Access-Control-Allow-Origin", "*")

tableMatches table =
Types.tableName table == Types.qiName identifier
&& Types.tableSchema table == Types.qiSchema identifier
in
case findTable dbStructure identifier of
case find tableMatches (Types.dbTables dbStructure) of
Just table ->
return $ Wai.responseLBS HTTP.status200 [allOrigins, allowH table] mempty

Expand Down Expand Up @@ -465,21 +473,13 @@ handleInvoke invMethod proc context@(RequestContext conf dbStructure apiRequest

handleOpenApi :: Bool -> Types.Schema -> RequestContext -> DbHandler Wai.Response
handleOpenApi headersOnly tSchema (RequestContext conf dbStructure apiRequest _) =
let
encodeApi tables schemaDescription procs =
OpenAPI.encodeOpenAPI
(concat $ HashMap.elems procs)
(fmap (openApiTableInfo dbStructure) tables)
(openApiUri conf)
schemaDescription
(Types.dbPrimaryKeys dbStructure)
in
do
body <-
lift $ encodeApi
<$> SQL.statement tSchema DbStructure.accessibleTables
<*> SQL.statement tSchema DbStructure.schemaDescription
<*> SQL.statement tSchema DbStructure.accessibleProcs
lift $
OpenAPI.encode conf dbStructure
<$> SQL.statement tSchema DbStructure.accessibleTables
<*> SQL.statement tSchema DbStructure.schemaDescription
<*> SQL.statement tSchema DbStructure.accessibleProcs

return $
Wai.responseLBS HTTP.status200
Expand Down Expand Up @@ -662,86 +662,8 @@ profileHeader :: ApiRequest -> Maybe HTTP.Header
profileHeader apiRequest =
(,) "Content-Profile" <$> (toS <$> Req.iProfile apiRequest)

contentTypeError :: ApiRequest -> Error.SimpleError
contentTypeError apiRequest =
Error.ContentTypeError $ map Types.toMime (Req.iAccepts apiRequest)

-- MIDDLEWARE

-- |
-- Set a transaction to eventually roll back if requested and set respective
-- headers on the response.
optionalRollback
:: AppConfig
-> ApiRequest
-> DbHandler Wai.Response
-> DbHandler Wai.Response
optionalRollback conf apiRequest transaction =
let
shouldCommit =
Config.configDbTxAllowOverride conf
&& Req.iPreferTransaction apiRequest == Just Types.Commit

shouldRollback =
Config.configDbTxAllowOverride conf
&& Req.iPreferTransaction apiRequest == Just Types.Rollback

preferenceApplied
| shouldCommit =
Types.addHeadersIfNotIncluded
[(HTTP.hPreferenceApplied, BS.pack (show Types.Commit))]
| shouldRollback =
Types.addHeadersIfNotIncluded
[(HTTP.hPreferenceApplied, BS.pack (show Types.Rollback))]
| otherwise =
identity
in
do
resp <- transaction

when (shouldRollback || (Config.configDbTxRollbackAll conf && not shouldCommit))
(lift SQL.condemn)

return $ Wai.mapResponseHeaders preferenceApplied resp

openApiUri :: AppConfig -> (Text, Text, Integer, Text)
openApiUri conf =
case OpenAPI.pickProxy $ toS <$> Config.configOpenApiServerProxyUri conf of
Just proxy ->
( Types.proxyScheme proxy
, Types.proxyHost proxy
, Types.proxyPort proxy
, Types.proxyPath proxy
)

Nothing ->
("http"
, Config.configServerHost conf
, toInteger $ Config.configServerPort conf
, "/"
)

findTable :: DbStructure -> Types.QualifiedIdentifier -> Maybe Types.Table
findTable dbStructure identifier =
find tableMatches (Types.dbTables dbStructure)
where
tableMatches table =
Types.tableName table == Types.qiName identifier
&& Types.tableSchema table == Types.qiSchema identifier

splitKeyValue :: ByteString -> (ByteString, ByteString)
splitKeyValue kv =
(k, BS.tail v)
where
(k, v) = BS.break (== '=') kv

openApiTableInfo :: DbStructure -> Types.Table -> (Types.Table, [Types.Column], [Text])
openApiTableInfo dbStructure table =
let
schema = Types.tableSchema table
name = Types.tableName table
in
( table
, Types.tableCols dbStructure schema name
, Types.tablePKCols dbStructure schema name
)
12 changes: 5 additions & 7 deletions src/PostgREST/Auth.hs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import Data.Time.Clock (UTCTime)

import Crypto.JWT (JWTError (..))

import PostgREST.Error (SimpleError (..))
import PostgREST.Error (Error (..))
import PostgREST.Types (JSPath, JSPathExp (..))

import PostgREST.Config as Config
Expand All @@ -37,7 +37,7 @@ type JWTClaims = M.HashMap Text JSON.Value
-- | Receives the JWT secret and audience (from config) and a JWT and returns a
-- map of JWT claims.
jwtClaims :: Monad m =>
Config.AppConfig -> LByteString -> UTCTime -> ExceptT SimpleError m JWTClaims
Config.AppConfig -> LByteString -> UTCTime -> ExceptT Error m JWTClaims
jwtClaims _ "" _ = return M.empty
jwtClaims conf payload time =
do
Expand All @@ -51,15 +51,13 @@ jwtClaims conf payload time =
set JWT.allowedSkew 1 $ JWT.defaultJWTValidationSettings audienceCheck

audienceCheck :: JWT.StringOrURI -> Bool
audienceCheck =
maybe (const True) (==) (Config.configJwtAudience conf)
audienceCheck = maybe (const True) (==) (Config.configJwtAudience conf)

jwtClaimsError :: JWTError -> SimpleError
jwtClaimsError :: JWTError -> Error
jwtClaimsError JWTExpired = JwtTokenInvalid "JWT expired"
jwtClaimsError e = JwtTokenInvalid $ show e

jspath =
rightToMaybe $ Config.configJwtRoleClaimKey conf
jspath = rightToMaybe $ Config.configJwtRoleClaimKey conf

-- | Turn JWT ClaimSet into something easier to work with.
--
Expand Down
4 changes: 2 additions & 2 deletions src/PostgREST/Config.hs
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ import System.Posix.Types (FileMode)

import Control.Applicative
import Data.Monoid
import Options.Applicative hiding (str)
import Text.Heredoc (str)
import Options.Applicative hiding (str)
import Text.Heredoc (str)

import PostgREST.Parsers (pRoleClaimKey)
import PostgREST.Private.ProxyUri (isMalformedProxyUri)
Expand Down
6 changes: 3 additions & 3 deletions src/PostgREST/DbRequestBuilder.hs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import Control.Applicative
import Data.Tree

import PostgREST.ApiRequest (Action (..), ApiRequest (..))
import PostgREST.Error (ApiRequestError (..), SimpleError (..))
import PostgREST.Error (ApiRequestError (..), Error (..))
import PostgREST.Parsers
import PostgREST.RangeQuery (NonnegRange, allRange, restrictRange)
import PostgREST.Types
Expand All @@ -39,7 +39,7 @@ import Protolude hiding (from)
-- | Builds the ReadRequest tree on a number of stages.
-- | Adds filters, order, limits on its respective nodes.
-- | Adds joins conditions obtained from resource embedding.
readRequest :: Schema -> TableName -> Maybe Integer -> [Relation] -> ApiRequest -> Either SimpleError ReadRequest
readRequest :: Schema -> TableName -> Maybe Integer -> [Relation] -> ApiRequest -> Either Error ReadRequest
readRequest schema rootTableName maxRows allRels apiRequest =
mapLeft ApiRequestError $
treeRestrictRange maxRows =<<
Expand Down Expand Up @@ -280,7 +280,7 @@ addProperty f (targetNodeName:remainingPath, a) (Node rn forest) =
where
pathNode = find (\(Node (_,(nodeName,_,alias,_,_)) _) -> nodeName == targetNodeName || alias == Just targetNodeName) forest

mutateRequest :: Schema -> TableName -> ApiRequest -> [FieldName] -> ReadRequest -> Either SimpleError MutateRequest
mutateRequest :: Schema -> TableName -> ApiRequest -> [FieldName] -> ReadRequest -> Either Error MutateRequest
mutateRequest schema tName apiRequest pkCols readReq = mapLeft ApiRequestError $
case action of
ActionCreate -> do
Expand Down
10 changes: 5 additions & 5 deletions src/PostgREST/Error.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ module PostgREST.Error (
errorResponseFor
, ApiRequestError(..)
, PgError(..)
, SimpleError(..)
, Error(..)
, errorPayload
, checkIsFatal
, singularityError
Expand Down Expand Up @@ -220,7 +220,7 @@ checkIsFatal (PgError _ (P.SessionError (H.QueryError _ _ (H.ResultError (H.Serv
checkIsFatal _ = Nothing


data SimpleError
data Error
= GucHeadersError
| GucStatusError
| BinaryFieldError ContentType
Expand All @@ -235,7 +235,7 @@ data SimpleError
| ApiRequestError ApiRequestError
| PgErr PgError

instance PgrstError SimpleError where
instance PgrstError Error where
status GucHeadersError = HT.status500
status GucStatusError = HT.status500
status (BinaryFieldError _) = HT.status406
Expand All @@ -256,7 +256,7 @@ instance PgrstError SimpleError where
headers (ApiRequestError err) = headers err
headers _ = [toHeader CTApplicationJSON]

instance JSON.ToJSON SimpleError where
instance JSON.ToJSON Error where
toJSON GucHeadersError = JSON.object [
"message" .= ("response.headers guc must be a JSON array composed of objects with a single key and a string value" :: Text)]
toJSON GucStatusError = JSON.object [
Expand Down Expand Up @@ -289,5 +289,5 @@ invalidTokenHeader :: Text -> Header
invalidTokenHeader m =
("WWW-Authenticate", "Bearer error=\"invalid_token\", " <> "error_description=" <> encodeUtf8 (show m))

singularityError :: (Integral a) => a -> SimpleError
singularityError :: (Integral a) => a -> Error
singularityError = SingularityError . toInteger
Loading

0 comments on commit bd0dd75

Please sign in to comment.