Skip to content

Commit

Permalink
feat(cors): disallow cors_enable & additional CORSMiddleware combo (
Browse files Browse the repository at this point in the history
#2201)

* feat(cors): disallow `cors_enable` & additional CORSMiddleware combo

* chore: ignore typing for 1 line because it is unclear how to fix it
  • Loading branch information
vytas7 authored Dec 19, 2023
1 parent 0aac950 commit 5b6e4c4
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 17 deletions.
4 changes: 4 additions & 0 deletions docs/_newsfragments/1977.breakingchange.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Previously, it was possible to create an :class:`~falcon.App` with the
``cors_enable`` option, and add additional :class:`~falcon.CORSMiddleware`,
leading to unexpected behavior and dysfunctional CORS. This combination now
explicitly results in a :class:`ValueError`.
12 changes: 9 additions & 3 deletions docs/api/cors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,15 @@ Usage
allow_origins='example.com', allow_credentials='*'))
.. note::
Passing the ``cors_enable`` parameter set to ``True`` should be seen as
mutually exclusive with directly passing an instance of
:class:`~falcon.CORSMiddleware` to the application's initializer.
Passing the ``cors_enable`` parameter set to ``True`` is mutually exclusive
with directly passing an instance of :class:`~falcon.CORSMiddleware` to the
application's initializer.

.. versionchanged:: 4.0

Attempt to use the combination of ``cors_enable=True`` and an additional
instance of :class:`~falcon.CORSMiddleware` now results in a
:class:`ValueError`.

CORSMiddleware
--------------
Expand Down
25 changes: 22 additions & 3 deletions falcon/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def __init__(
cors_enable=False,
sink_before_static_route=True,
):
self._cors_enable = cors_enable
self._sink_before_static_route = sink_before_static_route
self._sinks = []
self._static_routes = []
Expand Down Expand Up @@ -447,7 +448,7 @@ def __call__( # noqa: C901
def router_options(self):
return self._router.options

def add_middleware(self, middleware: object) -> None:
def add_middleware(self, middleware: Union[object, Iterable]) -> None:
"""Add one or more additional middleware components.
Arguments:
Expand All @@ -461,10 +462,28 @@ def add_middleware(self, middleware: object) -> None:
# the chance that middleware may be None.
if middleware:
try:
self._unprepared_middleware += middleware
middleware = list(middleware) # type: ignore
except TypeError:
# middleware is not iterable; assume it is just one bare component
self._unprepared_middleware.append(middleware)
middleware = [middleware]

if (
self._cors_enable
and len(
[
mc
for mc in self._unprepared_middleware + middleware
if isinstance(mc, CORSMiddleware)
]
)
> 1
):
raise ValueError(
'CORSMiddleware is not allowed in conjunction with '
'cors_enable (which already constructs one instance)'
)

self._unprepared_middleware += middleware

# NOTE(kgriffs): Even if middleware is None or an empty list, we still
# need to make sure self._middleware is initialized if this is the
Expand Down
30 changes: 19 additions & 11 deletions tests/test_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,20 +1000,28 @@ def test_process_resource_cached(self, asgi, independent_middleware):

class TestCORSMiddlewareWithAnotherMiddleware(TestMiddleware):
@pytest.mark.parametrize(
'mw',
'mw,allowed',
[
CaptureResponseMiddleware(),
[CaptureResponseMiddleware()],
(CaptureResponseMiddleware(),),
iter([CaptureResponseMiddleware()]),
(CaptureResponseMiddleware(), True),
([CaptureResponseMiddleware()], True),
((CaptureResponseMiddleware(),), True),
(iter([CaptureResponseMiddleware()]), True),
(falcon.CORSMiddleware(), False),
([falcon.CORSMiddleware()], False),
],
)
def test_api_initialization_with_cors_enabled_and_middleware_param(self, mw, asgi):
app = create_app(asgi, middleware=mw, cors_enable=True)
app.add_route('/', TestCorsResource())
client = testing.TestClient(app)
result = client.simulate_get(headers={'Origin': 'localhost'})
assert result.headers['Access-Control-Allow-Origin'] == '*'
def test_api_initialization_with_cors_enabled_and_middleware_param(
self, mw, asgi, allowed
):
if allowed:
app = create_app(asgi, middleware=mw, cors_enable=True)
app.add_route('/', TestCorsResource())
client = testing.TestClient(app)
result = client.simulate_get(headers={'Origin': 'localhost'})
assert result.headers['Access-Control-Allow-Origin'] == '*'
else:
with pytest.raises(ValueError, match='CORSMiddleware'):
app = create_app(asgi, middleware=mw, cors_enable=True)


@pytest.mark.skipif(cython, reason='Cythonized coroutine functions cannot be detected')
Expand Down

0 comments on commit 5b6e4c4

Please sign in to comment.