diff --git a/chalice/app.py b/chalice/app.py index de6726bd6..fc4a59be3 100644 --- a/chalice/app.py +++ b/chalice/app.py @@ -30,7 +30,9 @@ if TYPE_CHECKING: from chalice.local import LambdaContext -_PARAMS = re.compile(r'{\w+}') +# the optional + at the end is for supporting the special greedy parameter in +# API Gateway (ie. "{proxy+}") +_PARAMS = re.compile(r'{(\w+)\+?}') MiddlewareFuncType = Callable[[Any, Callable[[Any], Any]], Any] UserHandlerFuncType = Callable[..., Any] @@ -577,9 +579,7 @@ def __init__(self, view_function: Callable[..., Any], view_name: str, def _parse_view_args(self) -> List[str]: if '{' not in self.uri_pattern: return [] - # The [1:-1] slice is to remove the braces - # e.g {foobar} -> foobar - results = [r[1:-1] for r in _PARAMS.findall(self.uri_pattern)] + results = _PARAMS.findall(self.uri_pattern) return results def __eq__(self, other: object) -> bool: diff --git a/chalice/local.py b/chalice/local.py index 6cdb0bee2..32c88f2aa 100644 --- a/chalice/local.py +++ b/chalice/local.py @@ -147,8 +147,19 @@ def match_route(self, url): captured = {} for route_url in self.route_urls: url_parts = route_url.split('/') - if len(parts) == len(url_parts): - for i, j in zip(parts, url_parts): + parts_copy = parts.copy() + + # Handle a greedy catch-all path variable (ie. "proxy+") + if url_parts[-1].endswith('+}'): + pos = len(url_parts) - 1 + if len(parts) > pos: + catch_all_param = url_parts[-1][1:-2] + captured[catch_all_param] = '/'.join(parts[pos:]) + url_parts = url_parts[:-1] + parts_copy = parts_copy[:pos] + + if len(parts_copy) == len(url_parts): + for i, j in zip(parts_copy, url_parts): if j.startswith('{') and j.endswith('}'): captured[j[1:-1]] = i continue diff --git a/tests/aws/test_features.py b/tests/aws/test_features.py index 44cbe99ff..26afa4be3 100644 --- a/tests/aws/test_features.py +++ b/tests/aws/test_features.py @@ -265,6 +265,10 @@ def test_supports_path_params(smoke_test_app): assert smoke_test_app.get_json('/path/bar') == {'path': 'bar'} +def test_supports_catch_all_param(smoke_test_app): + assert smoke_test_app.get_json('/catch-all/a/b') == {'proxy': 'a/b'} + + def test_path_params_mapped_in_api(smoke_test_app, apig_client): # Use the API Gateway API to ensure that path parameters # are modeled as such. Otherwise this will break diff --git a/tests/aws/testapp/app.py b/tests/aws/testapp/app.py index df01f385b..d9292b8e5 100644 --- a/tests/aws/testapp/app.py +++ b/tests/aws/testapp/app.py @@ -227,6 +227,11 @@ def repr_raw_body(): return {'repr-raw-body': app.current_request.raw_body.decode('utf-8')} +@app.route('/catch-all/{proxy+}', methods=['GET']) +def catch_all(proxy): + return {'proxy': proxy} + + SOCKET_MESSAGES = [] diff --git a/tests/functional/test_local.py b/tests/functional/test_local.py index 2be412115..e984f5f6f 100644 --- a/tests/functional/test_local.py +++ b/tests/functional/test_local.py @@ -318,3 +318,17 @@ def test_can_reload_server(unused_tcp_port, basic_app, http_session): assert http_session.json_get(url) == {'version': 'reloaded'} finally: p.terminate() + + +def test_can_parse_proxy_catch_all_route( + config, local_server_factory): + demo = app.Chalice('app-name') + + @demo.route('/{proxy+}') + def proxy_view(proxy): + return proxy + + local_server, port = local_server_factory(demo, config) + response = local_server.make_call(requests.get, '/any/thing', port) + assert response.status_code == 200 + assert response.text == 'any/thing' diff --git a/tests/unit/test_app.py b/tests/unit/test_app.py index ea33f3ce3..aaf69c85c 100644 --- a/tests/unit/test_app.py +++ b/tests/unit/test_app.py @@ -463,6 +463,12 @@ def test_can_parse_route_view_args(): assert entry.view_args == ['bar', 'qux'] +def test_can_parse_catch_all_route_view_args(): + entry = app.RouteEntry(lambda: {"foo": "bar"}, 'view-name', + '/foo/{proxy+}', method='GET') + assert entry.view_args == ['proxy'] + + def test_can_route_single_view(): demo = app.Chalice('app-name') diff --git a/tests/unit/test_local.py b/tests/unit/test_local.py index b4135f650..14289b201 100644 --- a/tests/unit/test_local.py +++ b/tests/unit/test_local.py @@ -595,12 +595,18 @@ def test_multi_value_header(handler): ('/names/bar/wrong', None), ('/a/z/c', '/a/{capture}/c'), ('/a/b/c', '/a/b/c'), + ('/x', None), + ('/x/foo', '/x/{proxy+}'), + ('/x/foo/bar', '/x/{proxy+}'), + ('/y/foo/bar', '/y/{capture}/{proxy+}'), + ('/y/foo/bar/baz', '/y/{capture}/{proxy+}'), ]) def test_can_match_exact_route(actual_url, matched_url): matcher = local.RouteMatcher([ '/foo', '/foo/{capture}', '/foo/bar', '/names/{capture}', - '/a/{capture}/c', '/a/b/c' + '/a/{capture}/c', '/a/b/c', + '/x/{proxy+}', '/y/{capture}/{proxy+}' ]) if matched_url is not None: assert matcher.match_route(actual_url).route == matched_url