From 7189ca2e461d4f83b4bd5e8249f0649d04672242 Mon Sep 17 00:00:00 2001 From: Jacob Eisenberg Date: Wed, 4 Jan 2023 13:34:54 +0100 Subject: [PATCH 1/4] Add support for greedy catch-all path variables --- chalice/app.py | 8 ++++---- chalice/local.py | 16 ++++++++++++++-- tests/aws/test_features.py | 4 ++++ tests/aws/testapp/app.py | 5 +++++ tests/functional/test_local.py | 14 ++++++++++++++ tests/unit/test_app.py | 6 ++++++ tests/unit/test_local.py | 8 +++++++- 7 files changed, 54 insertions(+), 7 deletions(-) diff --git a/chalice/app.py b/chalice/app.py index de6726bd6..fc4a59be3 100644 --- a/chalice/app.py +++ b/chalice/app.py @@ -30,7 +30,9 @@ if TYPE_CHECKING: from chalice.local import LambdaContext -_PARAMS = re.compile(r'{\w+}') +# the optional + at the end is for supporting the special greedy parameter in +# API Gateway (ie. "{proxy+}") +_PARAMS = re.compile(r'{(\w+)\+?}') MiddlewareFuncType = Callable[[Any, Callable[[Any], Any]], Any] UserHandlerFuncType = Callable[..., Any] @@ -577,9 +579,7 @@ def __init__(self, view_function: Callable[..., Any], view_name: str, def _parse_view_args(self) -> List[str]: if '{' not in self.uri_pattern: return [] - # The [1:-1] slice is to remove the braces - # e.g {foobar} -> foobar - results = [r[1:-1] for r in _PARAMS.findall(self.uri_pattern)] + results = _PARAMS.findall(self.uri_pattern) return results def __eq__(self, other: object) -> bool: diff --git a/chalice/local.py b/chalice/local.py index 6cdb0bee2..b5d32c3e0 100644 --- a/chalice/local.py +++ b/chalice/local.py @@ -47,6 +47,7 @@ ResponseType = Dict[str, Any] HandlerCls = Callable[..., 'ChaliceRequestHandler'] ServerCls = Callable[..., 'HTTPServer'] +CatchAllParamMatcher = re.compile(r'\{(\w+)\+\}') class Clock(object): @@ -147,8 +148,19 @@ def match_route(self, url): captured = {} for route_url in self.route_urls: url_parts = route_url.split('/') - if len(parts) == len(url_parts): - for i, j in zip(parts, url_parts): + parts_copy = parts.copy() + + # Capture the special greedy (catch-all) ie. "proxy+" path variable + catch_all = CatchAllParamMatcher.match(url_parts[-1]) + if catch_all is not None: + i = len(url_parts) - 1 + if len(parts) > i: + captured[catch_all.group(1)] = '/'.join(parts[i:]) + url_parts = url_parts[:-1] + parts_copy = parts_copy[:i] + + if len(parts_copy) == len(url_parts): + for i, j in zip(parts_copy, url_parts): if j.startswith('{') and j.endswith('}'): captured[j[1:-1]] = i continue diff --git a/tests/aws/test_features.py b/tests/aws/test_features.py index 44cbe99ff..16329ec29 100644 --- a/tests/aws/test_features.py +++ b/tests/aws/test_features.py @@ -265,6 +265,10 @@ def test_supports_path_params(smoke_test_app): assert smoke_test_app.get_json('/path/bar') == {'path': 'bar'} +def test_supports_catch_all_param(smoke_test_app): + assert smoke_test_app.get_json('/catch-all/foo/bar') == {'proxy': 'foo/bar'} + + def test_path_params_mapped_in_api(smoke_test_app, apig_client): # Use the API Gateway API to ensure that path parameters # are modeled as such. Otherwise this will break diff --git a/tests/aws/testapp/app.py b/tests/aws/testapp/app.py index df01f385b..d9292b8e5 100644 --- a/tests/aws/testapp/app.py +++ b/tests/aws/testapp/app.py @@ -227,6 +227,11 @@ def repr_raw_body(): return {'repr-raw-body': app.current_request.raw_body.decode('utf-8')} +@app.route('/catch-all/{proxy+}', methods=['GET']) +def catch_all(proxy): + return {'proxy': proxy} + + SOCKET_MESSAGES = [] diff --git a/tests/functional/test_local.py b/tests/functional/test_local.py index 2be412115..e984f5f6f 100644 --- a/tests/functional/test_local.py +++ b/tests/functional/test_local.py @@ -318,3 +318,17 @@ def test_can_reload_server(unused_tcp_port, basic_app, http_session): assert http_session.json_get(url) == {'version': 'reloaded'} finally: p.terminate() + + +def test_can_parse_proxy_catch_all_route( + config, local_server_factory): + demo = app.Chalice('app-name') + + @demo.route('/{proxy+}') + def proxy_view(proxy): + return proxy + + local_server, port = local_server_factory(demo, config) + response = local_server.make_call(requests.get, '/any/thing', port) + assert response.status_code == 200 + assert response.text == 'any/thing' diff --git a/tests/unit/test_app.py b/tests/unit/test_app.py index ea33f3ce3..aaf69c85c 100644 --- a/tests/unit/test_app.py +++ b/tests/unit/test_app.py @@ -463,6 +463,12 @@ def test_can_parse_route_view_args(): assert entry.view_args == ['bar', 'qux'] +def test_can_parse_catch_all_route_view_args(): + entry = app.RouteEntry(lambda: {"foo": "bar"}, 'view-name', + '/foo/{proxy+}', method='GET') + assert entry.view_args == ['proxy'] + + def test_can_route_single_view(): demo = app.Chalice('app-name') diff --git a/tests/unit/test_local.py b/tests/unit/test_local.py index b4135f650..14289b201 100644 --- a/tests/unit/test_local.py +++ b/tests/unit/test_local.py @@ -595,12 +595,18 @@ def test_multi_value_header(handler): ('/names/bar/wrong', None), ('/a/z/c', '/a/{capture}/c'), ('/a/b/c', '/a/b/c'), + ('/x', None), + ('/x/foo', '/x/{proxy+}'), + ('/x/foo/bar', '/x/{proxy+}'), + ('/y/foo/bar', '/y/{capture}/{proxy+}'), + ('/y/foo/bar/baz', '/y/{capture}/{proxy+}'), ]) def test_can_match_exact_route(actual_url, matched_url): matcher = local.RouteMatcher([ '/foo', '/foo/{capture}', '/foo/bar', '/names/{capture}', - '/a/{capture}/c', '/a/b/c' + '/a/{capture}/c', '/a/b/c', + '/x/{proxy+}', '/y/{capture}/{proxy+}' ]) if matched_url is not None: assert matcher.match_route(actual_url).route == matched_url From 776004add429208b12b618466aa169a9aa44e259 Mon Sep 17 00:00:00 2001 From: Jacob Eisenberg Date: Wed, 4 Jan 2023 14:27:19 +0100 Subject: [PATCH 2/4] Improve performance by not using regex --- chalice/local.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/chalice/local.py b/chalice/local.py index b5d32c3e0..17226f26a 100644 --- a/chalice/local.py +++ b/chalice/local.py @@ -47,7 +47,6 @@ ResponseType = Dict[str, Any] HandlerCls = Callable[..., 'ChaliceRequestHandler'] ServerCls = Callable[..., 'HTTPServer'] -CatchAllParamMatcher = re.compile(r'\{(\w+)\+\}') class Clock(object): @@ -150,12 +149,12 @@ def match_route(self, url): url_parts = route_url.split('/') parts_copy = parts.copy() - # Capture the special greedy (catch-all) ie. "proxy+" path variable - catch_all = CatchAllParamMatcher.match(url_parts[-1]) - if catch_all is not None: + # Handle a greedy catch-all path variable (ie. "proxy+") + if url_parts[-1].endswith('+}'): i = len(url_parts) - 1 if len(parts) > i: - captured[catch_all.group(1)] = '/'.join(parts[i:]) + catch_all_param = url_parts[-1][1:-2] + captured[catch_all_param] = '/'.join(parts[i:]) url_parts = url_parts[:-1] parts_copy = parts_copy[:i] From 5f2b479993e65839e3b7cad5e438672d937b82d9 Mon Sep 17 00:00:00 2001 From: Jacob Eisenberg Date: Thu, 5 Jan 2023 09:00:48 +0100 Subject: [PATCH 3/4] Fix long line --- tests/aws/test_features.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/aws/test_features.py b/tests/aws/test_features.py index 16329ec29..26afa4be3 100644 --- a/tests/aws/test_features.py +++ b/tests/aws/test_features.py @@ -266,7 +266,7 @@ def test_supports_path_params(smoke_test_app): def test_supports_catch_all_param(smoke_test_app): - assert smoke_test_app.get_json('/catch-all/foo/bar') == {'proxy': 'foo/bar'} + assert smoke_test_app.get_json('/catch-all/a/b') == {'proxy': 'a/b'} def test_path_params_mapped_in_api(smoke_test_app, apig_client): From 6a92e01e80d61de91cb05200d0e6818d7aba61ff Mon Sep 17 00:00:00 2001 From: Jacob Eisenberg Date: Thu, 5 Jan 2023 09:29:39 +0100 Subject: [PATCH 4/4] Fix variable name to aboid typing confusion --- chalice/local.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/chalice/local.py b/chalice/local.py index 17226f26a..32c88f2aa 100644 --- a/chalice/local.py +++ b/chalice/local.py @@ -151,12 +151,12 @@ def match_route(self, url): # Handle a greedy catch-all path variable (ie. "proxy+") if url_parts[-1].endswith('+}'): - i = len(url_parts) - 1 - if len(parts) > i: + pos = len(url_parts) - 1 + if len(parts) > pos: catch_all_param = url_parts[-1][1:-2] - captured[catch_all_param] = '/'.join(parts[i:]) + captured[catch_all_param] = '/'.join(parts[pos:]) url_parts = url_parts[:-1] - parts_copy = parts_copy[:i] + parts_copy = parts_copy[:pos] if len(parts_copy) == len(url_parts): for i, j in zip(parts_copy, url_parts):