diff --git a/elasticsearch_dsl/search.py b/elasticsearch_dsl/search.py index 4ffff1883..cf82832a1 100644 --- a/elasticsearch_dsl/search.py +++ b/elasticsearch_dsl/search.py @@ -18,7 +18,7 @@ import collections.abc import copy -from elasticsearch.exceptions import ApiError +from elasticsearch.exceptions import ApiError, NotFoundError from elasticsearch.helpers import scan from .aggs import A, AggBase @@ -854,6 +854,54 @@ def scan(self): for hit in scan(es, query=self.to_dict(), index=self._index, **self._params): yield self._get_result(hit) + def page(self): + """ + Turn the search into a paged search utilizing Point in Time (PIT) and search_after. + Returns a generator that will iterate over all the documents matching the query. + """ + search = self._clone() + + # A sort is required to page search results. We use the optimized default if sort is None. + # https://www.elastic.co/guide/en/elasticsearch/reference/current/paginate-search-results.html + if not search._sort: + search._sort = ["_shard_doc"] + + keep_alive = search._params.pop("keep_alive", "30s") + es = get_connection(search._using) + + pit = es.open_point_in_time( + index=search._index, + keep_alive=keep_alive, + ) + pit_id = pit["id"] + + # The index is passed with Point in Time (PIT). + search._index = None + search._extra.update(pit={"id": pit_id, "keep_alive": keep_alive}) + + response = es.search(body=search.to_dict(), **search._params) + while hits := response["hits"]["hits"]: + for hit in hits: + yield self._get_result(hit) + + # If we have fewer hits than our batch size, we know there are no more results. + if len(hits) < search._params.get("size", 0): + break + + last_document = hits[-1] + pit_id = response["pit_id"] + search._extra.update( + pit={"id": pit_id, "keep_alive": keep_alive}, + search_after=last_document["sort"], + ) + response = es.search(body=search.to_dict(), **search._params) + + # Try to close the PIT unless it is already closed. + try: + es.close_point_in_time(body={"id": pit_id}) + except NotFoundError: + pass + def delete(self): """ delete() executes the query by delegating to delete_by_query() diff --git a/tests/test_integration/test_search.py b/tests/test_integration/test_search.py index 99fb51847..2b6d7e028 100644 --- a/tests/test_integration/test_search.py +++ b/tests/test_integration/test_search.py @@ -110,6 +110,15 @@ def test_scan_iterates_through_all_docs(data_client): assert {d["_id"] for d in FLAT_DATA} == {c.meta.id for c in commits} +def test_page_iterates_through_all_docs(data_client): + s = Search(index="flat-git") + + commits = list(s.page()) + + assert 52 == len(commits) + assert {d["_id"] for d in FLAT_DATA} == {c.meta.id for c in commits} + + def test_response_is_cached(data_client): s = Repository.search() repos = list(s)