diff --git a/src/PostgREST/DbStructure/Proc.hs b/src/PostgREST/DbStructure/Proc.hs index 59927b4034..1990f6fad7 100644 --- a/src/PostgREST/DbStructure/Proc.hs +++ b/src/PostgREST/DbStructure/Proc.hs @@ -8,7 +8,6 @@ module PostgREST.DbStructure.Proc , ProcVolatility(..) , ProcsMap , RetType(..) - , findProc , procReturnsScalar , procReturnsSingle , procTableName @@ -72,28 +71,6 @@ instance Ord ProcDescription where -- | It uses a HashMap for a faster lookup. type ProcsMap = M.HashMap QualifiedIdentifier [ProcDescription] -{-| - Search a pg procedure by its parameters. Since a function can be overloaded, the name is not enough to find it. - An overloaded function can have a different volatility or even a different return type. - Ideally, handling overloaded functions should be left to pg itself. But we need to know certain proc attributes in advance. --} -findProc :: QualifiedIdentifier -> S.Set Text -> Bool -> ProcsMap -> ProcDescription -findProc qi payloadKeys paramsAsSingleObject allProcs = fromMaybe fallback bestMatch - where - -- instead of passing Maybe ProcDescription around, we create a fallback description here when we can't find a matching function - -- args is empty, but because "specifiedProcArgs" will fill the missing arguments with default type text, this is not a problem - fallback = ProcDescription (qiSchema qi) (qiName qi) Nothing mempty (SetOf $ Composite $ QualifiedIdentifier mempty "record") Volatile False - bestMatch = - case M.lookup qi allProcs of - Nothing -> Nothing - Just [proc] -> Just proc -- if it's not an overloaded function then immediately get the ProcDescription - Just procs -> find matches procs -- Handle overloaded functions case - matches proc = - if paramsAsSingleObject - -- if the arg is not of json type let the db give the err - then length (pdArgs proc) == 1 - else payloadKeys `S.isSubsetOf` S.fromList (pgaName <$> pdArgs proc) - {-| Search the procedure parameters by matching them with the specified keys. If the key doesn't match a parameter, a parameter with a default type "text" is assumed. diff --git a/src/PostgREST/Error.hs b/src/PostgREST/Error.hs index 690acdc968..de5708eec6 100644 --- a/src/PostgREST/Error.hs +++ b/src/PostgREST/Error.hs @@ -29,6 +29,8 @@ import Network.HTTP.Types.Header (Header) import PostgREST.ContentType (ContentType (..)) import qualified PostgREST.ContentType as ContentType +import PostgREST.DbStructure.Proc (PgArg (..), + ProcDescription (..)) import PostgREST.DbStructure.Relationship (Cardinality (..), Junction (..), Relationship (..)) @@ -57,6 +59,8 @@ data ApiRequestError | ParseRequestError Text Text | NoRelBetween Text Text | AmbiguousRelBetween Text Text [Relationship] + | AmbiguousRpc [ProcDescription] + | NoRpc Text Text [Text] Bool | InvalidFilters | UnacceptableSchema [Text] | ContentTypeError [ByteString] @@ -71,6 +75,8 @@ instance PgrstError ApiRequestError where status (ParseRequestError _ _) = HT.status400 status (NoRelBetween _ _) = HT.status400 status AmbiguousRelBetween{} = HT.status300 + status (AmbiguousRpc _) = HT.status300 + status NoRpc{} = HT.status404 status (UnacceptableSchema _) = HT.status406 status (ContentTypeError _) = HT.status415 @@ -92,6 +98,12 @@ instance JSON.ToJSON ApiRequestError where "hint" .= ("By following the 'details' key, disambiguate the request by changing the url to /origin?select=relationship(*) or /origin?select=target!relationship(*)" :: Text), "message" .= ("More than one relationship was found for " <> parent <> " and " <> child :: Text), "details" .= (compressedRel <$> rels) ] + toJSON (AmbiguousRpc procs) = JSON.object [ + "hint" .= ("Overloaded functions with the same argument name but different types are not supported" :: Text), + "message" .= ("Could not choose the best candidate function between: " <> T.intercalate ", " [pdSchema p <> "." <> pdName p <> "(" <> T.intercalate ", " [pgaName a <> " => " <> pgaType a | a <- pdArgs p] <> ")" | p <- procs])] + toJSON (NoRpc schema procName payloadKeys hasPreferSingleObject) = JSON.object [ + "hint" .= ("If a new function was created in the database with this name and arguments, try reloading the schema cache." :: Text), + "message" .= ("Could not find the " <> schema <> "." <> procName <> (if hasPreferSingleObject then " function with a single json or jsonb argument" else "(" <> T.intercalate ", " payloadKeys <> ")" <> " function") <> " in the schema cache")] toJSON UnsupportedVerb = JSON.object [ "message" .= ("Unsupported HTTP verb" :: Text)] toJSON InvalidFilters = JSON.object [ diff --git a/src/PostgREST/Request/ApiRequest.hs b/src/PostgREST/Request/ApiRequest.hs index 0ea07584ab..dab8da1b7d 100644 --- a/src/PostgREST/Request/ApiRequest.hs +++ b/src/PostgREST/Request/ApiRequest.hs @@ -53,7 +53,7 @@ import PostgREST.DbStructure.Identifiers (FieldName, Schema) import PostgREST.DbStructure.Proc (PgArg (..), ProcDescription (..), - findProc) + ProcsMap) import PostgREST.Error (ApiRequestError (..)) import PostgREST.Query.SqlFragment (ftsOperators, operators) import PostgREST.RangeQuery (NonnegRange, allRange, @@ -95,6 +95,16 @@ data Action = ActionCreate | ActionRead{isHead :: Bool} | ActionSingleUpsert | ActionInvoke InvokeMethod | ActionInfo | ActionInspect{isHead :: Bool} deriving Eq +-- | The path info that will be mapped to a target (used to handle validations and errors before defining the Target) +data Path + = PathInfo + { pSchema :: Schema, + pName :: Text, + pHasRpc :: Bool, + pIsDefaultSpec :: Bool, + pIsRootSpec :: Bool + } + | PathUnknown -- | The target db object of a user action data Target = TargetIdent QualifiedIdentifier | TargetProc{tProc :: ProcDescription, tpIsRootSpec :: Bool} @@ -127,6 +137,12 @@ jsonRpcParams proc prms = mergeParams (Variadic a) (Variadic b) = Variadic $ b ++ a mergeParams v _ = v -- repeated params for non-variadic arguments are not merged +targetToJsonRpcParams :: Maybe Target -> [(Text, Text)] -> Maybe PayloadJSON +targetToJsonRpcParams target params = + case target of + Just TargetProc{tProc} -> Just $ jsonRpcParams tProc params + _ -> Nothing + {-| Describes what the user wants to do. This data type is a translation of the raw elements of an HTTP request into domain @@ -171,10 +187,11 @@ userApiRequest conf@AppConfig{..} dbStructure req reqBody | shouldParsePayload && isLeft payload = either (Left . InvalidBody . toS) witness payload | isLeft parsedColumns = either Left witness parsedColumns | otherwise = do - acceptContentType <- findAcceptContentType conf action target accepts + acceptContentType <- findAcceptContentType conf action path accepts + checkedTarget <- target return ApiRequest { iAction = action - , iTarget = target + , iTarget = checkedTarget , iRange = ranges , iTopLevelRange = topLevelRange , iPayload = relevantPayload @@ -231,12 +248,12 @@ userApiRequest conf@AppConfig{..} dbStructure req reqBody ((<> ".") <$> "not":M.keys operators) ++ ((<> "(") <$> M.keys ftsOperators) isEmbedPath = T.isInfixOf "." - isTargetingProc = case target of - TargetProc _ _ -> True - _ -> False - isTargetingDefaultSpec = case target of - TargetDefaultSpec _ -> True - _ -> False + isTargetingProc = case path of + PathInfo{pHasRpc, pIsRootSpec} -> pHasRpc || pIsRootSpec + _ -> False + isTargetingDefaultSpec = case path of + PathInfo{pIsDefaultSpec=True} -> True + _ -> False contentType = ContentType.decodeContentType . fromMaybe "application/json" $ lookupHeader "content-type" columns | action `elem` [ActionCreate, ActionUpdate, ActionInvoke InvPost] = toS <$> join (lookup "columns" qParams) @@ -263,13 +280,8 @@ userApiRequest conf@AppConfig{..} dbStructure req reqBody json <- csvToJson <$> CSV.decodeByName reqBody note "All lines must have same number of fields" $ payloadAttributes (JSON.encode json) json CTUrlEncoded -> - let urlEncodedBody = parseSimpleQuery $ toS reqBody in - case target of - TargetProc{tProc} -> - Right $ jsonRpcParams tProc $ (toS *** toS) <$> urlEncodedBody - _ -> - let paramsMap = M.fromList $ (toS *** JSON.String . toS) <$> urlEncodedBody in - Right $ ProcessedJSON (JSON.encode paramsMap) $ S.fromList (M.keys paramsMap) + let paramsMap = M.fromList $ (toS *** JSON.String . toS) <$> parseSimpleQuery (toS reqBody) in + Right $ ProcessedJSON (JSON.encode paramsMap) $ S.fromList (M.keys paramsMap) ct -> Left $ toS $ "Content-Type not acceptable: " <> ContentType.toMime ct topLevelRange = fromMaybe allRange $ M.lookup "limit" ranges -- if no limit is specified, get all the request rows @@ -313,23 +325,32 @@ userApiRequest conf@AppConfig{..} dbStructure req reqBody callFindProc procSch procNam = findProc (QualifiedIdentifier procSch procNam) payloadColumns (hasPrefer (show SingleObject)) $ dbProcs dbStructure in case path of - [] -> case configDbRootSpec of - Just (QualifiedIdentifier pSch pName) -> TargetProc (callFindProc (if pSch == mempty then schema else pSch) pName) True - Nothing | configOpenApiMode == OADisabled -> TargetUnknown - | otherwise -> TargetDefaultSpec schema - [table] -> TargetIdent $ QualifiedIdentifier schema table - ["rpc", pName] -> TargetProc (callFindProc schema pName) False - _ -> TargetUnknown - - shouldParsePayload = action `elem` [ActionCreate, ActionUpdate, ActionSingleUpsert, ActionInvoke InvPost] - relevantPayload = case (target, action) of + PathInfo{pSchema, pName, pHasRpc, pIsRootSpec, pIsDefaultSpec} + | pHasRpc || pIsRootSpec -> (`TargetProc` pIsRootSpec) <$> callFindProc pSchema pName + | pIsDefaultSpec -> Right $ TargetDefaultSpec pSchema + | otherwise -> Right $ TargetIdent $ QualifiedIdentifier pSchema pName + PathUnknown -> Right TargetUnknown + + shouldParsePayload = case (contentType, action) of + (CTUrlEncoded, ActionInvoke InvPost) -> False + (_, act) -> act `elem` [ActionCreate, ActionUpdate, ActionSingleUpsert, ActionInvoke InvPost] + relevantPayload = case (contentType, action) of -- Though ActionInvoke GET/HEAD doesn't really have a payload, we use the payload variable as a way -- to store the query string arguments to the function. - (TargetProc{tProc}, ActionInvoke InvGet) -> Just $ jsonRpcParams tProc rpcQParams - (TargetProc{tProc}, ActionInvoke InvHead) -> Just $ jsonRpcParams tProc rpcQParams - _ | shouldParsePayload -> rightToMaybe payload - | otherwise -> Nothing - path = pathInfo req + (_, ActionInvoke InvGet) -> targetToJsonRpcParams (rightToMaybe target) rpcQParams + (_, ActionInvoke InvHead) -> targetToJsonRpcParams (rightToMaybe target) rpcQParams + (CTUrlEncoded, ActionInvoke InvPost) -> targetToJsonRpcParams (rightToMaybe target) $ (toS *** toS) <$> parseSimpleQuery (toS reqBody) + _ | shouldParsePayload -> rightToMaybe payload + | otherwise -> Nothing + path = + case pathInfo req of + [] -> case configDbRootSpec of + Just (QualifiedIdentifier pSch pName) -> PathInfo (if pSch == mempty then schema else pSch) pName False False True + Nothing | configOpenApiMode == OADisabled -> PathUnknown + | otherwise -> PathInfo schema "" False True False + [table] -> PathInfo schema table False False False + ["rpc", pName] -> PathInfo schema pName True False False + _ -> PathUnknown method = requestMethod req hdrs = requestHeaders req qParams = [(toS k, v)|(k,v) <- qString] @@ -430,16 +451,16 @@ payloadAttributes raw json = where emptyPJArray = ProcessedJSON (JSON.encode emptyArray) S.empty -findAcceptContentType :: AppConfig -> Action -> Target -> [ContentType] -> Either ApiRequestError ContentType -findAcceptContentType conf action target accepts = - case mutuallyAgreeable (requestContentTypes conf action target) accepts of +findAcceptContentType :: AppConfig -> Action -> Path -> [ContentType] -> Either ApiRequestError ContentType +findAcceptContentType conf action path accepts = + case mutuallyAgreeable (requestContentTypes conf action path) accepts of Just ct -> Right ct Nothing -> Left . ContentTypeError $ map ContentType.toMime accepts -requestContentTypes :: AppConfig -> Action -> Target -> [ContentType] -requestContentTypes conf action target = +requestContentTypes :: AppConfig -> Action -> Path -> [ContentType] +requestContentTypes conf action path = case action of ActionRead _ -> defaultContentTypes ++ rawContentTypes conf ActionInvoke _ -> invokeContentTypes @@ -450,10 +471,46 @@ requestContentTypes conf action target = invokeContentTypes = defaultContentTypes ++ rawContentTypes conf - ++ [CTOpenAPI | tpIsRootSpec target] + ++ [CTOpenAPI | pIsRootSpec path] defaultContentTypes = [CTApplicationJSON, CTSingularJSON, CTTextCSV] rawContentTypes :: AppConfig -> [ContentType] rawContentTypes AppConfig{..} = (ContentType.decodeContentType <$> configRawMediaTypes) `union` [CTOctetStream, CTTextPlain] + +{-| + Search a pg procedure by its parameters. Since a function can be overloaded, the name is not enough to find it. + An overloaded function can have a different volatility or even a different return type. +-} +findProc :: QualifiedIdentifier -> S.Set Text -> Bool -> ProcsMap -> Either ApiRequestError ProcDescription +findProc qi payloadKeys paramsAsSingleObject allProcs = + case bestMatch of + [] -> Left $ NoRpc (qiSchema qi) (qiName qi) (S.toList payloadKeys) paramsAsSingleObject + [proc] -> Right proc + procs -> Left $ AmbiguousRpc (toList procs) + where + bestMatch = + case M.lookup qi allProcs of + Nothing -> [] + Just [proc] -> [proc | matches proc] + Just procs -> filter matches procs + -- Find the exact arguments match + matches proc + | paramsAsSingleObject = case pdArgs proc of + [arg] -> pgaType arg `elem` ["json", "jsonb"] + _ -> False + | otherwise = case pdArgs proc of + [] -> null payloadKeys + args -> matchesArg args + matchesArg args = + -- The function's required arguments are separated from the ones with a default value assigned. + -- The set of names of those arguments is compared to the set of keys supplied by the client + -- 1. If only required arguments are found, the keys must be exactly the same as those arguments + -- 2. If only optional arguments are found, the keys must be a subset of those arguments + -- 3. If both required and optional arguments are found, the result of taking away the optional arguments + -- from the keys must be exactly the same as the required arguments + case L.partition pgaReq args of + (reqArgs, []) -> payloadKeys == S.fromList (pgaName <$> reqArgs) + ([], defArgs) -> payloadKeys `S.isSubsetOf` S.fromList (pgaName <$> defArgs) + (reqArgs, defArgs) -> payloadKeys `S.difference` S.fromList (pgaName <$> defArgs) == S.fromList (pgaName <$> reqArgs) diff --git a/test/Feature/RpcSpec.hs b/test/Feature/RpcSpec.hs index 93727f3f32..c8b8758da8 100644 --- a/test/Feature/RpcSpec.hs +++ b/test/Feature/RpcSpec.hs @@ -96,22 +96,61 @@ spec actualPgVersion = context "unknown function" $ do it "returns 404" $ post "/rpc/fakefunc" [json| {} |] `shouldRespondWith` 404 + it "should fail with 404 on unknown proc name" $ get "/rpc/fake" `shouldRespondWith` 404 + it "should fail with 404 on unknown proc args" $ do get "/rpc/sayhello" `shouldRespondWith` 404 get "/rpc/sayhello?any_arg=value" `shouldRespondWith` 404 + it "should not ignore unknown args and fail with 404" $ get "/rpc/add_them?a=1&b=2&smthelse=blabla" `shouldRespondWith` [json| { - "code": "42883", - "details": null, - "hint": "No function matches the given name and argument types. You might need to add explicit type casts.", - "message": "function test.add_them(a => integer, b => integer, smthelse => text) does not exist" } |] + "hint":"If a new function was created in the database with this name and arguments, try reloading the schema cache.", + "message":"Could not find the test.add_them(a, b, smthelse) function in the schema cache" } |] + { matchStatus = 404 + , matchHeaders = [matchContentTypeJson] + } + + it "should fail with 404 when no json arg is found with prefer single object" $ + request methodPost "/rpc/sayhello" + [("Prefer","params=single-object")] + [json|{}|] + `shouldRespondWith` + [json| { + "hint":"If a new function was created in the database with this name and arguments, try reloading the schema cache.", + "message":"Could not find the test.sayhello function with a single json or jsonb argument in the schema cache" } |] { matchStatus = 404 , matchHeaders = [matchContentTypeJson] } + it "should fail with 404 for overloaded functions with unknown args" $ do + get "/rpc/overloaded?wrong_arg=value" `shouldRespondWith` + [json| { + "hint":"If a new function was created in the database with this name and arguments, try reloading the schema cache.", + "message":"Could not find the test.overloaded(wrong_arg) function in the schema cache" } |] + { matchStatus = 404 + , matchHeaders = [matchContentTypeJson] + } + get "/rpc/overloaded?a=1&b=2&wrong_arg=value" `shouldRespondWith` + [json| { + "hint":"If a new function was created in the database with this name and arguments, try reloading the schema cache.", + "message":"Could not find the test.overloaded(a, b, wrong_arg) function in the schema cache" } |] + { matchStatus = 404 + , matchHeaders = [matchContentTypeJson] + } + + context "ambiguous overloaded functions with same arguments but different types" $ do + it "should fail with 300 Multiple Choices without explicit argument type casts" $ + get "/rpc/overloaded_same_args?arg=value" `shouldRespondWith` + [json| { + "hint":"Overloaded functions with the same argument name but different types are not supported", + "message":"Could not choose the best candidate function between: test.overloaded_same_args(arg => integer), test.overloaded_same_args(arg => xml), test.overloaded_same_args(arg => text, num => integer)" } |] + { matchStatus = 300 + , matchHeaders = [matchContentTypeJson] + } + it "works when having uppercase identifiers" $ do get "/rpc/quotedFunction?user=mscott&fullName=Michael Scott&SSN=401-32-XXXX" `shouldRespondWith` [json|{"user": "mscott", "fullName": "Michael Scott", "SSN": "401-32-XXXX"}|] @@ -689,7 +728,7 @@ spec actualPgVersion = context "only for POST rpc" $ do it "gives a parse filter error if GET style proc args are specified" $ - post "/rpc/sayhello?name=John" [json|{}|] `shouldRespondWith` 400 + post "/rpc/sayhello?name=John" [json|{name: "John"}|] `shouldRespondWith` 400 it "ignores json keys not included in ?columns" $ post "/rpc/sayhello?columns=name" diff --git a/test/fixtures/schema.sql b/test/fixtures/schema.sql index ed6bd53ba3..524b367af0 100644 --- a/test/fixtures/schema.sql +++ b/test/fixtures/schema.sql @@ -1305,6 +1305,27 @@ create or replace function test.overloaded_html_form(a text, b text, c text) ret select a || b || c $$ language sql; +create or replace function test.overloaded_same_args(arg integer) returns json as $$ +select json_build_object( + 'type', pg_typeof(arg), + 'value', arg + ); +$$ language sql; + +create or replace function test.overloaded_same_args(arg xml) returns json as $$ +select json_build_object( + 'type', pg_typeof(arg), + 'value', arg + ); +$$ language sql; + +create or replace function test.overloaded_same_args(arg text, num integer default 0) returns json as $$ +select json_build_object( + 'type', pg_typeof(arg), + 'value', arg + ); +$$ language sql; + create table test.leak( id serial primary key, blob bytea