From ab090c2ed9cec4a02901b5b665bf2efe844e44a2 Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Wed, 13 Mar 2024 15:16:51 +0000 Subject: [PATCH] search unit tests --- .github/workflows/ci.yml | 1 - elasticsearch_dsl/_async/search.py | 18 +- elasticsearch_dsl/_sync/search.py | 16 +- .../{document.py => document_base.py} | 0 elasticsearch_dsl/search_base.py | 6 - setup.cfg | 1 + setup.py | 2 + tests/_async/__init__.py | 16 + tests/_async/test_search.py | 687 ++++++++++++++++++ tests/_sync/__init__.py | 16 + tests/{ => _sync}/test_search.py | 8 +- tests/conftest.py | 50 +- utils/run-unasync.py | 18 +- 13 files changed, 818 insertions(+), 21 deletions(-) rename elasticsearch_dsl/{document.py => document_base.py} (100%) create mode 100644 tests/_async/__init__.py create mode 100644 tests/_async/test_search.py create mode 100644 tests/_sync/__init__.py rename tests/{ => _sync}/test_search.py (99%) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b4541fe84..7fac2d9b7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -61,7 +61,6 @@ jobs: fail-fast: false matrix: python-version: [ - "3.7", "3.8", "3.9", "3.10", diff --git a/elasticsearch_dsl/_async/search.py b/elasticsearch_dsl/_async/search.py index 1716e9654..534ae93c8 100644 --- a/elasticsearch_dsl/_async/search.py +++ b/elasticsearch_dsl/_async/search.py @@ -24,12 +24,26 @@ from ..utils import AttrDict -class Search(SearchBase): +class AsyncSearch(SearchBase): def __aiter__(self): """ Iterate over the hits. """ - return aiter(await self.execute()) + + class ResultsIterator: + def __init__(self, search): + self.search = search + self.iterator = None + + async def __anext__(self): + if self.iterator is None: + self.iterator = iter(await self.search.execute()) + try: + return next(self.iterator) + except StopIteration: + raise StopAsyncIteration() + + return ResultsIterator(self) async def count(self): """ diff --git a/elasticsearch_dsl/_sync/search.py b/elasticsearch_dsl/_sync/search.py index 0da2efe09..a15a08aa0 100644 --- a/elasticsearch_dsl/_sync/search.py +++ b/elasticsearch_dsl/_sync/search.py @@ -29,7 +29,21 @@ def __iter__(self): """ Iterate over the hits. """ - return iter(self.execute()) + + class ResultsIterator: + def __init__(self, search): + self.search = search + self.iterator = None + + def __next__(self): + if self.iterator is None: + self.iterator = iter(self.search.execute()) + try: + return next(self.iterator) + except StopIteration: + raise StopIteration() + + return ResultsIterator(self) def count(self): """ diff --git a/elasticsearch_dsl/document.py b/elasticsearch_dsl/document_base.py similarity index 100% rename from elasticsearch_dsl/document.py rename to elasticsearch_dsl/document_base.py diff --git a/elasticsearch_dsl/search_base.py b/elasticsearch_dsl/search_base.py index ec6c02dda..c85798628 100644 --- a/elasticsearch_dsl/search_base.py +++ b/elasticsearch_dsl/search_base.py @@ -333,12 +333,6 @@ def filter(self, *args, **kwargs): def exclude(self, *args, **kwargs): return self.query(Bool(filter=[~Q(*args, **kwargs)])) - def __iter__(self): - """ - Iterate over the hits. - """ - return iter(self.execute()) - def __getitem__(self, n): """ Support slicing the `Search` instance for pagination. diff --git a/setup.cfg b/setup.cfg index d902949f3..2a34c2fab 100644 --- a/setup.cfg +++ b/setup.cfg @@ -11,3 +11,4 @@ filterwarnings = error ignore:Legacy index templates are deprecated in favor of composable templates.:elasticsearch.exceptions.ElasticsearchWarning ignore:datetime.datetime.utcfromtimestamp\(\) is deprecated and scheduled for removal in a future version..*:DeprecationWarning +asyncio_mode = auto diff --git a/setup.py b/setup.py index 22c6a5e63..492821de4 100644 --- a/setup.py +++ b/setup.py @@ -37,8 +37,10 @@ "pytest", "pytest-cov", "pytest-mock", + "pytest-asyncio", "pytz", "coverage", + "mock", # needed to have AsyncMock in Python <= 3.7 # Override Read the Docs default (sphinx<2 and sphinx-rtd-theme<0.5) "sphinx>2", "sphinx-rtd-theme>0.5", diff --git a/tests/_async/__init__.py b/tests/_async/__init__.py new file mode 100644 index 000000000..2a87d183f --- /dev/null +++ b/tests/_async/__init__.py @@ -0,0 +1,16 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/_async/test_search.py b/tests/_async/test_search.py new file mode 100644 index 000000000..2aaec92d7 --- /dev/null +++ b/tests/_async/test_search.py @@ -0,0 +1,687 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from copy import deepcopy + +from pytest import raises + +from elasticsearch_dsl import Document, Q, query +from elasticsearch_dsl._async import search +from elasticsearch_dsl.exceptions import IllegalOperation + + +def test_expand__to_dot_is_respected(): + s = search.AsyncSearch().query("match", a__b=42, _expand__to_dot=False) + + assert {"query": {"match": {"a__b": 42}}} == s.to_dict() + + +async def test_execute_uses_cache(): + s = search.AsyncSearch() + r = object() + s._response = r + + assert r is await s.execute() + + +async def test_cache_can_be_ignored(async_mock_client): + s = search.AsyncSearch(using="mock") + r = object() + s._response = r + await s.execute(ignore_cache=True) + + async_mock_client.search.assert_awaited_once_with(index=None, body={}) + + +async def test_iter_iterates_over_hits(): + s = search.AsyncSearch() + s._response = [1, 2, 3] + + r = [] + async for hit in s: + r.append(hit) + assert [1, 2, 3] == r + + +def test_cache_isnt_cloned(): + s = search.AsyncSearch() + s._response = object() + + assert not hasattr(s._clone(), "_response") + + +def test_search_starts_with_no_query(): + s = search.AsyncSearch() + + assert s.query._proxied is None + + +def test_search_query_combines_query(): + s = search.AsyncSearch() + + s2 = s.query("match", f=42) + assert s2.query._proxied == query.Match(f=42) + assert s.query._proxied is None + + s3 = s2.query("match", f=43) + assert s2.query._proxied == query.Match(f=42) + assert s3.query._proxied == query.Bool(must=[query.Match(f=42), query.Match(f=43)]) + + +def test_query_can_be_assigned_to(): + s = search.AsyncSearch() + + q = Q("match", title="python") + s.query = q + + assert s.query._proxied is q + + +def test_query_can_be_wrapped(): + s = search.AsyncSearch().query("match", title="python") + + s.query = Q("function_score", query=s.query, field_value_factor={"field": "rating"}) + + assert { + "query": { + "function_score": { + "functions": [{"field_value_factor": {"field": "rating"}}], + "query": {"match": {"title": "python"}}, + } + } + } == s.to_dict() + + +def test_using(): + o = object() + o2 = object() + s = search.AsyncSearch(using=o) + assert s._using is o + s2 = s.using(o2) + assert s._using is o + assert s2._using is o2 + + +def test_methods_are_proxied_to_the_query(): + s = search.AsyncSearch().query("match_all") + + assert s.query.to_dict() == {"match_all": {}} + + +def test_query_always_returns_search(): + s = search.AsyncSearch() + + assert isinstance(s.query("match", f=42), search.AsyncSearch) + + +def test_source_copied_on_clone(): + s = search.AsyncSearch().source(False) + assert s._clone()._source == s._source + assert s._clone()._source is False + + s2 = search.AsyncSearch().source([]) + assert s2._clone()._source == s2._source + assert s2._source == [] + + s3 = search.AsyncSearch().source(["some", "fields"]) + assert s3._clone()._source == s3._source + assert s3._clone()._source == ["some", "fields"] + + +def test_copy_clones(): + from copy import copy + + s1 = search.AsyncSearch().source(["some", "fields"]) + s2 = copy(s1) + + assert s1 == s2 + assert s1 is not s2 + + +def test_aggs_allow_two_metric(): + s = search.AsyncSearch() + + s.aggs.metric("a", "max", field="a").metric("b", "max", field="b") + + assert s.to_dict() == { + "aggs": {"a": {"max": {"field": "a"}}, "b": {"max": {"field": "b"}}} + } + + +def test_aggs_get_copied_on_change(): + s = search.AsyncSearch().query("match_all") + s.aggs.bucket("per_tag", "terms", field="f").metric( + "max_score", "max", field="score" + ) + + s2 = s.query("match_all") + s2.aggs.bucket("per_month", "date_histogram", field="date", interval="month") + s3 = s2.query("match_all") + s3.aggs["per_month"].metric("max_score", "max", field="score") + s4 = s3._clone() + s4.aggs.metric("max_score", "max", field="score") + + d = { + "query": {"match_all": {}}, + "aggs": { + "per_tag": { + "terms": {"field": "f"}, + "aggs": {"max_score": {"max": {"field": "score"}}}, + } + }, + } + + assert d == s.to_dict() + d["aggs"]["per_month"] = {"date_histogram": {"field": "date", "interval": "month"}} + assert d == s2.to_dict() + d["aggs"]["per_month"]["aggs"] = {"max_score": {"max": {"field": "score"}}} + assert d == s3.to_dict() + d["aggs"]["max_score"] = {"max": {"field": "score"}} + assert d == s4.to_dict() + + +def test_search_index(): + s = search.AsyncSearch(index="i") + assert s._index == ["i"] + s = s.index("i2") + assert s._index == ["i", "i2"] + s = s.index("i3") + assert s._index == ["i", "i2", "i3"] + s = s.index() + assert s._index is None + s = search.AsyncSearch(index=("i", "i2")) + assert s._index == ["i", "i2"] + s = search.AsyncSearch(index=["i", "i2"]) + assert s._index == ["i", "i2"] + s = search.AsyncSearch() + s = s.index("i", "i2") + assert s._index == ["i", "i2"] + s2 = s.index("i3") + assert s._index == ["i", "i2"] + assert s2._index == ["i", "i2", "i3"] + s = search.AsyncSearch() + s = s.index(["i", "i2"], "i3") + assert s._index == ["i", "i2", "i3"] + s2 = s.index("i4") + assert s._index == ["i", "i2", "i3"] + assert s2._index == ["i", "i2", "i3", "i4"] + s2 = s.index(["i4"]) + assert s2._index == ["i", "i2", "i3", "i4"] + s2 = s.index(("i4", "i5")) + assert s2._index == ["i", "i2", "i3", "i4", "i5"] + + +def test_doc_type_document_class(): + class MyDocument(Document): + pass + + s = search.AsyncSearch(doc_type=MyDocument) + assert s._doc_type == [MyDocument] + assert s._doc_type_map == {} + + s = search.AsyncSearch().doc_type(MyDocument) + assert s._doc_type == [MyDocument] + assert s._doc_type_map == {} + + +def test_knn(): + s = search.AsyncSearch() + + with raises(TypeError): + s.knn() + with raises(TypeError): + s.knn("field") + with raises(TypeError): + s.knn("field", 5) + with raises(ValueError): + s.knn("field", 5, 100) + with raises(ValueError): + s.knn("field", 5, 100, query_vector=[1, 2, 3], query_vector_builder={}) + + s = s.knn("field", 5, 100, query_vector=[1, 2, 3]) + assert { + "knn": { + "field": "field", + "k": 5, + "num_candidates": 100, + "query_vector": [1, 2, 3], + } + } == s.to_dict() + + s = s.knn( + k=4, + num_candidates=40, + boost=0.8, + field="name", + query_vector_builder={ + "text_embedding": {"model_id": "foo", "model_text": "search text"} + }, + ) + assert { + "knn": [ + { + "field": "field", + "k": 5, + "num_candidates": 100, + "query_vector": [1, 2, 3], + }, + { + "field": "name", + "k": 4, + "num_candidates": 40, + "query_vector_builder": { + "text_embedding": {"model_id": "foo", "model_text": "search text"} + }, + "boost": 0.8, + }, + ] + } == s.to_dict() + + +def test_rank(): + s = search.AsyncSearch() + s.rank(rrf=False) + assert {} == s.to_dict() + + s = s.rank(rrf=True) + assert {"rank": {"rrf": {}}} == s.to_dict() + + s = s.rank(rrf={"window_size": 50, "rank_constant": 20}) + assert {"rank": {"rrf": {"window_size": 50, "rank_constant": 20}}} == s.to_dict() + + +def test_sort(): + s = search.AsyncSearch() + s = s.sort("fielda", "-fieldb") + + assert ["fielda", {"fieldb": {"order": "desc"}}] == s._sort + assert {"sort": ["fielda", {"fieldb": {"order": "desc"}}]} == s.to_dict() + + s = s.sort() + assert [] == s._sort + assert search.AsyncSearch().to_dict() == s.to_dict() + + +def test_sort_by_score(): + s = search.AsyncSearch() + s = s.sort("_score") + assert {"sort": ["_score"]} == s.to_dict() + + s = search.AsyncSearch() + with raises(IllegalOperation): + s.sort("-_score") + + +def test_collapse(): + s = search.AsyncSearch() + + inner_hits = {"name": "most_recent", "size": 5, "sort": [{"@timestamp": "desc"}]} + s = s.collapse("user.id", inner_hits=inner_hits, max_concurrent_group_searches=4) + + assert { + "field": "user.id", + "inner_hits": { + "name": "most_recent", + "size": 5, + "sort": [{"@timestamp": "desc"}], + }, + "max_concurrent_group_searches": 4, + } == s._collapse + assert { + "collapse": { + "field": "user.id", + "inner_hits": { + "name": "most_recent", + "size": 5, + "sort": [{"@timestamp": "desc"}], + }, + "max_concurrent_group_searches": 4, + } + } == s.to_dict() + + s = s.collapse() + assert {} == s._collapse + assert search.AsyncSearch().to_dict() == s.to_dict() + + +def test_slice(): + s = search.AsyncSearch() + assert {"from": 3, "size": 7} == s[3:10].to_dict() + assert {"from": 0, "size": 5} == s[:5].to_dict() + assert {"from": 3, "size": 10} == s[3:].to_dict() + assert {"from": 0, "size": 0} == s[0:0].to_dict() + assert {"from": 20, "size": 0} == s[20:0].to_dict() + + +def test_index(): + s = search.AsyncSearch() + assert {"from": 3, "size": 1} == s[3].to_dict() + + +def test_search_to_dict(): + s = search.AsyncSearch() + assert {} == s.to_dict() + + s = s.query("match", f=42) + assert {"query": {"match": {"f": 42}}} == s.to_dict() + + assert {"query": {"match": {"f": 42}}, "size": 10} == s.to_dict(size=10) + + s.aggs.bucket("per_tag", "terms", field="f").metric( + "max_score", "max", field="score" + ) + d = { + "aggs": { + "per_tag": { + "terms": {"field": "f"}, + "aggs": {"max_score": {"max": {"field": "score"}}}, + } + }, + "query": {"match": {"f": 42}}, + } + assert d == s.to_dict() + + s = search.AsyncSearch(extra={"size": 5}) + assert {"size": 5} == s.to_dict() + s = s.extra(from_=42) + assert {"size": 5, "from": 42} == s.to_dict() + + +def test_complex_example(): + s = search.AsyncSearch() + s = ( + s.query("match", title="python") + .query(~Q("match", title="ruby")) + .filter(Q("term", category="meetup") | Q("term", category="conference")) + .collapse("user_id") + .post_filter("terms", tags=["prague", "czech"]) + .script_fields(more_attendees="doc['attendees'].value + 42") + ) + + s.aggs.bucket("per_country", "terms", field="country").metric( + "avg_attendees", "avg", field="attendees" + ) + + s.query.minimum_should_match = 2 + + s = s.highlight_options(order="score").highlight("title", "body", fragment_size=50) + + assert { + "query": { + "bool": { + "filter": [ + { + "bool": { + "should": [ + {"term": {"category": "meetup"}}, + {"term": {"category": "conference"}}, + ] + } + } + ], + "must": [{"match": {"title": "python"}}], + "must_not": [{"match": {"title": "ruby"}}], + "minimum_should_match": 2, + } + }, + "post_filter": {"terms": {"tags": ["prague", "czech"]}}, + "aggs": { + "per_country": { + "terms": {"field": "country"}, + "aggs": {"avg_attendees": {"avg": {"field": "attendees"}}}, + } + }, + "collapse": {"field": "user_id"}, + "highlight": { + "order": "score", + "fields": {"title": {"fragment_size": 50}, "body": {"fragment_size": 50}}, + }, + "script_fields": {"more_attendees": {"script": "doc['attendees'].value + 42"}}, + } == s.to_dict() + + +def test_reverse(): + d = { + "query": { + "filtered": { + "filter": { + "bool": { + "should": [ + {"term": {"category": "meetup"}}, + {"term": {"category": "conference"}}, + ] + } + }, + "query": { + "bool": { + "must": [{"match": {"title": "python"}}], + "must_not": [{"match": {"title": "ruby"}}], + "minimum_should_match": 2, + } + }, + } + }, + "post_filter": {"bool": {"must": [{"terms": {"tags": ["prague", "czech"]}}]}}, + "aggs": { + "per_country": { + "terms": {"field": "country"}, + "aggs": {"avg_attendees": {"avg": {"field": "attendees"}}}, + } + }, + "sort": ["title", {"category": {"order": "desc"}}, "_score"], + "size": 5, + "highlight": {"order": "score", "fields": {"title": {"fragment_size": 50}}}, + "suggest": { + "my-title-suggestions-1": { + "text": "devloping distibutd saerch engies", + "term": {"size": 3, "field": "title"}, + } + }, + "script_fields": {"more_attendees": {"script": "doc['attendees'].value + 42"}}, + } + + d2 = deepcopy(d) + + s = search.AsyncSearch.from_dict(d) + + # make sure we haven't modified anything in place + assert d == d2 + assert {"size": 5} == s._extra + assert d == s.to_dict() + + +def test_from_dict_doesnt_need_query(): + s = search.AsyncSearch.from_dict({"size": 5}) + + assert {"size": 5} == s.to_dict() + + +async def test_params_being_passed_to_search(async_mock_client): + s = search.AsyncSearch(using="mock") + s = s.params(routing="42") + await s.execute() + + async_mock_client.search.assert_awaited_once_with(index=None, body={}, routing="42") + + +def test_source(): + assert {} == search.AsyncSearch().source().to_dict() + + assert { + "_source": {"includes": ["foo.bar.*"], "excludes": ["foo.one"]} + } == search.AsyncSearch().source( + includes=["foo.bar.*"], excludes=["foo.one"] + ).to_dict() + + assert {"_source": False} == search.AsyncSearch().source(False).to_dict() + + assert {"_source": ["f1", "f2"]} == search.AsyncSearch().source( + includes=["foo.bar.*"], excludes=["foo.one"] + ).source(["f1", "f2"]).to_dict() + + +def test_source_on_clone(): + assert { + "_source": {"includes": ["foo.bar.*"], "excludes": ["foo.one"]}, + "query": {"bool": {"filter": [{"term": {"title": "python"}}]}}, + } == search.AsyncSearch().source(includes=["foo.bar.*"]).source( + excludes=["foo.one"] + ).filter( + "term", title="python" + ).to_dict() + assert { + "_source": False, + "query": {"bool": {"filter": [{"term": {"title": "python"}}]}}, + } == search.AsyncSearch().source(False).filter("term", title="python").to_dict() + + +def test_source_on_clear(): + assert ( + {} + == search.AsyncSearch() + .source(includes=["foo.bar.*"]) + .source(includes=None, excludes=None) + .to_dict() + ) + + +def test_suggest_accepts_global_text(): + s = search.AsyncSearch.from_dict( + { + "suggest": { + "text": "the amsterdma meetpu", + "my-suggest-1": {"term": {"field": "title"}}, + "my-suggest-2": {"text": "other", "term": {"field": "body"}}, + } + } + ) + + assert { + "suggest": { + "my-suggest-1": { + "term": {"field": "title"}, + "text": "the amsterdma meetpu", + }, + "my-suggest-2": {"term": {"field": "body"}, "text": "other"}, + } + } == s.to_dict() + + +def test_suggest(): + s = search.AsyncSearch() + s = s.suggest("my_suggestion", "pyhton", term={"field": "title"}) + + assert { + "suggest": {"my_suggestion": {"term": {"field": "title"}, "text": "pyhton"}} + } == s.to_dict() + + +def test_exclude(): + s = search.AsyncSearch() + s = s.exclude("match", title="python") + + assert { + "query": { + "bool": { + "filter": [{"bool": {"must_not": [{"match": {"title": "python"}}]}}] + } + } + } == s.to_dict() + + +async def test_delete_by_query(async_mock_client): + s = search.AsyncSearch(using="mock").query("match", lang="java") + await s.delete() + + async_mock_client.delete_by_query.assert_awaited_once_with( + index=None, body={"query": {"match": {"lang": "java"}}} + ) + + +def test_update_from_dict(): + s = search.AsyncSearch() + s.update_from_dict({"indices_boost": [{"important-documents": 2}]}) + s.update_from_dict({"_source": ["id", "name"]}) + s.update_from_dict({"collapse": {"field": "user_id"}}) + + assert { + "indices_boost": [{"important-documents": 2}], + "_source": ["id", "name"], + "collapse": {"field": "user_id"}, + } == s.to_dict() + + +def test_rescore_query_to_dict(): + s = search.AsyncSearch(index="index-name") + + positive_query = Q( + "function_score", + query=Q("term", tags="a"), + script_score={"script": "_score * 1"}, + ) + + negative_query = Q( + "function_score", + query=Q("term", tags="b"), + script_score={"script": "_score * -100"}, + ) + + s = s.query(positive_query) + s = s.extra( + rescore={"window_size": 100, "query": {"rescore_query": negative_query}} + ) + assert s.to_dict() == { + "query": { + "function_score": { + "query": {"term": {"tags": "a"}}, + "functions": [{"script_score": {"script": "_score * 1"}}], + } + }, + "rescore": { + "window_size": 100, + "query": { + "rescore_query": { + "function_score": { + "query": {"term": {"tags": "b"}}, + "functions": [{"script_score": {"script": "_score * -100"}}], + } + } + }, + }, + } + + assert s.to_dict( + rescore={"window_size": 10, "query": {"rescore_query": positive_query}} + ) == { + "query": { + "function_score": { + "query": {"term": {"tags": "a"}}, + "functions": [{"script_score": {"script": "_score * 1"}}], + } + }, + "rescore": { + "window_size": 10, + "query": { + "rescore_query": { + "function_score": { + "query": {"term": {"tags": "a"}}, + "functions": [{"script_score": {"script": "_score * 1"}}], + } + } + }, + }, + } diff --git a/tests/_sync/__init__.py b/tests/_sync/__init__.py new file mode 100644 index 000000000..2a87d183f --- /dev/null +++ b/tests/_sync/__init__.py @@ -0,0 +1,16 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/test_search.py b/tests/_sync/test_search.py similarity index 99% rename from tests/test_search.py rename to tests/_sync/test_search.py index 841caa7cc..8cd2f1eea 100644 --- a/tests/test_search.py +++ b/tests/_sync/test_search.py @@ -19,7 +19,8 @@ from pytest import raises -from elasticsearch_dsl import Document, Q, query, search +from elasticsearch_dsl import Document, Q, query +from elasticsearch_dsl._sync import search from elasticsearch_dsl.exceptions import IllegalOperation @@ -50,7 +51,10 @@ def test_iter_iterates_over_hits(): s = search.Search() s._response = [1, 2, 3] - assert [1, 2, 3] == list(s) + r = [] + for hit in s: + r.append(hit) + assert [1, 2, 3] == r def test_cache_isnt_cloned(): diff --git a/tests/conftest.py b/tests/conftest.py index 0e5e082de..27f606fde 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,14 +21,16 @@ import time from datetime import datetime from unittest import SkipTest, TestCase -from unittest.mock import Mock +from unittest.mock import Mock, AsyncMock from elastic_transport import ObjectApiResponse -from elasticsearch import Elasticsearch +from elasticsearch import AsyncElasticsearch, Elasticsearch from elasticsearch.exceptions import ConnectionError from elasticsearch.helpers import bulk from pytest import fixture, skip +from elasticsearch_dsl.async_connections import add_connection as add_async_connection +from elasticsearch_dsl.async_connections import connections as async_connections from elasticsearch_dsl.connections import add_connection, connections from .test_integration.test_data import ( @@ -46,7 +48,7 @@ ELASTICSEARCH_URL = "http://localhost:9200" -def get_test_client(wait=True, **kwargs): +def get_test_client(wait=True, _async=False, **kwargs): # construct kwargs from the environment kw = {"request_timeout": 30} @@ -58,7 +60,11 @@ def get_test_client(wait=True, **kwargs): ) kw.update(kwargs) - client = Elasticsearch(ELASTICSEARCH_URL, **kw) + client = ( + Elasticsearch(ELASTICSEARCH_URL, **kw) + if not _async + else AsyncElasticsearch(ELASTICSEARCH_URL, **kw) + ) # wait for yellow status for tries_left in range(100 if wait else 1, 0, -1): @@ -119,6 +125,16 @@ def client(): skip() +@fixture(scope="session") +def async_client(): + try: + connection = get_test_client(wait="WAIT_FOR_ES" in os.environ, _async=True) + add_async_connection("default", connection) + return connection + except SkipTest: + skip() + + @fixture(scope="session") def es_version(client): info = client.info() @@ -142,11 +158,24 @@ def mock_client(dummy_response): client = Mock() client.search.return_value = dummy_response add_connection("mock", client) + yield client connections._conn = {} connections._kwargs = {} +@fixture +def async_mock_client(dummy_response): + client = Mock() + client.search = AsyncMock(return_value=dummy_response) + client.delete_by_query = AsyncMock() + add_async_connection("mock", client) + + yield client + async_connections._conn = {} + async_connections._kwargs = {} + + @fixture(scope="session") def data_client(client): # create mappings @@ -160,6 +189,19 @@ def data_client(client): client.indices.delete(index="flat-git") +@fixture(scope="session") +def async_data_client(client, async_client): + # create mappings + create_git_index(client, "git") + create_flat_git_index(client, "flat-git") + # load data + bulk(client, DATA, raise_on_error=True, refresh=True) + bulk(client, FLAT_DATA, raise_on_error=True, refresh=True) + yield async_client + client.indices.delete(index="git") + client.indices.delete(index="flat-git") + + @fixture def dummy_response(): return ObjectApiResponse( diff --git a/utils/run-unasync.py b/utils/run-unasync.py index bff890ad3..d34006110 100644 --- a/utils/run-unasync.py +++ b/utils/run-unasync.py @@ -28,13 +28,15 @@ def main(check=False): # Unasync all the generated async code additional_replacements = { + "_async": "_sync", "aiter": "iter", + "anext": "next", "AsyncElasticsearch": "Elasticsearch", "AsyncSearch": "Search", "async_connections": "connections", "async_scan": "scan", - # Handling typing.Awaitable[...] isn't done yet by unasync. - "_TYPE_ASYNC_SNIFF_CALLBACK": "_TYPE_SYNC_SNIFF_CALLBACK", + "async_mock_client": "mock_client", + "assert_awaited_once_with": "assert_called_once_with", } rules = [ unasync.Rule( @@ -43,11 +45,17 @@ def main(check=False): additional_replacements=additional_replacements, ), ] + if not check: + rules.append( + unasync.Rule( + fromdir="/tests/_async/", + todir="/tests/_sync/", + additional_replacements=additional_replacements, + ) + ) filepaths = [] - for root, _, filenames in os.walk( - Path(__file__).absolute().parent.parent / "elasticsearch_dsl/_async" - ): + for root, _, filenames in os.walk(Path(__file__).absolute().parent.parent): for filename in filenames: if filename.rpartition(".")[-1] in ( "py",