Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement some basic guards on our /search endpoint #16812

Merged
merged 2 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion tests/unit/search/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import pretend

from warehouse import search
from warehouse.rate_limiting import IRateLimiter, RateLimit

from ...common.db.packaging import ProjectFactory, ReleaseFactory

Expand Down Expand Up @@ -118,11 +119,15 @@ def test_includeme(monkeypatch):
"aws.key_id": "AAAAAAAAAAAA",
"aws.secret_key": "deadbeefdeadbeefdeadbeef",
"opensearch.url": opensearch_url,
"warehouse.search.ratelimit_string": "10 per second",
},
__setitem__=registry.__setitem__,
),
add_request_method=pretend.call_recorder(lambda *a, **kw: None),
add_periodic_task=pretend.call_recorder(lambda *a, **kw: None),
register_service_factory=pretend.call_recorder(
lambda factory, iface, name=None: None
),
)

search.includeme(config)
Expand All @@ -132,7 +137,7 @@ def test_includeme(monkeypatch):
]
assert len(opensearch_client_init.calls) == 1
assert opensearch_client_init.calls[0].kwargs["hosts"] == ["https://some.url"]
assert opensearch_client_init.calls[0].kwargs["timeout"] == 2
assert opensearch_client_init.calls[0].kwargs["timeout"] == 0.5
assert opensearch_client_init.calls[0].kwargs["retry_on_timeout"] is False
assert (
opensearch_client_init.calls[0].kwargs["connection_class"]
Expand All @@ -147,3 +152,7 @@ def test_includeme(monkeypatch):
assert config.add_request_method.calls == [
pretend.call(search.opensearch, name="opensearch", reify=True)
]

assert config.register_service_factory.calls == [
pretend.call(RateLimit("10 per second"), IRateLimiter, name="search")
]
1 change: 1 addition & 0 deletions tests/unit/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ def __init__(self):
"warehouse.manage.oidc.ip_registration_ratelimit_string": "100 per day",
"warehouse.packaging.project_create_user_ratelimit_string": "20 per hour",
"warehouse.packaging.project_create_ip_ratelimit_string": "40 per hour",
"warehouse.search.ratelimit_string": "5 per second",
"oidc.backend": "warehouse.oidc.services.OIDCPublisherService",
"integrity.backend": "warehouse.attestations.services.IntegrityService",
"warehouse.organizations.max_undecided_organization_applications": 3,
Expand Down
110 changes: 102 additions & 8 deletions tests/unit/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@
HTTPRequestEntityTooLarge,
HTTPSeeOther,
HTTPServiceUnavailable,
HTTPTooManyRequests,
)
from trove_classifiers import sorted_classifiers
from webob.multidict import MultiDict

from warehouse import views
from warehouse.errors import WarehouseDenied
from warehouse.packaging.models import ProjectFactory as DBProjectFactory
from warehouse.rate_limiting.interfaces import IRateLimiter
from warehouse.utils.row_counter import compute_row_counts
from warehouse.views import (
SecurityKeyGiveaway,
Expand Down Expand Up @@ -476,12 +478,21 @@ def test_csi_sidebar_sponsor_logo():

class TestSearch:
@pytest.mark.parametrize("page", [None, 1, 5])
def test_with_a_query(self, monkeypatch, db_request, metrics, page):
def test_with_a_query(
self, monkeypatch, pyramid_services, db_request, metrics, page
):
params = MultiDict({"q": "foo bar"})
if page is not None:
params["page"] = page
db_request.params = params

fake_rate_limiter = pretend.stub(
test=lambda *a: True, hit=lambda *a: True, resets_in=lambda *a: None
)
pyramid_services.register_service(
fake_rate_limiter, IRateLimiter, None, name="search"
)

db_request.opensearch = pretend.stub()
opensearch_query = pretend.stub()
get_opensearch_query = pretend.call_recorder(lambda *a, **kw: opensearch_query)
Expand Down Expand Up @@ -514,12 +525,21 @@ def test_with_a_query(self, monkeypatch, db_request, metrics, page):
]

@pytest.mark.parametrize("page", [None, 1, 5])
def test_with_classifiers(self, monkeypatch, db_request, metrics, page):
def test_with_classifiers(
self, monkeypatch, pyramid_services, db_request, metrics, page
):
params = MultiDict([("q", "foo bar"), ("c", "foo :: bar"), ("c", "fiz :: buz")])
if page is not None:
params["page"] = page
db_request.params = params

fake_rate_limiter = pretend.stub(
test=lambda *a: True, hit=lambda *a: True, resets_in=lambda *a: None
)
pyramid_services.register_service(
fake_rate_limiter, IRateLimiter, None, name="search"
)

opensearch_query = pretend.stub()
db_request.opensearch = pretend.stub()
get_opensearch_query = pretend.call_recorder(lambda *a, **kw: opensearch_query)
Expand Down Expand Up @@ -562,6 +582,7 @@ def test_with_classifiers(self, monkeypatch, db_request, metrics, page):
assert page_cls.calls == [
pretend.call(opensearch_query, url_maker=url_maker, page=page or 1)
]

assert url_maker_factory.calls == [pretend.call(db_request)]
assert get_opensearch_query.calls == [
pretend.call(db_request.opensearch, params.get("q"), "", params.getall("c"))
Expand All @@ -570,10 +591,19 @@ def test_with_classifiers(self, monkeypatch, db_request, metrics, page):
pretend.call("warehouse.views.search.results", 1000)
]

def test_returns_404_with_pagenum_too_high(self, monkeypatch, db_request, metrics):
def test_returns_404_with_pagenum_too_high(
self, monkeypatch, pyramid_services, db_request, metrics
):
params = MultiDict({"page": 15})
db_request.params = params

fake_rate_limiter = pretend.stub(
test=lambda *a: True, hit=lambda *a: True, resets_in=lambda *a: None
)
pyramid_services.register_service(
fake_rate_limiter, IRateLimiter, None, name="search"
)

opensearch_query = pretend.stub()
db_request.opensearch = pretend.stub(query=lambda *a, **kw: opensearch_query)

Expand All @@ -594,10 +624,19 @@ def test_returns_404_with_pagenum_too_high(self, monkeypatch, db_request, metric
assert url_maker_factory.calls == [pretend.call(db_request)]
assert metrics.histogram.calls == []

def test_raises_400_with_pagenum_type_str(self, monkeypatch, db_request, metrics):
def test_raises_400_with_pagenum_type_str(
self, monkeypatch, pyramid_services, db_request, metrics
):
params = MultiDict({"page": "abc"})
db_request.params = params

fake_rate_limiter = pretend.stub(
test=lambda *a: True, hit=lambda *a: True, resets_in=lambda *a: None
)
pyramid_services.register_service(
fake_rate_limiter, IRateLimiter, None, name="search"
)

opensearch_query = pretend.stub()
db_request.opensearch = pretend.stub(query=lambda *a, **kw: opensearch_query)

Expand All @@ -615,23 +654,40 @@ def test_raises_400_with_pagenum_type_str(self, monkeypatch, db_request, metrics
assert page_cls.calls == []
assert metrics.histogram.calls == []

def test_return_413_when_query_too_long(self, db_request, metrics):
def test_return_413_when_query_too_long(
self, pyramid_services, db_request, metrics
):
params = MultiDict({"q": "a" * 1001})
db_request.params = params

fake_rate_limiter = pretend.stub(
test=lambda *a: True, hit=lambda *a: True, resets_in=lambda *a: None
)
pyramid_services.register_service(
fake_rate_limiter, IRateLimiter, None, name="search"
)

with pytest.raises(HTTPRequestEntityTooLarge):
search(db_request)

assert metrics.increment.calls == [
pretend.call("warehouse.views.search.error", tags=["error:query_too_long"])
pretend.call("warehouse.search.ratelimiter.hit", tags=[]),
pretend.call("warehouse.views.search.error", tags=["error:query_too_long"]),
]

def test_returns_503_when_opensearch_unavailable(
self, monkeypatch, db_request, metrics
self, monkeypatch, pyramid_services, db_request, metrics
):
params = MultiDict({"page": 15})
db_request.params = params

fake_rate_limiter = pretend.stub(
test=lambda *a: True, hit=lambda *a: True, resets_in=lambda *a: None
)
pyramid_services.register_service(
fake_rate_limiter, IRateLimiter, None, name="search"
)

opensearch_query = pretend.stub()
db_request.opensearch = pretend.stub(query=lambda *a, **kw: opensearch_query)

Expand All @@ -648,9 +704,47 @@ def raiser(*args, **kwargs):
search(db_request)

assert url_maker_factory.calls == [pretend.call(db_request)]
assert metrics.increment.calls == [pretend.call("warehouse.views.search.error")]
assert metrics.increment.calls == [
pretend.call("warehouse.search.ratelimiter.hit", tags=[]),
pretend.call("warehouse.views.search.error"),
]
assert metrics.histogram.calls == []

@pytest.mark.parametrize("resets_in", [None, 1, 5])
def test_returns_429_when_ratelimited(
self, monkeypatch, pyramid_services, db_request, metrics, resets_in
):
params = MultiDict({"q": "foo bar"})
db_request.params = params

fake_rate_limiter = pretend.stub(
test=lambda *a: False,
hit=lambda *a: True,
resets_in=lambda *a: (
None
if resets_in is None
else pretend.stub(total_seconds=lambda *a: resets_in)
),
)
pyramid_services.register_service(
fake_rate_limiter, IRateLimiter, None, name="search"
)

with pytest.raises(HTTPTooManyRequests) as exc_info:
search(db_request)

message = (
"Your search query could not be performed because there were too "
"many requests by the client."
)
if resets_in is not None:
message += f" Limit may reset in {resets_in} seconds."

assert exc_info.value.args[0] == message
assert metrics.increment.calls == [
pretend.call("warehouse.search.ratelimiter.exceeded", tags=[])
]


def test_classifiers(db_request):
assert list_classifiers(db_request) == {"classifiers": sorted_classifiers}
Expand Down
6 changes: 6 additions & 0 deletions warehouse/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,12 @@ def configure(settings=None):
"PROJECT_CREATE_IP_RATELIMIT_STRING",
default="40 per hour",
)
maybe_set(
settings,
"warehouse.search.ratelimit_string",
"SEARCH_RATELIMIT_STRING",
default="5 per second",
)

# OIDC feature flags and settings
maybe_set(settings, "warehouse.oidc.audience", "OIDC_AUDIENCE")
Expand Down
6 changes: 3 additions & 3 deletions warehouse/locale/messages.pot
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
#: warehouse/views.py:147
#: warehouse/views.py:149
msgid ""
"You must verify your **primary** email address before you can perform "
"this action."
msgstr ""

#: warehouse/views.py:163
#: warehouse/views.py:165
msgid ""
"Two-factor authentication must be enabled on your account to perform this"
" action."
msgstr ""

#: warehouse/views.py:299
#: warehouse/views.py:301
msgid "Locale updated"
msgstr ""

Expand Down
8 changes: 7 additions & 1 deletion warehouse/search/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from warehouse import db
from warehouse.packaging.models import Project, Release
from warehouse.rate_limiting import IRateLimiter, RateLimit
from warehouse.search.utils import get_index


Expand Down Expand Up @@ -79,13 +80,18 @@ def opensearch(request):


def includeme(config):
ratelimit_string = config.registry.settings.get("warehouse.search.ratelimit_string")
config.register_service_factory(
RateLimit(ratelimit_string), IRateLimiter, name="search"
)

p = parse_url(config.registry.settings["opensearch.url"])
qs = urllib.parse.parse_qs(p.query)
kwargs = {
"hosts": [urllib.parse.urlunparse((p.scheme, p.netloc) + ("",) * 4)],
"verify_certs": True,
"ca_certs": certifi.where(),
"timeout": 2,
"timeout": 0.5,
"retry_on_timeout": False,
"serializer": opensearchpy.serializer.serializer,
"max_retries": 1,
Expand Down
17 changes: 17 additions & 0 deletions warehouse/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
HTTPRequestEntityTooLarge,
HTTPSeeOther,
HTTPServiceUnavailable,
HTTPTooManyRequests,
exception_response,
)
from pyramid.i18n import make_localizer
Expand Down Expand Up @@ -60,6 +61,7 @@
Release,
ReleaseClassifiers,
)
from warehouse.rate_limiting import IRateLimiter
from warehouse.search.queries import SEARCH_FILTER_ORDER, get_opensearch_query
from warehouse.utils.cors import _CORS_HEADERS
from warehouse.utils.http import is_safe_url
Expand Down Expand Up @@ -322,8 +324,23 @@ def list_classifiers(request):
has_translations=True,
)
def search(request):
ratelimiter = request.find_service(IRateLimiter, name="search", context=None)
metrics = request.find_service(IMetricsService, context=None)

ratelimiter.hit(request.remote_addr)
if not ratelimiter.test(request.remote_addr):
metrics.increment("warehouse.search.ratelimiter.exceeded", tags=[])
ewdurbin marked this conversation as resolved.
Show resolved Hide resolved
message = (
"Your search query could not be performed because there were too "
"many requests by the client."
)
_resets_in = ratelimiter.resets_in(request.remote_addr)
if _resets_in is not None:
_resets_in = max(1, int(_resets_in.total_seconds()))
message += f" Limit may reset in {_resets_in} seconds."
raise HTTPTooManyRequests(message)
metrics.increment("warehouse.search.ratelimiter.hit", tags=[])

querystring = request.params.get("q", "").replace("'", '"')
# Bail early for really long queries before ES raises an error
if len(querystring) > 1000:
Expand Down