diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 3f1830f81..cd2c85488 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -25,22 +25,34 @@ The process for contributing to any of the Elasticsearch repositories is similar assure our users of the origin and continuing existence of the code. You only need to sign the CLA once. -2. Run the test suite to ensure your changes do not break existing code: +2. Many classes included in this library are offered in two versions, for + asynchronous and synchronous Python. When working with these classes, you only + need to make changes to the asynchronous code, located in *_async* + subdirectories in the source and tests trees. Once you've made your changes, + run the following command to automatically generate the corresponding + synchronous code: + + .. code:: bash + + $ nox -rs format + +3. Run the test suite to ensure your changes do not break existing code: .. code:: bash $ nox -rs lint test -3. Rebase your changes. +4. Rebase your changes. Update your local repository with the most recent code from the main elasticsearch-dsl-py repository, and rebase your branch on top of the latest master branch. We prefer your changes to be squashed into a single commit. -4. Submit a pull request. Push your local changes to your forked copy of the +5. Submit a pull request. Push your local changes to your forked copy of the repository and submit a pull request. In the pull request, describe what your changes do and mention the number of the issue where discussion has taken place, eg “Closes #123″. Please consider adding or modifying tests related to - your changes. + your changes. Include any generated files in the *_sync* subdirectory in your + pull request. Then sit back and wait. There will probably be discussion about the pull request and, if any changes are needed, we would love to work with you to get diff --git a/docs/api.rst b/docs/api.rst index f914f69e8..a89943b12 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -4,38 +4,25 @@ API Documentation ================= Below please find the documentation for the public classes and functions of ``elasticsearch_dsl``. +The :ref:`Asynchronous API ` classes are documented separately. .. py:module:: elasticsearch_dsl -Search ------- - .. autoclass:: Search :members: .. autoclass:: MultiSearch :members: -Document --------- - .. autoclass:: Document :members: -Index ------ - .. autoclass:: Index :members: -Faceted Search --------------- - .. autoclass:: FacetedSearch :members: -Update By Query ----------------- .. autoclass:: UpdateByQuery :members: @@ -108,5 +95,3 @@ Common field options: ``required`` Indicates if a field requires a value for the document to be valid. - - diff --git a/docs/async_api.rst b/docs/async_api.rst new file mode 100644 index 000000000..644a9b247 --- /dev/null +++ b/docs/async_api.rst @@ -0,0 +1,33 @@ +.. _async_api: + +Asynchronous API Documentation +============================== + +Below please find the documentation for the asychronous classes of ``elasticsearch_dsl``. + +.. py:module:: elasticsearch_dsl + :no-index: + +.. autoclass:: AsyncSearch + :inherited-members: + :members: + +.. autoclass:: AsyncMultiSearch + :inherited-members: + :members: + +.. autoclass:: AsyncDocument + :inherited-members: + :members: + +.. autoclass:: AsyncIndex + :inherited-members: + :members: + +.. autoclass:: AsyncFacetedSearch + :inherited-members: + :members: + +.. autoclass:: AsyncUpdateByQuery + :inherited-members: + :members: diff --git a/docs/asyncio.rst b/docs/asyncio.rst new file mode 100644 index 000000000..33ea2a7a7 --- /dev/null +++ b/docs/asyncio.rst @@ -0,0 +1,94 @@ +.. _asyncio: + +Using asyncio with Elasticsearch DSL +==================================== + +The ``elasticsearch-dsl`` package supports async/await with `asyncio `__. +To ensure that you have all the required dependencies, install the ``[async]`` extra: + + .. code:: bash + + $ python -m pip install elasticsearch-dsl[async] + +Connections +----------- + +Use the ``async_connections`` module to manage your asynchronous connections. + + .. code:: python + + from elasticsearch_dsl import async_connections + + async_connections.create_connection(hosts=['localhost'], timeout=20) + +All the options available in the ``connections`` module can be used with ``async_connections``. + +How to avoid 'Unclosed client session / connector' warnings on exit +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +These warnings come from the ``aiohttp`` package, which is used internally by the +``AsyncElasticsearch`` client. They appear often when the application exits and +are caused by HTTP connections that are open when they are garbage collected. To +avoid these warnings, make sure that you close your connections. + + .. code:: python + + es = async_connections.get_connection() + await es.close() + +Search DSL +---------- + +Use the ``AsyncSearch`` class to perform asynchronous searches. + + .. code:: python + + from elasticsearch_dsl import AsyncSearch + + s = AsyncSearch().query("match", title="python") + async for hit in s: + print(hit.title) + +Instead of using the ``AsyncSearch`` object as an asynchronous iterator, you can +explicitly call the ``execute()`` method to get a ``Response`` object. + + .. code:: python + + s = AsyncSearch().query("match", title="python") + response = await s.execute() + for hit in response: + print(hit.title) + +An ``AsyncMultiSearch`` is available as well. + + .. code:: python + + from elasticsearch_dsl import AsyncMultiSearch + + ms = AsyncMultiSearch(index='blogs') + + ms = ms.add(AsyncSearch().filter('term', tags='python')) + ms = ms.add(AsyncSearch().filter('term', tags='elasticsearch')) + + responses = await ms.execute() + + for response in responses: + print("Results for query %r." % response.search.query) + for hit in response: + print(hit.title) + +Asynchronous Documents, Indexes, and more +----------------------------------------- + +The ``Document``, ``Index``, ``IndexTemplate``, ``Mapping``, ``UpdateByQuery`` and +``FacetedSearch`` classes all have asynchronous versions that use the same name +with an ``Async`` prefix. These classes expose the same interfaces as the +synchronous versions, but any methods that perform I/O are defined as coroutines. + +Auxiliary classes that do not perform I/O do not have asynchronous versions. The +same classes can be used in synchronous and asynchronous applications. + +When using a :ref:`custom analyzer ` in an asynchronous application, use +the ``async_simulate()`` method to invoke the Analyze API on it. + +Consult the :ref:`api` section for details about each specific method. diff --git a/docs/index.rst b/docs/index.rst index 76f77a56a..2592d0df0 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -8,7 +8,29 @@ low-level client (`elasticsearch-py `_:: pip install elasticsearch-dsl +For asynchronous applications, install with the ``async`` extra:: + + pip install elasticsearch-dsl[async] + +Read more about :ref:`how to use asyncio with this project `. + Examples -------- @@ -73,256 +101,6 @@ The recommended way to set your requirements in your `setup.py` or The development is happening on ``main``, older branches only get bugfix releases -Search Example --------------- - -Let's have a typical search request written directly as a ``dict``: - -.. code:: python - - from elasticsearch import Elasticsearch - client = Elasticsearch("https://localhost:9200") - - response = client.search( - index="my-index", - body={ - "query": { - "bool": { - "must": [{"match": {"title": "python"}}], - "must_not": [{"match": {"description": "beta"}}], - "filter": [{"term": {"category": "search"}}] - } - }, - "aggs" : { - "per_tag": { - "terms": {"field": "tags"}, - "aggs": { - "max_lines": {"max": {"field": "lines"}} - } - } - } - } - ) - - for hit in response['hits']['hits']: - print(hit['_score'], hit['_source']['title']) - - for tag in response['aggregations']['per_tag']['buckets']: - print(tag['key'], tag['max_lines']['value']) - - - -The problem with this approach is that it is very verbose, prone to syntax -mistakes like incorrect nesting, hard to modify (eg. adding another filter) and -definitely not fun to write. - -Let's rewrite the example using the Python DSL: - -.. code:: python - - from elasticsearch import Elasticsearch - from elasticsearch_dsl import Search - - client = Elasticsearch("https://localhost:9200") - - s = Search(using=client, index="my-index") \ - .filter("term", category="search") \ - .query("match", title="python") \ - .exclude("match", description="beta") - - s.aggs.bucket('per_tag', 'terms', field='tags') \ - .metric('max_lines', 'max', field='lines') - - response = s.execute() - - for hit in response: - print(hit.meta.score, hit.title) - - for tag in response.aggregations.per_tag.buckets: - print(tag.key, tag.max_lines.value) - -As you see, the library took care of: - -- creating appropriate ``Query`` objects by name (eq. "match") -- composing queries into a compound ``bool`` query -- putting the ``term`` query in a filter context of the ``bool`` query -- providing a convenient access to response data -- no curly or square brackets everywhere - - -Persistence Example -------------------- - -Let's have a simple Python class representing an article in a blogging system: - -.. code:: python - - from datetime import datetime - from elasticsearch_dsl import Document, Date, Integer, Keyword, Text, connections - - # Define a default Elasticsearch client - connections.create_connection(hosts="https://localhost:9200") - - class Article(Document): - title = Text(analyzer='snowball', fields={'raw': Keyword()}) - body = Text(analyzer='snowball') - tags = Keyword() - published_from = Date() - lines = Integer() - - class Index: - name = 'blog' - settings = { - "number_of_shards": 2, - } - - def save(self, ** kwargs): - self.lines = len(self.body.split()) - return super(Article, self).save(** kwargs) - - def is_published(self): - return datetime.now() > self.published_from - - # create the mappings in elasticsearch - Article.init() - - # create and save and article - article = Article(meta={'id': 42}, title='Hello world!', tags=['test']) - article.body = ''' looong text ''' - article.published_from = datetime.now() - article.save() - - article = Article.get(id=42) - print(article.is_published()) - - # Display cluster health - print(connections.get_connection().cluster.health()) - - -In this example you can see: - -- providing a default connection -- defining fields with mapping configuration -- setting index name -- defining custom methods -- overriding the built-in ``.save()`` method to hook into the persistence - life cycle -- retrieving and saving the object into Elasticsearch -- accessing the underlying client for other APIs - -You can see more in the :ref:`persistence` chapter. - - -Pre-built Faceted Search ------------------------- - -If you have your ``Document``\ s defined you can very easily create a faceted -search class to simplify searching and filtering. - -.. note:: - - This feature is experimental and may be subject to change. - -.. code:: python - - from elasticsearch_dsl import FacetedSearch, TermsFacet, DateHistogramFacet - - class BlogSearch(FacetedSearch): - doc_types = [Article, ] - # fields that should be searched - fields = ['tags', 'title', 'body'] - - facets = { - # use bucket aggregations to define facets - 'tags': TermsFacet(field='tags'), - 'publishing_frequency': DateHistogramFacet(field='published_from', interval='month') - } - - # empty search - bs = BlogSearch() - response = bs.execute() - - for hit in response: - print(hit.meta.score, hit.title) - - for (tag, count, selected) in response.facets.tags: - print(tag, ' (SELECTED):' if selected else ':', count) - - for (month, count, selected) in response.facets.publishing_frequency: - print(month.strftime('%B %Y'), ' (SELECTED):' if selected else ':', count) - -You can find more details in the :ref:`faceted_search` chapter. - - -Update By Query Example ------------------------- - -Let's resume the simple example of articles on a blog, and let's assume that each article has a number of likes. -For this example, imagine we want to increment the number of likes by 1 for all articles that match a certain tag and do not match a certain description. -Writing this as a ``dict``, we would have the following code: - -.. code:: python - - from elasticsearch import Elasticsearch - client = Elasticsearch() - - response = client.update_by_query( - index="my-index", - body={ - "query": { - "bool": { - "must": [{"match": {"tag": "python"}}], - "must_not": [{"match": {"description": "beta"}}] - } - }, - "script"={ - "source": "ctx._source.likes++", - "lang": "painless" - } - }, - ) - -Using the DSL, we can now express this query as such: - -.. code:: python - - from elasticsearch import Elasticsearch - from elasticsearch_dsl import Search, UpdateByQuery - - client = Elasticsearch() - ubq = UpdateByQuery(using=client, index="my-index") \ - .query("match", title="python") \ - .exclude("match", description="beta") \ - .script(source="ctx._source.likes++", lang="painless") - - response = ubq.execute() - -As you can see, the ``Update By Query`` object provides many of the savings offered -by the ``Search`` object, and additionally allows one to update the results of the search -based on a script assigned in the same manner. - -Migration from ``elasticsearch-py`` ------------------------------------ - -You don't have to port your entire application to get the benefits of the -Python DSL, you can start gradually by creating a ``Search`` object from your -existing ``dict``, modifying it using the API and serializing it back to a -``dict``: - -.. code:: python - - body = {...} # insert complicated query here - - # Convert to Search object - s = Search.from_dict(body) - - # Add some filters, aggregations, queries, ... - s.filter("term", tags="python") - - # Convert back to dict to plug back into existing code - body = s.to_dict() - - License ------- @@ -344,13 +122,38 @@ Contents -------- .. toctree:: + :caption: About :maxdepth: 2 + self configuration + +.. toctree:: + :caption: Tutorials + :maxdepth: 2 + + tutorials + +.. toctree:: + :caption: How-To Guides + :maxdepth: 2 + search_dsl persistence faceted_search update_by_query + asyncio + +.. toctree:: + :caption: Reference + :maxdepth: 2 + api + async_api + +.. toctree:: + :caption: Community + :maxdepth: 2 + CONTRIBUTING Changelog diff --git a/docs/tutorials.rst b/docs/tutorials.rst new file mode 100644 index 000000000..9b62ae905 --- /dev/null +++ b/docs/tutorials.rst @@ -0,0 +1,248 @@ +Search +------ + +Let's have a typical search request written directly as a ``dict``: + +.. code:: python + + from elasticsearch import Elasticsearch + client = Elasticsearch("https://localhost:9200") + + response = client.search( + index="my-index", + body={ + "query": { + "bool": { + "must": [{"match": {"title": "python"}}], + "must_not": [{"match": {"description": "beta"}}], + "filter": [{"term": {"category": "search"}}] + } + }, + "aggs" : { + "per_tag": { + "terms": {"field": "tags"}, + "aggs": { + "max_lines": {"max": {"field": "lines"}} + } + } + } + } + ) + + for hit in response['hits']['hits']: + print(hit['_score'], hit['_source']['title']) + + for tag in response['aggregations']['per_tag']['buckets']: + print(tag['key'], tag['max_lines']['value']) + + + +The problem with this approach is that it is very verbose, prone to syntax +mistakes like incorrect nesting, hard to modify (eg. adding another filter) and +definitely not fun to write. + +Let's rewrite the example using the Python DSL: + +.. code:: python + + from elasticsearch import Elasticsearch + from elasticsearch_dsl import Search + + client = Elasticsearch("https://localhost:9200") + + s = Search(using=client, index="my-index") \ + .filter("term", category="search") \ + .query("match", title="python") \ + .exclude("match", description="beta") + + s.aggs.bucket('per_tag', 'terms', field='tags') \ + .metric('max_lines', 'max', field='lines') + + response = s.execute() + + for hit in response: + print(hit.meta.score, hit.title) + + for tag in response.aggregations.per_tag.buckets: + print(tag.key, tag.max_lines.value) + +As you see, the library took care of: + +- creating appropriate ``Query`` objects by name (eq. "match") +- composing queries into a compound ``bool`` query +- putting the ``term`` query in a filter context of the ``bool`` query +- providing a convenient access to response data +- no curly or square brackets everywhere + + +Persistence +----------- + +Let's have a simple Python class representing an article in a blogging system: + +.. code:: python + + from datetime import datetime + from elasticsearch_dsl import Document, Date, Integer, Keyword, Text, connections + + # Define a default Elasticsearch client + connections.create_connection(hosts="https://localhost:9200") + + class Article(Document): + title = Text(analyzer='snowball', fields={'raw': Keyword()}) + body = Text(analyzer='snowball') + tags = Keyword() + published_from = Date() + lines = Integer() + + class Index: + name = 'blog' + settings = { + "number_of_shards": 2, + } + + def save(self, ** kwargs): + self.lines = len(self.body.split()) + return super(Article, self).save(** kwargs) + + def is_published(self): + return datetime.now() > self.published_from + + # create the mappings in elasticsearch + Article.init() + + # create and save and article + article = Article(meta={'id': 42}, title='Hello world!', tags=['test']) + article.body = ''' looong text ''' + article.published_from = datetime.now() + article.save() + + article = Article.get(id=42) + print(article.is_published()) + + # Display cluster health + print(connections.get_connection().cluster.health()) + + +In this example you can see: + +- providing a default connection +- defining fields with mapping configuration +- setting index name +- defining custom methods +- overriding the built-in ``.save()`` method to hook into the persistence + life cycle +- retrieving and saving the object into Elasticsearch +- accessing the underlying client for other APIs + +You can see more in the :ref:`persistence` chapter. + + +Pre-built Faceted Search +------------------------ + +If you have your ``Document``\ s defined you can very easily create a faceted +search class to simplify searching and filtering. + +.. note:: + + This feature is experimental and may be subject to change. + +.. code:: python + + from elasticsearch_dsl import FacetedSearch, TermsFacet, DateHistogramFacet + + class BlogSearch(FacetedSearch): + doc_types = [Article, ] + # fields that should be searched + fields = ['tags', 'title', 'body'] + + facets = { + # use bucket aggregations to define facets + 'tags': TermsFacet(field='tags'), + 'publishing_frequency': DateHistogramFacet(field='published_from', interval='month') + } + + # empty search + bs = BlogSearch() + response = bs.execute() + + for hit in response: + print(hit.meta.score, hit.title) + + for (tag, count, selected) in response.facets.tags: + print(tag, ' (SELECTED):' if selected else ':', count) + + for (month, count, selected) in response.facets.publishing_frequency: + print(month.strftime('%B %Y'), ' (SELECTED):' if selected else ':', count) + +You can find more details in the :ref:`faceted_search` chapter. + + +Update By Query +--------------- + +Let's resume the simple example of articles on a blog, and let's assume that each article has a number of likes. +For this example, imagine we want to increment the number of likes by 1 for all articles that match a certain tag and do not match a certain description. +Writing this as a ``dict``, we would have the following code: + +.. code:: python + + from elasticsearch import Elasticsearch + client = Elasticsearch() + + response = client.update_by_query( + index="my-index", + body={ + "query": { + "bool": { + "must": [{"match": {"tag": "python"}}], + "must_not": [{"match": {"description": "beta"}}] + } + }, + "script"={ + "source": "ctx._source.likes++", + "lang": "painless" + } + }, + ) + +Using the DSL, we can now express this query as such: + +.. code:: python + + from elasticsearch import Elasticsearch + from elasticsearch_dsl import Search, UpdateByQuery + + client = Elasticsearch() + ubq = UpdateByQuery(using=client, index="my-index") \ + .query("match", title="python") \ + .exclude("match", description="beta") \ + .script(source="ctx._source.likes++", lang="painless") + + response = ubq.execute() + +As you can see, the ``Update By Query`` object provides many of the savings offered +by the ``Search`` object, and additionally allows one to update the results of the search +based on a script assigned in the same manner. + +Migration from ``elasticsearch-py`` +----------------------------------- + +You don't have to port your entire application to get the benefits of the +Python DSL, you can start gradually by creating a ``Search`` object from your +existing ``dict``, modifying it using the API and serializing it back to a +``dict``: + +.. code:: python + + body = {...} # insert complicated query here + + # Convert to Search object + s = Search.from_dict(body) + + # Add some filters, aggregations, queries, ... + s.filter("term", tags="python") + + # Convert back to dict to plug back into existing code + body = s.to_dict() diff --git a/elasticsearch_dsl/__init__.py b/elasticsearch_dsl/__init__.py index afb5b275f..42e499fee 100644 --- a/elasticsearch_dsl/__init__.py +++ b/elasticsearch_dsl/__init__.py @@ -18,7 +18,8 @@ from . import connections from .aggs import A from .analysis import analyzer, char_filter, normalizer, token_filter, tokenizer -from .document import Document, InnerDoc, MetaField +from .document import AsyncDocument, Document +from .document_base import InnerDoc, MetaField from .exceptions import ( ElasticsearchDslException, IllegalOperation, @@ -26,6 +27,7 @@ ValidationException, ) from .faceted_search import ( + AsyncFacetedSearch, DateHistogramFacet, Facet, FacetedResponse, @@ -76,11 +78,11 @@ construct_field, ) from .function import SF -from .index import Index, IndexTemplate -from .mapping import Mapping +from .index import AsyncIndex, AsyncIndexTemplate, Index, IndexTemplate +from .mapping import AsyncMapping, Mapping from .query import Q -from .search import MultiSearch, Search -from .update_by_query import UpdateByQuery +from .search import AsyncMultiSearch, AsyncSearch, MultiSearch, Search +from .update_by_query import AsyncUpdateByQuery, UpdateByQuery from .utils import AttrDict, AttrList, DslBase from .wrappers import Range @@ -89,6 +91,14 @@ __versionstr__ = ".".join(map(str, VERSION)) __all__ = [ "A", + "AsyncDocument", + "AsyncFacetedSearch", + "AsyncIndex", + "AsyncIndexTemplate", + "AsyncMapping", + "AsyncMultiSearch", + "AsyncSearch", + "AsyncUpdateByQuery", "AttrDict", "AttrList", "Binary", diff --git a/elasticsearch_dsl/_async/__init__.py b/elasticsearch_dsl/_async/__init__.py new file mode 100644 index 000000000..2a87d183f --- /dev/null +++ b/elasticsearch_dsl/_async/__init__.py @@ -0,0 +1,16 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/elasticsearch_dsl/_async/document.py b/elasticsearch_dsl/_async/document.py new file mode 100644 index 000000000..8f1b50a2a --- /dev/null +++ b/elasticsearch_dsl/_async/document.py @@ -0,0 +1,374 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import collections.abc + +from elasticsearch.exceptions import NotFoundError, RequestError + +from .._async.index import AsyncIndex +from ..async_connections import get_connection +from ..document_base import DocumentBase, DocumentMeta +from ..exceptions import IllegalOperation +from ..utils import DOC_META_FIELDS, META_FIELDS, merge +from .search import AsyncSearch + + +class AsyncIndexMeta(DocumentMeta): + # global flag to guard us from associating an Index with the base Document + # class, only user defined subclasses should have an _index attr + _document_initialized = False + + def __new__(cls, name, bases, attrs): + new_cls = super().__new__(cls, name, bases, attrs) + if cls._document_initialized: + index_opts = attrs.pop("Index", None) + index = cls.construct_index(index_opts, bases) + new_cls._index = index + index.document(new_cls) + cls._document_initialized = True + return new_cls + + @classmethod + def construct_index(cls, opts, bases): + if opts is None: + for b in bases: + if hasattr(b, "_index"): + return b._index + + # Set None as Index name so it will set _all while making the query + return AsyncIndex(name=None) + + i = AsyncIndex( + getattr(opts, "name", "*"), using=getattr(opts, "using", "default") + ) + i.settings(**getattr(opts, "settings", {})) + i.aliases(**getattr(opts, "aliases", {})) + for a in getattr(opts, "analyzers", ()): + i.analyzer(a) + return i + + +class AsyncDocument(DocumentBase, metaclass=AsyncIndexMeta): + """ + Model-like class for persisting documents in elasticsearch. + """ + + @classmethod + def _get_connection(cls, using=None): + return get_connection(cls._get_using(using)) + + @classmethod + async def init(cls, index=None, using=None): + """ + Create the index and populate the mappings in elasticsearch. + """ + i = cls._index + if index: + i = i.clone(name=index) + await i.save(using=using) + + @classmethod + def search(cls, using=None, index=None): + """ + Create an :class:`~elasticsearch_dsl.Search` instance that will search + over this ``Document``. + """ + return AsyncSearch( + using=cls._get_using(using), index=cls._default_index(index), doc_type=[cls] + ) + + @classmethod + async def get(cls, id, using=None, index=None, **kwargs): + """ + Retrieve a single document from elasticsearch using its ``id``. + + :arg id: ``id`` of the document to be retrieved + :arg index: elasticsearch index to use, if the ``Document`` is + associated with an index this can be omitted. + :arg using: connection alias to use, defaults to ``'default'`` + + Any additional keyword arguments will be passed to + ``Elasticsearch.get`` unchanged. + """ + es = cls._get_connection(using) + doc = await es.get(index=cls._default_index(index), id=id, **kwargs) + if not doc.get("found", False): + return None + return cls.from_es(doc) + + @classmethod + async def exists(cls, id, using=None, index=None, **kwargs): + """ + check if exists a single document from elasticsearch using its ``id``. + + :arg id: ``id`` of the document to check if exists + :arg index: elasticsearch index to use, if the ``Document`` is + associated with an index this can be omitted. + :arg using: connection alias to use, defaults to ``'default'`` + + Any additional keyword arguments will be passed to + ``Elasticsearch.exists`` unchanged. + """ + es = cls._get_connection(using) + return await es.exists(index=cls._default_index(index), id=id, **kwargs) + + @classmethod + async def mget( + cls, docs, using=None, index=None, raise_on_error=True, missing="none", **kwargs + ): + r""" + Retrieve multiple document by their ``id``\s. Returns a list of instances + in the same order as requested. + + :arg docs: list of ``id``\s of the documents to be retrieved or a list + of document specifications as per + https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-multi-get.html + :arg index: elasticsearch index to use, if the ``Document`` is + associated with an index this can be omitted. + :arg using: connection alias to use, defaults to ``'default'`` + :arg missing: what to do when one of the documents requested is not + found. Valid options are ``'none'`` (use ``None``), ``'raise'`` (raise + ``NotFoundError``) or ``'skip'`` (ignore the missing document). + + Any additional keyword arguments will be passed to + ``Elasticsearch.mget`` unchanged. + """ + if missing not in ("raise", "skip", "none"): + raise ValueError("'missing' must be 'raise', 'skip', or 'none'.") + es = cls._get_connection(using) + body = { + "docs": [ + doc if isinstance(doc, collections.abc.Mapping) else {"_id": doc} + for doc in docs + ] + } + results = await es.mget(index=cls._default_index(index), body=body, **kwargs) + + objs, error_docs, missing_docs = [], [], [] + for doc in results["docs"]: + if doc.get("found"): + if error_docs or missing_docs: + # We're going to raise an exception anyway, so avoid an + # expensive call to cls.from_es(). + continue + + objs.append(cls.from_es(doc)) + + elif doc.get("error"): + if raise_on_error: + error_docs.append(doc) + if missing == "none": + objs.append(None) + + # The doc didn't cause an error, but the doc also wasn't found. + elif missing == "raise": + missing_docs.append(doc) + elif missing == "none": + objs.append(None) + + if error_docs: + error_ids = [doc["_id"] for doc in error_docs] + message = "Required routing not provided for documents %s." + message %= ", ".join(error_ids) + raise RequestError(400, message, error_docs) + if missing_docs: + missing_ids = [doc["_id"] for doc in missing_docs] + message = f"Documents {', '.join(missing_ids)} not found." + raise NotFoundError(404, message, {"docs": missing_docs}) + return objs + + async def delete(self, using=None, index=None, **kwargs): + """ + Delete the instance in elasticsearch. + + :arg index: elasticsearch index to use, if the ``Document`` is + associated with an index this can be omitted. + :arg using: connection alias to use, defaults to ``'default'`` + + Any additional keyword arguments will be passed to + ``Elasticsearch.delete`` unchanged. + """ + es = self._get_connection(using) + # extract routing etc from meta + doc_meta = {k: self.meta[k] for k in DOC_META_FIELDS if k in self.meta} + + # Optimistic concurrency control + if "seq_no" in self.meta and "primary_term" in self.meta: + doc_meta["if_seq_no"] = self.meta["seq_no"] + doc_meta["if_primary_term"] = self.meta["primary_term"] + + doc_meta.update(kwargs) + await es.delete(index=self._get_index(index), **doc_meta) + + async def update( + self, + using=None, + index=None, + detect_noop=True, + doc_as_upsert=False, + refresh=False, + retry_on_conflict=None, + script=None, + script_id=None, + scripted_upsert=False, + upsert=None, + return_doc_meta=False, + **fields, + ): + """ + Partial update of the document, specify fields you wish to update and + both the instance and the document in elasticsearch will be updated:: + + doc = MyDocument(title='Document Title!') + doc.save() + doc.update(title='New Document Title!') + + :arg index: elasticsearch index to use, if the ``Document`` is + associated with an index this can be omitted. + :arg using: connection alias to use, defaults to ``'default'`` + :arg detect_noop: Set to ``False`` to disable noop detection. + :arg refresh: Control when the changes made by this request are visible + to search. Set to ``True`` for immediate effect. + :arg retry_on_conflict: In between the get and indexing phases of the + update, it is possible that another process might have already + updated the same document. By default, the update will fail with a + version conflict exception. The retry_on_conflict parameter + controls how many times to retry the update before finally throwing + an exception. + :arg doc_as_upsert: Instead of sending a partial doc plus an upsert + doc, setting doc_as_upsert to true will use the contents of doc as + the upsert value + :arg return_doc_meta: set to ``True`` to return all metadata from the + index API call instead of only the operation result + + :return operation result noop/updated + """ + body = { + "doc_as_upsert": doc_as_upsert, + "detect_noop": detect_noop, + } + + # scripted update + if script or script_id: + if upsert is not None: + body["upsert"] = upsert + + if script: + script = {"source": script} + else: + script = {"id": script_id} + + script["params"] = fields + + body["script"] = script + body["scripted_upsert"] = scripted_upsert + + # partial document update + else: + if not fields: + raise IllegalOperation( + "You cannot call update() without updating individual fields or a script. " + "If you wish to update the entire object use save()." + ) + + # update given fields locally + merge(self, fields) + + # prepare data for ES + values = self.to_dict() + + # if fields were given: partial update + body["doc"] = {k: values.get(k) for k in fields.keys()} + + # extract routing etc from meta + doc_meta = {k: self.meta[k] for k in DOC_META_FIELDS if k in self.meta} + + if retry_on_conflict is not None: + doc_meta["retry_on_conflict"] = retry_on_conflict + + # Optimistic concurrency control + if ( + retry_on_conflict in (None, 0) + and "seq_no" in self.meta + and "primary_term" in self.meta + ): + doc_meta["if_seq_no"] = self.meta["seq_no"] + doc_meta["if_primary_term"] = self.meta["primary_term"] + + meta = await self._get_connection(using).update( + index=self._get_index(index), body=body, refresh=refresh, **doc_meta + ) + # update meta information from ES + for k in META_FIELDS: + if "_" + k in meta: + setattr(self.meta, k, meta["_" + k]) + + return meta if return_doc_meta else meta["result"] + + async def save( + self, + using=None, + index=None, + validate=True, + skip_empty=True, + return_doc_meta=False, + **kwargs, + ): + """ + Save the document into elasticsearch. If the document doesn't exist it + is created, it is overwritten otherwise. Returns ``True`` if this + operations resulted in new document being created. + + :arg index: elasticsearch index to use, if the ``Document`` is + associated with an index this can be omitted. + :arg using: connection alias to use, defaults to ``'default'`` + :arg validate: set to ``False`` to skip validating the document + :arg skip_empty: if set to ``False`` will cause empty values (``None``, + ``[]``, ``{}``) to be left on the document. Those values will be + stripped out otherwise as they make no difference in elasticsearch. + :arg return_doc_meta: set to ``True`` to return all metadata from the + update API call instead of only the operation result + + Any additional keyword arguments will be passed to + ``Elasticsearch.index`` unchanged. + + :return operation result created/updated + """ + if validate: + self.full_clean() + + es = self._get_connection(using) + # extract routing etc from meta + doc_meta = {k: self.meta[k] for k in DOC_META_FIELDS if k in self.meta} + + # Optimistic concurrency control + if "seq_no" in self.meta and "primary_term" in self.meta: + doc_meta["if_seq_no"] = self.meta["seq_no"] + doc_meta["if_primary_term"] = self.meta["primary_term"] + + doc_meta.update(kwargs) + meta = await es.index( + index=self._get_index(index), + body=self.to_dict(skip_empty=skip_empty), + **doc_meta, + ) + # update meta information from ES + for k in META_FIELDS: + if "_" + k in meta: + setattr(self.meta, k, meta["_" + k]) + + return meta if return_doc_meta else meta["result"] diff --git a/elasticsearch_dsl/_async/faceted_search.py b/elasticsearch_dsl/_async/faceted_search.py new file mode 100644 index 000000000..a446a70bb --- /dev/null +++ b/elasticsearch_dsl/_async/faceted_search.py @@ -0,0 +1,40 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from elasticsearch_dsl.faceted_search_base import FacetedResponse, FacetedSearchBase + +from .search import AsyncSearch + + +class AsyncFacetedSearch(FacetedSearchBase): + def search(self): + """ + Returns the base Search object to which the facets are added. + + You can customize the query by overriding this method and returning a + modified search object. + """ + s = AsyncSearch(doc_type=self.doc_types, index=self.index, using=self.using) + return s.response_class(FacetedResponse) + + async def execute(self): + """ + Execute the search and return the response. + """ + r = await self._s.execute() + r._faceted_search = self + return r diff --git a/elasticsearch_dsl/_async/index.py b/elasticsearch_dsl/_async/index.py new file mode 100644 index 000000000..daada3e06 --- /dev/null +++ b/elasticsearch_dsl/_async/index.py @@ -0,0 +1,542 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from ..async_connections import get_connection +from ..exceptions import IllegalOperation +from ..index_base import IndexBase +from .mapping import AsyncMapping +from .search import AsyncSearch +from .update_by_query import AsyncUpdateByQuery + + +class AsyncIndexTemplate: + def __init__(self, name, template, index=None, order=None, **kwargs): + if index is None: + self._index = AsyncIndex(template, **kwargs) + else: + if kwargs: + raise ValueError( + "You cannot specify options for Index when" + " passing an Index instance." + ) + self._index = index.clone() + self._index._name = template + self._template_name = name + self.order = order + + def __getattr__(self, attr_name): + return getattr(self._index, attr_name) + + def to_dict(self): + d = self._index.to_dict() + d["index_patterns"] = [self._index._name] + if self.order is not None: + d["order"] = self.order + return d + + async def save(self, using=None): + es = get_connection(using or self._index._using) + return await es.indices.put_template( + name=self._template_name, body=self.to_dict() + ) + + +class AsyncIndex(IndexBase): + def __init__(self, name, using="default"): + """ + :arg name: name of the index + :arg using: connection alias to use, defaults to ``'default'`` + """ + super().__init__(name, AsyncMapping, using=using) + + def _get_connection(self, using=None): + if self._name is None: + raise ValueError("You cannot perform API calls on the default index.") + return get_connection(using or self._using) + + connection = property(_get_connection) + + def as_template(self, template_name, pattern=None, order=None): + # TODO: should we allow pattern to be a top-level arg? + # or maybe have an IndexPattern that allows for it and have + # Document._index be that? + return AsyncIndexTemplate( + template_name, pattern or self._name, index=self, order=order + ) + + async def load_mappings(self, using=None): + await self.get_or_create_mapping().update_from_es( + self._name, using=using or self._using + ) + + def clone(self, name=None, using=None): + """ + Create a copy of the instance with another name or connection alias. + Useful for creating multiple indices with shared configuration:: + + i = Index('base-index') + i.settings(number_of_shards=1) + i.create() + + i2 = i.clone('other-index') + i2.create() + + :arg name: name of the index + :arg using: connection alias to use, defaults to ``'default'`` + """ + i = AsyncIndex(name or self._name, using=using or self._using) + i._settings = self._settings.copy() + i._aliases = self._aliases.copy() + i._analysis = self._analysis.copy() + i._doc_types = self._doc_types[:] + if self._mapping is not None: + i._mapping = self._mapping._clone() + return i + + def search(self, using=None): + """ + Return a :class:`~elasticsearch_dsl.Search` object searching over the + index (or all the indices belonging to this template) and its + ``Document``\\s. + """ + return AsyncSearch( + using=using or self._using, index=self._name, doc_type=self._doc_types + ) + + def updateByQuery(self, using=None): + """ + Return a :class:`~elasticsearch_dsl.UpdateByQuery` object searching over the index + (or all the indices belonging to this template) and updating Documents that match + the search criteria. + + For more information, see here: + https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-update-by-query.html + """ + return AsyncUpdateByQuery( + using=using or self._using, + index=self._name, + ) + + async def create(self, using=None, **kwargs): + """ + Creates the index in elasticsearch. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.create`` unchanged. + """ + return await self._get_connection(using).indices.create( + index=self._name, body=self.to_dict(), **kwargs + ) + + async def is_closed(self, using=None): + state = await self._get_connection(using).cluster.state( + index=self._name, metric="metadata" + ) + return state["metadata"]["indices"][self._name]["state"] == "close" + + async def save(self, using=None): + """ + Sync the index definition with elasticsearch, creating the index if it + doesn't exist and updating its settings and mappings if it does. + + Note some settings and mapping changes cannot be done on an open + index (or at all on an existing index) and for those this method will + fail with the underlying exception. + """ + if not await self.exists(using=using): + return await self.create(using=using) + + body = self.to_dict() + settings = body.pop("settings", {}) + analysis = settings.pop("analysis", None) + current_settings = (await self.get_settings(using=using))[self._name][ + "settings" + ]["index"] + if analysis: + if await self.is_closed(using=using): + # closed index, update away + settings["analysis"] = analysis + else: + # compare analysis definition, if all analysis objects are + # already defined as requested, skip analysis update and + # proceed, otherwise raise IllegalOperation + existing_analysis = current_settings.get("analysis", {}) + if any( + existing_analysis.get(section, {}).get(k, None) + != analysis[section][k] + for section in analysis + for k in analysis[section] + ): + raise IllegalOperation( + "You cannot update analysis configuration on an open index, " + "you need to close index %s first." % self._name + ) + + # try and update the settings + if settings: + settings = settings.copy() + for k, v in list(settings.items()): + if k in current_settings and current_settings[k] == str(v): + del settings[k] + + if settings: + await self.put_settings(using=using, body=settings) + + # update the mappings, any conflict in the mappings will result in an + # exception + mappings = body.pop("mappings", {}) + if mappings: + await self.put_mapping(using=using, body=mappings) + + async def analyze(self, using=None, **kwargs): + """ + Perform the analysis process on a text and return the tokens breakdown + of the text. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.analyze`` unchanged. + """ + return await self._get_connection(using).indices.analyze( + index=self._name, **kwargs + ) + + async def refresh(self, using=None, **kwargs): + """ + Performs a refresh operation on the index. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.refresh`` unchanged. + """ + return await self._get_connection(using).indices.refresh( + index=self._name, **kwargs + ) + + async def flush(self, using=None, **kwargs): + """ + Performs a flush operation on the index. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.flush`` unchanged. + """ + return await self._get_connection(using).indices.flush( + index=self._name, **kwargs + ) + + async def get(self, using=None, **kwargs): + """ + The get index API allows to retrieve information about the index. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.get`` unchanged. + """ + return await self._get_connection(using).indices.get(index=self._name, **kwargs) + + async def open(self, using=None, **kwargs): + """ + Opens the index in elasticsearch. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.open`` unchanged. + """ + return await self._get_connection(using).indices.open( + index=self._name, **kwargs + ) + + async def close(self, using=None, **kwargs): + """ + Closes the index in elasticsearch. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.close`` unchanged. + """ + return await self._get_connection(using).indices.close( + index=self._name, **kwargs + ) + + async def delete(self, using=None, **kwargs): + """ + Deletes the index in elasticsearch. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.delete`` unchanged. + """ + return await self._get_connection(using).indices.delete( + index=self._name, **kwargs + ) + + async def exists(self, using=None, **kwargs): + """ + Returns ``True`` if the index already exists in elasticsearch. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.exists`` unchanged. + """ + return await self._get_connection(using).indices.exists( + index=self._name, **kwargs + ) + + async def exists_type(self, using=None, **kwargs): + """ + Check if a type/types exists in the index. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.exists_type`` unchanged. + """ + return await self._get_connection(using).indices.exists_type( + index=self._name, **kwargs + ) + + async def put_mapping(self, using=None, **kwargs): + """ + Register specific mapping definition for a specific type. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.put_mapping`` unchanged. + """ + return await self._get_connection(using).indices.put_mapping( + index=self._name, **kwargs + ) + + async def get_mapping(self, using=None, **kwargs): + """ + Retrieve specific mapping definition for a specific type. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.get_mapping`` unchanged. + """ + return await self._get_connection(using).indices.get_mapping( + index=self._name, **kwargs + ) + + async def get_field_mapping(self, using=None, **kwargs): + """ + Retrieve mapping definition of a specific field. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.get_field_mapping`` unchanged. + """ + return await self._get_connection(using).indices.get_field_mapping( + index=self._name, **kwargs + ) + + async def put_alias(self, using=None, **kwargs): + """ + Create an alias for the index. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.put_alias`` unchanged. + """ + return await self._get_connection(using).indices.put_alias( + index=self._name, **kwargs + ) + + async def exists_alias(self, using=None, **kwargs): + """ + Return a boolean indicating whether given alias exists for this index. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.exists_alias`` unchanged. + """ + return await self._get_connection(using).indices.exists_alias( + index=self._name, **kwargs + ) + + async def get_alias(self, using=None, **kwargs): + """ + Retrieve a specified alias. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.get_alias`` unchanged. + """ + return await self._get_connection(using).indices.get_alias( + index=self._name, **kwargs + ) + + async def delete_alias(self, using=None, **kwargs): + """ + Delete specific alias. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.delete_alias`` unchanged. + """ + return await self._get_connection(using).indices.delete_alias( + index=self._name, **kwargs + ) + + async def get_settings(self, using=None, **kwargs): + """ + Retrieve settings for the index. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.get_settings`` unchanged. + """ + return await self._get_connection(using).indices.get_settings( + index=self._name, **kwargs + ) + + async def put_settings(self, using=None, **kwargs): + """ + Change specific index level settings in real time. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.put_settings`` unchanged. + """ + return await self._get_connection(using).indices.put_settings( + index=self._name, **kwargs + ) + + async def stats(self, using=None, **kwargs): + """ + Retrieve statistics on different operations happening on the index. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.stats`` unchanged. + """ + return await self._get_connection(using).indices.stats( + index=self._name, **kwargs + ) + + async def segments(self, using=None, **kwargs): + """ + Provide low level segments information that a Lucene index (shard + level) is built with. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.segments`` unchanged. + """ + return await self._get_connection(using).indices.segments( + index=self._name, **kwargs + ) + + async def validate_query(self, using=None, **kwargs): + """ + Validate a potentially expensive query without executing it. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.validate_query`` unchanged. + """ + return await self._get_connection(using).indices.validate_query( + index=self._name, **kwargs + ) + + async def clear_cache(self, using=None, **kwargs): + """ + Clear all caches or specific cached associated with the index. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.clear_cache`` unchanged. + """ + return await self._get_connection(using).indices.clear_cache( + index=self._name, **kwargs + ) + + async def recovery(self, using=None, **kwargs): + """ + The indices recovery API provides insight into on-going shard + recoveries for the index. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.recovery`` unchanged. + """ + return await self._get_connection(using).indices.recovery( + index=self._name, **kwargs + ) + + async def upgrade(self, using=None, **kwargs): + """ + Upgrade the index to the latest format. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.upgrade`` unchanged. + """ + return await self._get_connection(using).indices.upgrade( + index=self._name, **kwargs + ) + + async def get_upgrade(self, using=None, **kwargs): + """ + Monitor how much of the index is upgraded. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.get_upgrade`` unchanged. + """ + return await self._get_connection(using).indices.get_upgrade( + index=self._name, **kwargs + ) + + async def flush_synced(self, using=None, **kwargs): + """ + Perform a normal flush, then add a generated unique marker (sync_id) to + all shards. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.flush_synced`` unchanged. + """ + return await self._get_connection(using).indices.flush_synced( + index=self._name, **kwargs + ) + + async def shard_stores(self, using=None, **kwargs): + """ + Provides store information for shard copies of the index. Store + information reports on which nodes shard copies exist, the shard copy + version, indicating how recent they are, and any exceptions encountered + while opening the shard index or from earlier engine failure. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.shard_stores`` unchanged. + """ + return await self._get_connection(using).indices.shard_stores( + index=self._name, **kwargs + ) + + async def forcemerge(self, using=None, **kwargs): + """ + The force merge API allows to force merging of the index through an + API. The merge relates to the number of segments a Lucene index holds + within each shard. The force merge operation allows to reduce the + number of segments by merging them. + + This call will block until the merge is complete. If the http + connection is lost, the request will continue in the background, and + any new requests will block until the previous force merge is complete. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.forcemerge`` unchanged. + """ + return await self._get_connection(using).indices.forcemerge( + index=self._name, **kwargs + ) + + async def shrink(self, using=None, **kwargs): + """ + The shrink index API allows you to shrink an existing index into a new + index with fewer primary shards. The number of primary shards in the + target index must be a factor of the shards in the source index. For + example an index with 8 primary shards can be shrunk into 4, 2 or 1 + primary shards or an index with 15 primary shards can be shrunk into 5, + 3 or 1. If the number of shards in the index is a prime number it can + only be shrunk into a single primary shard. Before shrinking, a + (primary or replica) copy of every shard in the index must be present + on the same node. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.shrink`` unchanged. + """ + return await self._get_connection(using).indices.shrink( + index=self._name, **kwargs + ) diff --git a/elasticsearch_dsl/_async/mapping.py b/elasticsearch_dsl/_async/mapping.py new file mode 100644 index 000000000..a2d032c96 --- /dev/null +++ b/elasticsearch_dsl/_async/mapping.py @@ -0,0 +1,40 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from ..async_connections import get_connection +from ..mapping_base import MappingBase + + +class AsyncMapping(MappingBase): + @classmethod + async def from_es(cls, index, using="default"): + m = cls() + await m.update_from_es(index, using) + return m + + async def update_from_es(self, index, using="default"): + es = get_connection(using) + raw = await es.indices.get_mapping(index=index) + _, raw = raw.popitem() + self._update_from_dict(raw["mappings"]) + + async def save(self, index, using="default"): + from .index import AsyncIndex + + index = AsyncIndex(index, using=using) + index.mapping(self) + return await index.save() diff --git a/elasticsearch_dsl/_async/search.py b/elasticsearch_dsl/_async/search.py new file mode 100644 index 000000000..25885e08f --- /dev/null +++ b/elasticsearch_dsl/_async/search.py @@ -0,0 +1,146 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from elasticsearch.exceptions import ApiError +from elasticsearch.helpers import async_scan + +from ..async_connections import get_connection +from ..response import Response +from ..search_base import MultiSearchBase, SearchBase +from ..utils import AttrDict + + +class AsyncSearch(SearchBase): + def __aiter__(self): + """ + Iterate over the hits. + """ + + class ResultsIterator: + def __init__(self, search): + self.search = search + self.iterator = None + + async def __anext__(self): + if self.iterator is None: + self.iterator = iter(await self.search.execute()) + try: + return next(self.iterator) + except StopIteration: + raise StopAsyncIteration() + + return ResultsIterator(self) + + async def count(self): + """ + Return the number of hits matching the query and filters. Note that + only the actual number is returned. + """ + if hasattr(self, "_response") and self._response.hits.total.relation == "eq": + return self._response.hits.total.value + + es = get_connection(self._using) + + d = self.to_dict(count=True) + # TODO: failed shards detection + resp = await es.count( + index=self._index, query=d.get("query", None), **self._params + ) + return resp["count"] + + async def execute(self, ignore_cache=False): + """ + Execute the search and return an instance of ``Response`` wrapping all + the data. + + :arg ignore_cache: if set to ``True``, consecutive calls will hit + ES, while cached result will be ignored. Defaults to `False` + """ + if ignore_cache or not hasattr(self, "_response"): + es = get_connection(self._using) + + self._response = self._response_class( + self, + ( + await es.search( + index=self._index, body=self.to_dict(), **self._params + ) + ).body, + ) + return self._response + + async def scan(self): + """ + Turn the search into a scan search and return a generator that will + iterate over all the documents matching the query. + + Use ``params`` method to specify any additional arguments you with to + pass to the underlying ``scan`` helper from ``elasticsearch-py`` - + https://elasticsearch-py.readthedocs.io/en/master/helpers.html#elasticsearch.helpers.scan + + """ + es = get_connection(self._using) + + async for hit in async_scan( + es, query=self.to_dict(), index=self._index, **self._params + ): + yield self._get_result(hit) + + async def delete(self): + """ + delete() executes the query by delegating to delete_by_query() + """ + + es = get_connection(self._using) + + return AttrDict( + await es.delete_by_query( + index=self._index, body=self.to_dict(), **self._params + ) + ) + + +class AsyncMultiSearch(MultiSearchBase): + """ + Combine multiple :class:`~elasticsearch_dsl.Search` objects into a single + request. + """ + + async def execute(self, ignore_cache=False, raise_on_error=True): + """ + Execute the multi search request and return a list of search results. + """ + if ignore_cache or not hasattr(self, "_response"): + es = get_connection(self._using) + + responses = await es.msearch( + index=self._index, body=self.to_dict(), **self._params + ) + + out = [] + for s, r in zip(self._searches, responses["responses"]): + if r.get("error", False): + if raise_on_error: + raise ApiError("N/A", meta=responses.meta, body=r) + r = None + else: + r = Response(s, r) + out.append(r) + + self._response = out + + return self._response diff --git a/elasticsearch_dsl/_async/update_by_query.py b/elasticsearch_dsl/_async/update_by_query.py new file mode 100644 index 000000000..eb18bbf28 --- /dev/null +++ b/elasticsearch_dsl/_async/update_by_query.py @@ -0,0 +1,36 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from ..async_connections import get_connection +from ..update_by_query_base import UpdateByQueryBase + + +class AsyncUpdateByQuery(UpdateByQueryBase): + async def execute(self): + """ + Execute the search and return an instance of ``Response`` wrapping all + the data. + """ + es = get_connection(self._using) + + self._response = self._response_class( + self, + await es.update_by_query( + index=self._index, **self.to_dict(), **self._params + ), + ) + return self._response diff --git a/elasticsearch_dsl/_sync/__init__.py b/elasticsearch_dsl/_sync/__init__.py new file mode 100644 index 000000000..2a87d183f --- /dev/null +++ b/elasticsearch_dsl/_sync/__init__.py @@ -0,0 +1,16 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/elasticsearch_dsl/_sync/document.py b/elasticsearch_dsl/_sync/document.py new file mode 100644 index 000000000..05df47536 --- /dev/null +++ b/elasticsearch_dsl/_sync/document.py @@ -0,0 +1,372 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import collections.abc + +from elasticsearch.exceptions import NotFoundError, RequestError + +from .._sync.index import Index +from ..connections import get_connection +from ..document_base import DocumentBase, DocumentMeta +from ..exceptions import IllegalOperation +from ..utils import DOC_META_FIELDS, META_FIELDS, merge +from .search import Search + + +class IndexMeta(DocumentMeta): + # global flag to guard us from associating an Index with the base Document + # class, only user defined subclasses should have an _index attr + _document_initialized = False + + def __new__(cls, name, bases, attrs): + new_cls = super().__new__(cls, name, bases, attrs) + if cls._document_initialized: + index_opts = attrs.pop("Index", None) + index = cls.construct_index(index_opts, bases) + new_cls._index = index + index.document(new_cls) + cls._document_initialized = True + return new_cls + + @classmethod + def construct_index(cls, opts, bases): + if opts is None: + for b in bases: + if hasattr(b, "_index"): + return b._index + + # Set None as Index name so it will set _all while making the query + return Index(name=None) + + i = Index(getattr(opts, "name", "*"), using=getattr(opts, "using", "default")) + i.settings(**getattr(opts, "settings", {})) + i.aliases(**getattr(opts, "aliases", {})) + for a in getattr(opts, "analyzers", ()): + i.analyzer(a) + return i + + +class Document(DocumentBase, metaclass=IndexMeta): + """ + Model-like class for persisting documents in elasticsearch. + """ + + @classmethod + def _get_connection(cls, using=None): + return get_connection(cls._get_using(using)) + + @classmethod + def init(cls, index=None, using=None): + """ + Create the index and populate the mappings in elasticsearch. + """ + i = cls._index + if index: + i = i.clone(name=index) + i.save(using=using) + + @classmethod + def search(cls, using=None, index=None): + """ + Create an :class:`~elasticsearch_dsl.Search` instance that will search + over this ``Document``. + """ + return Search( + using=cls._get_using(using), index=cls._default_index(index), doc_type=[cls] + ) + + @classmethod + def get(cls, id, using=None, index=None, **kwargs): + """ + Retrieve a single document from elasticsearch using its ``id``. + + :arg id: ``id`` of the document to be retrieved + :arg index: elasticsearch index to use, if the ``Document`` is + associated with an index this can be omitted. + :arg using: connection alias to use, defaults to ``'default'`` + + Any additional keyword arguments will be passed to + ``Elasticsearch.get`` unchanged. + """ + es = cls._get_connection(using) + doc = es.get(index=cls._default_index(index), id=id, **kwargs) + if not doc.get("found", False): + return None + return cls.from_es(doc) + + @classmethod + def exists(cls, id, using=None, index=None, **kwargs): + """ + check if exists a single document from elasticsearch using its ``id``. + + :arg id: ``id`` of the document to check if exists + :arg index: elasticsearch index to use, if the ``Document`` is + associated with an index this can be omitted. + :arg using: connection alias to use, defaults to ``'default'`` + + Any additional keyword arguments will be passed to + ``Elasticsearch.exists`` unchanged. + """ + es = cls._get_connection(using) + return es.exists(index=cls._default_index(index), id=id, **kwargs) + + @classmethod + def mget( + cls, docs, using=None, index=None, raise_on_error=True, missing="none", **kwargs + ): + r""" + Retrieve multiple document by their ``id``\s. Returns a list of instances + in the same order as requested. + + :arg docs: list of ``id``\s of the documents to be retrieved or a list + of document specifications as per + https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-multi-get.html + :arg index: elasticsearch index to use, if the ``Document`` is + associated with an index this can be omitted. + :arg using: connection alias to use, defaults to ``'default'`` + :arg missing: what to do when one of the documents requested is not + found. Valid options are ``'none'`` (use ``None``), ``'raise'`` (raise + ``NotFoundError``) or ``'skip'`` (ignore the missing document). + + Any additional keyword arguments will be passed to + ``Elasticsearch.mget`` unchanged. + """ + if missing not in ("raise", "skip", "none"): + raise ValueError("'missing' must be 'raise', 'skip', or 'none'.") + es = cls._get_connection(using) + body = { + "docs": [ + doc if isinstance(doc, collections.abc.Mapping) else {"_id": doc} + for doc in docs + ] + } + results = es.mget(index=cls._default_index(index), body=body, **kwargs) + + objs, error_docs, missing_docs = [], [], [] + for doc in results["docs"]: + if doc.get("found"): + if error_docs or missing_docs: + # We're going to raise an exception anyway, so avoid an + # expensive call to cls.from_es(). + continue + + objs.append(cls.from_es(doc)) + + elif doc.get("error"): + if raise_on_error: + error_docs.append(doc) + if missing == "none": + objs.append(None) + + # The doc didn't cause an error, but the doc also wasn't found. + elif missing == "raise": + missing_docs.append(doc) + elif missing == "none": + objs.append(None) + + if error_docs: + error_ids = [doc["_id"] for doc in error_docs] + message = "Required routing not provided for documents %s." + message %= ", ".join(error_ids) + raise RequestError(400, message, error_docs) + if missing_docs: + missing_ids = [doc["_id"] for doc in missing_docs] + message = f"Documents {', '.join(missing_ids)} not found." + raise NotFoundError(404, message, {"docs": missing_docs}) + return objs + + def delete(self, using=None, index=None, **kwargs): + """ + Delete the instance in elasticsearch. + + :arg index: elasticsearch index to use, if the ``Document`` is + associated with an index this can be omitted. + :arg using: connection alias to use, defaults to ``'default'`` + + Any additional keyword arguments will be passed to + ``Elasticsearch.delete`` unchanged. + """ + es = self._get_connection(using) + # extract routing etc from meta + doc_meta = {k: self.meta[k] for k in DOC_META_FIELDS if k in self.meta} + + # Optimistic concurrency control + if "seq_no" in self.meta and "primary_term" in self.meta: + doc_meta["if_seq_no"] = self.meta["seq_no"] + doc_meta["if_primary_term"] = self.meta["primary_term"] + + doc_meta.update(kwargs) + es.delete(index=self._get_index(index), **doc_meta) + + def update( + self, + using=None, + index=None, + detect_noop=True, + doc_as_upsert=False, + refresh=False, + retry_on_conflict=None, + script=None, + script_id=None, + scripted_upsert=False, + upsert=None, + return_doc_meta=False, + **fields, + ): + """ + Partial update of the document, specify fields you wish to update and + both the instance and the document in elasticsearch will be updated:: + + doc = MyDocument(title='Document Title!') + doc.save() + doc.update(title='New Document Title!') + + :arg index: elasticsearch index to use, if the ``Document`` is + associated with an index this can be omitted. + :arg using: connection alias to use, defaults to ``'default'`` + :arg detect_noop: Set to ``False`` to disable noop detection. + :arg refresh: Control when the changes made by this request are visible + to search. Set to ``True`` for immediate effect. + :arg retry_on_conflict: In between the get and indexing phases of the + update, it is possible that another process might have already + updated the same document. By default, the update will fail with a + version conflict exception. The retry_on_conflict parameter + controls how many times to retry the update before finally throwing + an exception. + :arg doc_as_upsert: Instead of sending a partial doc plus an upsert + doc, setting doc_as_upsert to true will use the contents of doc as + the upsert value + :arg return_doc_meta: set to ``True`` to return all metadata from the + index API call instead of only the operation result + + :return operation result noop/updated + """ + body = { + "doc_as_upsert": doc_as_upsert, + "detect_noop": detect_noop, + } + + # scripted update + if script or script_id: + if upsert is not None: + body["upsert"] = upsert + + if script: + script = {"source": script} + else: + script = {"id": script_id} + + script["params"] = fields + + body["script"] = script + body["scripted_upsert"] = scripted_upsert + + # partial document update + else: + if not fields: + raise IllegalOperation( + "You cannot call update() without updating individual fields or a script. " + "If you wish to update the entire object use save()." + ) + + # update given fields locally + merge(self, fields) + + # prepare data for ES + values = self.to_dict() + + # if fields were given: partial update + body["doc"] = {k: values.get(k) for k in fields.keys()} + + # extract routing etc from meta + doc_meta = {k: self.meta[k] for k in DOC_META_FIELDS if k in self.meta} + + if retry_on_conflict is not None: + doc_meta["retry_on_conflict"] = retry_on_conflict + + # Optimistic concurrency control + if ( + retry_on_conflict in (None, 0) + and "seq_no" in self.meta + and "primary_term" in self.meta + ): + doc_meta["if_seq_no"] = self.meta["seq_no"] + doc_meta["if_primary_term"] = self.meta["primary_term"] + + meta = self._get_connection(using).update( + index=self._get_index(index), body=body, refresh=refresh, **doc_meta + ) + # update meta information from ES + for k in META_FIELDS: + if "_" + k in meta: + setattr(self.meta, k, meta["_" + k]) + + return meta if return_doc_meta else meta["result"] + + def save( + self, + using=None, + index=None, + validate=True, + skip_empty=True, + return_doc_meta=False, + **kwargs, + ): + """ + Save the document into elasticsearch. If the document doesn't exist it + is created, it is overwritten otherwise. Returns ``True`` if this + operations resulted in new document being created. + + :arg index: elasticsearch index to use, if the ``Document`` is + associated with an index this can be omitted. + :arg using: connection alias to use, defaults to ``'default'`` + :arg validate: set to ``False`` to skip validating the document + :arg skip_empty: if set to ``False`` will cause empty values (``None``, + ``[]``, ``{}``) to be left on the document. Those values will be + stripped out otherwise as they make no difference in elasticsearch. + :arg return_doc_meta: set to ``True`` to return all metadata from the + update API call instead of only the operation result + + Any additional keyword arguments will be passed to + ``Elasticsearch.index`` unchanged. + + :return operation result created/updated + """ + if validate: + self.full_clean() + + es = self._get_connection(using) + # extract routing etc from meta + doc_meta = {k: self.meta[k] for k in DOC_META_FIELDS if k in self.meta} + + # Optimistic concurrency control + if "seq_no" in self.meta and "primary_term" in self.meta: + doc_meta["if_seq_no"] = self.meta["seq_no"] + doc_meta["if_primary_term"] = self.meta["primary_term"] + + doc_meta.update(kwargs) + meta = es.index( + index=self._get_index(index), + body=self.to_dict(skip_empty=skip_empty), + **doc_meta, + ) + # update meta information from ES + for k in META_FIELDS: + if "_" + k in meta: + setattr(self.meta, k, meta["_" + k]) + + return meta if return_doc_meta else meta["result"] diff --git a/elasticsearch_dsl/_sync/faceted_search.py b/elasticsearch_dsl/_sync/faceted_search.py new file mode 100644 index 000000000..bdb4da75a --- /dev/null +++ b/elasticsearch_dsl/_sync/faceted_search.py @@ -0,0 +1,40 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from elasticsearch_dsl.faceted_search_base import FacetedResponse, FacetedSearchBase + +from .search import Search + + +class FacetedSearch(FacetedSearchBase): + def search(self): + """ + Returns the base Search object to which the facets are added. + + You can customize the query by overriding this method and returning a + modified search object. + """ + s = Search(doc_type=self.doc_types, index=self.index, using=self.using) + return s.response_class(FacetedResponse) + + def execute(self): + """ + Execute the search and return the response. + """ + r = self._s.execute() + r._faceted_search = self + return r diff --git a/elasticsearch_dsl/_sync/index.py b/elasticsearch_dsl/_sync/index.py new file mode 100644 index 000000000..8a35f8e7a --- /dev/null +++ b/elasticsearch_dsl/_sync/index.py @@ -0,0 +1,512 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from ..connections import get_connection +from ..exceptions import IllegalOperation +from ..index_base import IndexBase +from .mapping import Mapping +from .search import Search +from .update_by_query import UpdateByQuery + + +class IndexTemplate: + def __init__(self, name, template, index=None, order=None, **kwargs): + if index is None: + self._index = Index(template, **kwargs) + else: + if kwargs: + raise ValueError( + "You cannot specify options for Index when" + " passing an Index instance." + ) + self._index = index.clone() + self._index._name = template + self._template_name = name + self.order = order + + def __getattr__(self, attr_name): + return getattr(self._index, attr_name) + + def to_dict(self): + d = self._index.to_dict() + d["index_patterns"] = [self._index._name] + if self.order is not None: + d["order"] = self.order + return d + + def save(self, using=None): + es = get_connection(using or self._index._using) + return es.indices.put_template(name=self._template_name, body=self.to_dict()) + + +class Index(IndexBase): + def __init__(self, name, using="default"): + """ + :arg name: name of the index + :arg using: connection alias to use, defaults to ``'default'`` + """ + super().__init__(name, Mapping, using=using) + + def _get_connection(self, using=None): + if self._name is None: + raise ValueError("You cannot perform API calls on the default index.") + return get_connection(using or self._using) + + connection = property(_get_connection) + + def as_template(self, template_name, pattern=None, order=None): + # TODO: should we allow pattern to be a top-level arg? + # or maybe have an IndexPattern that allows for it and have + # Document._index be that? + return IndexTemplate( + template_name, pattern or self._name, index=self, order=order + ) + + def load_mappings(self, using=None): + self.get_or_create_mapping().update_from_es( + self._name, using=using or self._using + ) + + def clone(self, name=None, using=None): + """ + Create a copy of the instance with another name or connection alias. + Useful for creating multiple indices with shared configuration:: + + i = Index('base-index') + i.settings(number_of_shards=1) + i.create() + + i2 = i.clone('other-index') + i2.create() + + :arg name: name of the index + :arg using: connection alias to use, defaults to ``'default'`` + """ + i = Index(name or self._name, using=using or self._using) + i._settings = self._settings.copy() + i._aliases = self._aliases.copy() + i._analysis = self._analysis.copy() + i._doc_types = self._doc_types[:] + if self._mapping is not None: + i._mapping = self._mapping._clone() + return i + + def search(self, using=None): + """ + Return a :class:`~elasticsearch_dsl.Search` object searching over the + index (or all the indices belonging to this template) and its + ``Document``\\s. + """ + return Search( + using=using or self._using, index=self._name, doc_type=self._doc_types + ) + + def updateByQuery(self, using=None): + """ + Return a :class:`~elasticsearch_dsl.UpdateByQuery` object searching over the index + (or all the indices belonging to this template) and updating Documents that match + the search criteria. + + For more information, see here: + https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-update-by-query.html + """ + return UpdateByQuery( + using=using or self._using, + index=self._name, + ) + + def create(self, using=None, **kwargs): + """ + Creates the index in elasticsearch. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.create`` unchanged. + """ + return self._get_connection(using).indices.create( + index=self._name, body=self.to_dict(), **kwargs + ) + + def is_closed(self, using=None): + state = self._get_connection(using).cluster.state( + index=self._name, metric="metadata" + ) + return state["metadata"]["indices"][self._name]["state"] == "close" + + def save(self, using=None): + """ + Sync the index definition with elasticsearch, creating the index if it + doesn't exist and updating its settings and mappings if it does. + + Note some settings and mapping changes cannot be done on an open + index (or at all on an existing index) and for those this method will + fail with the underlying exception. + """ + if not self.exists(using=using): + return self.create(using=using) + + body = self.to_dict() + settings = body.pop("settings", {}) + analysis = settings.pop("analysis", None) + current_settings = (self.get_settings(using=using))[self._name]["settings"][ + "index" + ] + if analysis: + if self.is_closed(using=using): + # closed index, update away + settings["analysis"] = analysis + else: + # compare analysis definition, if all analysis objects are + # already defined as requested, skip analysis update and + # proceed, otherwise raise IllegalOperation + existing_analysis = current_settings.get("analysis", {}) + if any( + existing_analysis.get(section, {}).get(k, None) + != analysis[section][k] + for section in analysis + for k in analysis[section] + ): + raise IllegalOperation( + "You cannot update analysis configuration on an open index, " + "you need to close index %s first." % self._name + ) + + # try and update the settings + if settings: + settings = settings.copy() + for k, v in list(settings.items()): + if k in current_settings and current_settings[k] == str(v): + del settings[k] + + if settings: + self.put_settings(using=using, body=settings) + + # update the mappings, any conflict in the mappings will result in an + # exception + mappings = body.pop("mappings", {}) + if mappings: + self.put_mapping(using=using, body=mappings) + + def analyze(self, using=None, **kwargs): + """ + Perform the analysis process on a text and return the tokens breakdown + of the text. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.analyze`` unchanged. + """ + return self._get_connection(using).indices.analyze(index=self._name, **kwargs) + + def refresh(self, using=None, **kwargs): + """ + Performs a refresh operation on the index. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.refresh`` unchanged. + """ + return self._get_connection(using).indices.refresh(index=self._name, **kwargs) + + def flush(self, using=None, **kwargs): + """ + Performs a flush operation on the index. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.flush`` unchanged. + """ + return self._get_connection(using).indices.flush(index=self._name, **kwargs) + + def get(self, using=None, **kwargs): + """ + The get index API allows to retrieve information about the index. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.get`` unchanged. + """ + return self._get_connection(using).indices.get(index=self._name, **kwargs) + + def open(self, using=None, **kwargs): + """ + Opens the index in elasticsearch. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.open`` unchanged. + """ + return self._get_connection(using).indices.open(index=self._name, **kwargs) + + def close(self, using=None, **kwargs): + """ + Closes the index in elasticsearch. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.close`` unchanged. + """ + return self._get_connection(using).indices.close(index=self._name, **kwargs) + + def delete(self, using=None, **kwargs): + """ + Deletes the index in elasticsearch. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.delete`` unchanged. + """ + return self._get_connection(using).indices.delete(index=self._name, **kwargs) + + def exists(self, using=None, **kwargs): + """ + Returns ``True`` if the index already exists in elasticsearch. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.exists`` unchanged. + """ + return self._get_connection(using).indices.exists(index=self._name, **kwargs) + + def exists_type(self, using=None, **kwargs): + """ + Check if a type/types exists in the index. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.exists_type`` unchanged. + """ + return self._get_connection(using).indices.exists_type( + index=self._name, **kwargs + ) + + def put_mapping(self, using=None, **kwargs): + """ + Register specific mapping definition for a specific type. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.put_mapping`` unchanged. + """ + return self._get_connection(using).indices.put_mapping( + index=self._name, **kwargs + ) + + def get_mapping(self, using=None, **kwargs): + """ + Retrieve specific mapping definition for a specific type. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.get_mapping`` unchanged. + """ + return self._get_connection(using).indices.get_mapping( + index=self._name, **kwargs + ) + + def get_field_mapping(self, using=None, **kwargs): + """ + Retrieve mapping definition of a specific field. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.get_field_mapping`` unchanged. + """ + return self._get_connection(using).indices.get_field_mapping( + index=self._name, **kwargs + ) + + def put_alias(self, using=None, **kwargs): + """ + Create an alias for the index. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.put_alias`` unchanged. + """ + return self._get_connection(using).indices.put_alias(index=self._name, **kwargs) + + def exists_alias(self, using=None, **kwargs): + """ + Return a boolean indicating whether given alias exists for this index. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.exists_alias`` unchanged. + """ + return self._get_connection(using).indices.exists_alias( + index=self._name, **kwargs + ) + + def get_alias(self, using=None, **kwargs): + """ + Retrieve a specified alias. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.get_alias`` unchanged. + """ + return self._get_connection(using).indices.get_alias(index=self._name, **kwargs) + + def delete_alias(self, using=None, **kwargs): + """ + Delete specific alias. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.delete_alias`` unchanged. + """ + return self._get_connection(using).indices.delete_alias( + index=self._name, **kwargs + ) + + def get_settings(self, using=None, **kwargs): + """ + Retrieve settings for the index. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.get_settings`` unchanged. + """ + return self._get_connection(using).indices.get_settings( + index=self._name, **kwargs + ) + + def put_settings(self, using=None, **kwargs): + """ + Change specific index level settings in real time. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.put_settings`` unchanged. + """ + return self._get_connection(using).indices.put_settings( + index=self._name, **kwargs + ) + + def stats(self, using=None, **kwargs): + """ + Retrieve statistics on different operations happening on the index. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.stats`` unchanged. + """ + return self._get_connection(using).indices.stats(index=self._name, **kwargs) + + def segments(self, using=None, **kwargs): + """ + Provide low level segments information that a Lucene index (shard + level) is built with. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.segments`` unchanged. + """ + return self._get_connection(using).indices.segments(index=self._name, **kwargs) + + def validate_query(self, using=None, **kwargs): + """ + Validate a potentially expensive query without executing it. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.validate_query`` unchanged. + """ + return self._get_connection(using).indices.validate_query( + index=self._name, **kwargs + ) + + def clear_cache(self, using=None, **kwargs): + """ + Clear all caches or specific cached associated with the index. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.clear_cache`` unchanged. + """ + return self._get_connection(using).indices.clear_cache( + index=self._name, **kwargs + ) + + def recovery(self, using=None, **kwargs): + """ + The indices recovery API provides insight into on-going shard + recoveries for the index. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.recovery`` unchanged. + """ + return self._get_connection(using).indices.recovery(index=self._name, **kwargs) + + def upgrade(self, using=None, **kwargs): + """ + Upgrade the index to the latest format. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.upgrade`` unchanged. + """ + return self._get_connection(using).indices.upgrade(index=self._name, **kwargs) + + def get_upgrade(self, using=None, **kwargs): + """ + Monitor how much of the index is upgraded. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.get_upgrade`` unchanged. + """ + return self._get_connection(using).indices.get_upgrade( + index=self._name, **kwargs + ) + + def flush_synced(self, using=None, **kwargs): + """ + Perform a normal flush, then add a generated unique marker (sync_id) to + all shards. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.flush_synced`` unchanged. + """ + return self._get_connection(using).indices.flush_synced( + index=self._name, **kwargs + ) + + def shard_stores(self, using=None, **kwargs): + """ + Provides store information for shard copies of the index. Store + information reports on which nodes shard copies exist, the shard copy + version, indicating how recent they are, and any exceptions encountered + while opening the shard index or from earlier engine failure. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.shard_stores`` unchanged. + """ + return self._get_connection(using).indices.shard_stores( + index=self._name, **kwargs + ) + + def forcemerge(self, using=None, **kwargs): + """ + The force merge API allows to force merging of the index through an + API. The merge relates to the number of segments a Lucene index holds + within each shard. The force merge operation allows to reduce the + number of segments by merging them. + + This call will block until the merge is complete. If the http + connection is lost, the request will continue in the background, and + any new requests will block until the previous force merge is complete. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.forcemerge`` unchanged. + """ + return self._get_connection(using).indices.forcemerge( + index=self._name, **kwargs + ) + + def shrink(self, using=None, **kwargs): + """ + The shrink index API allows you to shrink an existing index into a new + index with fewer primary shards. The number of primary shards in the + target index must be a factor of the shards in the source index. For + example an index with 8 primary shards can be shrunk into 4, 2 or 1 + primary shards or an index with 15 primary shards can be shrunk into 5, + 3 or 1. If the number of shards in the index is a prime number it can + only be shrunk into a single primary shard. Before shrinking, a + (primary or replica) copy of every shard in the index must be present + on the same node. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.shrink`` unchanged. + """ + return self._get_connection(using).indices.shrink(index=self._name, **kwargs) diff --git a/elasticsearch_dsl/_sync/mapping.py b/elasticsearch_dsl/_sync/mapping.py new file mode 100644 index 000000000..59909348a --- /dev/null +++ b/elasticsearch_dsl/_sync/mapping.py @@ -0,0 +1,40 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from ..connections import get_connection +from ..mapping_base import MappingBase + + +class Mapping(MappingBase): + @classmethod + def from_es(cls, index, using="default"): + m = cls() + m.update_from_es(index, using) + return m + + def update_from_es(self, index, using="default"): + es = get_connection(using) + raw = es.indices.get_mapping(index=index) + _, raw = raw.popitem() + self._update_from_dict(raw["mappings"]) + + def save(self, index, using="default"): + from .index import Index + + index = Index(index, using=using) + index.mapping(self) + return index.save() diff --git a/elasticsearch_dsl/_sync/search.py b/elasticsearch_dsl/_sync/search.py new file mode 100644 index 000000000..a15a08aa0 --- /dev/null +++ b/elasticsearch_dsl/_sync/search.py @@ -0,0 +1,138 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from elasticsearch.exceptions import ApiError +from elasticsearch.helpers import scan + +from ..connections import get_connection +from ..response import Response +from ..search_base import MultiSearchBase, SearchBase +from ..utils import AttrDict + + +class Search(SearchBase): + def __iter__(self): + """ + Iterate over the hits. + """ + + class ResultsIterator: + def __init__(self, search): + self.search = search + self.iterator = None + + def __next__(self): + if self.iterator is None: + self.iterator = iter(self.search.execute()) + try: + return next(self.iterator) + except StopIteration: + raise StopIteration() + + return ResultsIterator(self) + + def count(self): + """ + Return the number of hits matching the query and filters. Note that + only the actual number is returned. + """ + if hasattr(self, "_response") and self._response.hits.total.relation == "eq": + return self._response.hits.total.value + + es = get_connection(self._using) + + d = self.to_dict(count=True) + # TODO: failed shards detection + resp = es.count(index=self._index, query=d.get("query", None), **self._params) + return resp["count"] + + def execute(self, ignore_cache=False): + """ + Execute the search and return an instance of ``Response`` wrapping all + the data. + + :arg ignore_cache: if set to ``True``, consecutive calls will hit + ES, while cached result will be ignored. Defaults to `False` + """ + if ignore_cache or not hasattr(self, "_response"): + es = get_connection(self._using) + + self._response = self._response_class( + self, + ( + es.search(index=self._index, body=self.to_dict(), **self._params) + ).body, + ) + return self._response + + def scan(self): + """ + Turn the search into a scan search and return a generator that will + iterate over all the documents matching the query. + + Use ``params`` method to specify any additional arguments you with to + pass to the underlying ``scan`` helper from ``elasticsearch-py`` - + https://elasticsearch-py.readthedocs.io/en/master/helpers.html#elasticsearch.helpers.scan + + """ + es = get_connection(self._using) + + for hit in scan(es, query=self.to_dict(), index=self._index, **self._params): + yield self._get_result(hit) + + def delete(self): + """ + delete() executes the query by delegating to delete_by_query() + """ + + es = get_connection(self._using) + + return AttrDict( + es.delete_by_query(index=self._index, body=self.to_dict(), **self._params) + ) + + +class MultiSearch(MultiSearchBase): + """ + Combine multiple :class:`~elasticsearch_dsl.Search` objects into a single + request. + """ + + def execute(self, ignore_cache=False, raise_on_error=True): + """ + Execute the multi search request and return a list of search results. + """ + if ignore_cache or not hasattr(self, "_response"): + es = get_connection(self._using) + + responses = es.msearch( + index=self._index, body=self.to_dict(), **self._params + ) + + out = [] + for s, r in zip(self._searches, responses["responses"]): + if r.get("error", False): + if raise_on_error: + raise ApiError("N/A", meta=responses.meta, body=r) + r = None + else: + r = Response(s, r) + out.append(r) + + self._response = out + + return self._response diff --git a/elasticsearch_dsl/_sync/update_by_query.py b/elasticsearch_dsl/_sync/update_by_query.py new file mode 100644 index 000000000..9e0b4780c --- /dev/null +++ b/elasticsearch_dsl/_sync/update_by_query.py @@ -0,0 +1,34 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from ..connections import get_connection +from ..update_by_query_base import UpdateByQueryBase + + +class UpdateByQuery(UpdateByQueryBase): + def execute(self): + """ + Execute the search and return an instance of ``Response`` wrapping all + the data. + """ + es = get_connection(self._using) + + self._response = self._response_class( + self, + es.update_by_query(index=self._index, **self.to_dict(), **self._params), + ) + return self._response diff --git a/elasticsearch_dsl/analysis.py b/elasticsearch_dsl/analysis.py index 0d7a1a425..96321337f 100644 --- a/elasticsearch_dsl/analysis.py +++ b/elasticsearch_dsl/analysis.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from .connections import get_connection +from . import async_connections, connections from .utils import AttrDict, DslBase, merge __all__ = ["tokenizer", "analyzer", "char_filter", "token_filter", "normalizer"] @@ -119,20 +119,7 @@ class CustomAnalyzer(CustomAnalysisDefinition, Analyzer): "tokenizer": {"type": "tokenizer"}, } - def simulate(self, text, using="default", explain=False, attributes=None): - """ - Use the Analyze API of elasticsearch to test the outcome of this analyzer. - - :arg text: Text to be analyzed - :arg using: connection alias to use, defaults to ``'default'`` - :arg explain: will output all token attributes for each token. You can - filter token attributes you want to output by setting ``attributes`` - option. - :arg attributes: if ``explain`` is specified, filter the token - attributes to return. - """ - es = get_connection(using) - + def _get_body(self, text, explain, attributes): body = {"text": text, "explain": explain} if attributes: body["attributes"] = attributes @@ -156,7 +143,43 @@ def simulate(self, text, using="default", explain=False, attributes=None): if self._builtin_type != "custom": body["analyzer"] = self._builtin_type - return AttrDict(es.indices.analyze(body=body)) + return body + + def simulate(self, text, using="default", explain=False, attributes=None): + """ + Use the Analyze API of elasticsearch to test the outcome of this analyzer. + + :arg text: Text to be analyzed + :arg using: connection alias to use, defaults to ``'default'`` + :arg explain: will output all token attributes for each token. You can + filter token attributes you want to output by setting ``attributes`` + option. + :arg attributes: if ``explain`` is specified, filter the token + attributes to return. + """ + es = connections.get_connection(using) + return AttrDict( + es.indices.analyze(body=self._get_body(text, explain, attributes)) + ) + + async def async_simulate( + self, text, using="default", explain=False, attributes=None + ): + """ + Use the Analyze API of elasticsearch to test the outcome of this analyzer. + + :arg text: Text to be analyzed + :arg using: connection alias to use, defaults to ``'default'`` + :arg explain: will output all token attributes for each token. You can + filter token attributes you want to output by setting ``attributes`` + option. + :arg attributes: if ``explain`` is specified, filter the token + attributes to return. + """ + es = async_connections.get_connection(using) + return AttrDict( + await es.indices.analyze(body=self._get_body(text, explain, attributes)) + ) class Normalizer(AnalysisBase, DslBase): diff --git a/elasticsearch_dsl/async_connections.py b/elasticsearch_dsl/async_connections.py new file mode 100644 index 000000000..5100679d3 --- /dev/null +++ b/elasticsearch_dsl/async_connections.py @@ -0,0 +1,27 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from elasticsearch import AsyncElasticsearch + +from elasticsearch_dsl.connections import Connections + +connections = Connections(elasticsearch_class=AsyncElasticsearch) +configure = connections.configure +add_connection = connections.add_connection +remove_connection = connections.remove_connection +create_connection = connections.create_connection +get_connection = connections.get_connection diff --git a/elasticsearch_dsl/document.py b/elasticsearch_dsl/document.py index 5553821f6..304313c98 100644 --- a/elasticsearch_dsl/document.py +++ b/elasticsearch_dsl/document.py @@ -15,474 +15,6 @@ # specific language governing permissions and limitations # under the License. -import collections.abc -from fnmatch import fnmatch - -from elasticsearch.exceptions import NotFoundError, RequestError - -from .connections import get_connection -from .exceptions import IllegalOperation, ValidationException -from .field import Field -from .index import Index -from .mapping import Mapping -from .search import Search -from .utils import DOC_META_FIELDS, META_FIELDS, ObjectBase, merge - - -class MetaField: - def __init__(self, *args, **kwargs): - self.args, self.kwargs = args, kwargs - - -class DocumentMeta(type): - def __new__(cls, name, bases, attrs): - # DocumentMeta filters attrs in place - attrs["_doc_type"] = DocumentOptions(name, bases, attrs) - return super().__new__(cls, name, bases, attrs) - - -class IndexMeta(DocumentMeta): - # global flag to guard us from associating an Index with the base Document - # class, only user defined subclasses should have an _index attr - _document_initialized = False - - def __new__(cls, name, bases, attrs): - new_cls = super().__new__(cls, name, bases, attrs) - if cls._document_initialized: - index_opts = attrs.pop("Index", None) - index = cls.construct_index(index_opts, bases) - new_cls._index = index - index.document(new_cls) - cls._document_initialized = True - return new_cls - - @classmethod - def construct_index(cls, opts, bases): - if opts is None: - for b in bases: - if hasattr(b, "_index"): - return b._index - - # Set None as Index name so it will set _all while making the query - return Index(name=None) - - i = Index(getattr(opts, "name", "*"), using=getattr(opts, "using", "default")) - i.settings(**getattr(opts, "settings", {})) - i.aliases(**getattr(opts, "aliases", {})) - for a in getattr(opts, "analyzers", ()): - i.analyzer(a) - return i - - -class DocumentOptions: - def __init__(self, name, bases, attrs): - meta = attrs.pop("Meta", None) - - # create the mapping instance - self.mapping = getattr(meta, "mapping", Mapping()) - - # register all declared fields into the mapping - for name, value in list(attrs.items()): - if isinstance(value, Field): - self.mapping.field(name, value) - del attrs[name] - - # add all the mappings for meta fields - for name in dir(meta): - if isinstance(getattr(meta, name, None), MetaField): - params = getattr(meta, name) - self.mapping.meta(name, *params.args, **params.kwargs) - - # document inheritance - include the fields from parents' mappings - for b in bases: - if hasattr(b, "_doc_type") and hasattr(b._doc_type, "mapping"): - self.mapping.update(b._doc_type.mapping, update_only=True) - - @property - def name(self): - return self.mapping.properties.name - - -class InnerDoc(ObjectBase, metaclass=DocumentMeta): - """ - Common class for inner documents like Object or Nested - """ - - @classmethod - def from_es(cls, data, data_only=False): - if data_only: - data = {"_source": data} - return super().from_es(data) - - -class Document(ObjectBase, metaclass=IndexMeta): - """ - Model-like class for persisting documents in elasticsearch. - """ - - @classmethod - def _matches(cls, hit): - if cls._index._name is None: - return True - return fnmatch(hit.get("_index", ""), cls._index._name) - - @classmethod - def _get_using(cls, using=None): - return using or cls._index._using - - @classmethod - def _get_connection(cls, using=None): - return get_connection(cls._get_using(using)) - - @classmethod - def _default_index(cls, index=None): - return index or cls._index._name - - @classmethod - def init(cls, index=None, using=None): - """ - Create the index and populate the mappings in elasticsearch. - """ - i = cls._index - if index: - i = i.clone(name=index) - i.save(using=using) - - def _get_index(self, index=None, required=True): - if index is None: - index = getattr(self.meta, "index", None) - if index is None: - index = getattr(self._index, "_name", None) - if index is None and required: - raise ValidationException("No index") - if index and "*" in index: - raise ValidationException("You cannot write to a wildcard index.") - return index - - def __repr__(self): - return "{}({})".format( - self.__class__.__name__, - ", ".join( - f"{key}={getattr(self.meta, key)!r}" - for key in ("index", "id") - if key in self.meta - ), - ) - - @classmethod - def search(cls, using=None, index=None): - """ - Create an :class:`~elasticsearch_dsl.Search` instance that will search - over this ``Document``. - """ - return Search( - using=cls._get_using(using), index=cls._default_index(index), doc_type=[cls] - ) - - @classmethod - def get(cls, id, using=None, index=None, **kwargs): - """ - Retrieve a single document from elasticsearch using its ``id``. - - :arg id: ``id`` of the document to be retrieved - :arg index: elasticsearch index to use, if the ``Document`` is - associated with an index this can be omitted. - :arg using: connection alias to use, defaults to ``'default'`` - - Any additional keyword arguments will be passed to - ``Elasticsearch.get`` unchanged. - """ - es = cls._get_connection(using) - doc = es.get(index=cls._default_index(index), id=id, **kwargs) - if not doc.get("found", False): - return None - return cls.from_es(doc) - - @classmethod - def exists(cls, id, using=None, index=None, **kwargs): - """ - check if exists a single document from elasticsearch using its ``id``. - - :arg id: ``id`` of the document to check if exists - :arg index: elasticsearch index to use, if the ``Document`` is - associated with an index this can be omitted. - :arg using: connection alias to use, defaults to ``'default'`` - - Any additional keyword arguments will be passed to - ``Elasticsearch.exists`` unchanged. - """ - es = cls._get_connection(using) - return es.exists(index=cls._default_index(index), id=id, **kwargs) - - @classmethod - def mget( - cls, docs, using=None, index=None, raise_on_error=True, missing="none", **kwargs - ): - r""" - Retrieve multiple document by their ``id``\s. Returns a list of instances - in the same order as requested. - - :arg docs: list of ``id``\s of the documents to be retrieved or a list - of document specifications as per - https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-multi-get.html - :arg index: elasticsearch index to use, if the ``Document`` is - associated with an index this can be omitted. - :arg using: connection alias to use, defaults to ``'default'`` - :arg missing: what to do when one of the documents requested is not - found. Valid options are ``'none'`` (use ``None``), ``'raise'`` (raise - ``NotFoundError``) or ``'skip'`` (ignore the missing document). - - Any additional keyword arguments will be passed to - ``Elasticsearch.mget`` unchanged. - """ - if missing not in ("raise", "skip", "none"): - raise ValueError("'missing' must be 'raise', 'skip', or 'none'.") - es = cls._get_connection(using) - body = { - "docs": [ - doc if isinstance(doc, collections.abc.Mapping) else {"_id": doc} - for doc in docs - ] - } - results = es.mget(index=cls._default_index(index), body=body, **kwargs) - - objs, error_docs, missing_docs = [], [], [] - for doc in results["docs"]: - if doc.get("found"): - if error_docs or missing_docs: - # We're going to raise an exception anyway, so avoid an - # expensive call to cls.from_es(). - continue - - objs.append(cls.from_es(doc)) - - elif doc.get("error"): - if raise_on_error: - error_docs.append(doc) - if missing == "none": - objs.append(None) - - # The doc didn't cause an error, but the doc also wasn't found. - elif missing == "raise": - missing_docs.append(doc) - elif missing == "none": - objs.append(None) - - if error_docs: - error_ids = [doc["_id"] for doc in error_docs] - message = "Required routing not provided for documents %s." - message %= ", ".join(error_ids) - raise RequestError(400, message, error_docs) - if missing_docs: - missing_ids = [doc["_id"] for doc in missing_docs] - message = f"Documents {', '.join(missing_ids)} not found." - raise NotFoundError(404, message, {"docs": missing_docs}) - return objs - - def delete(self, using=None, index=None, **kwargs): - """ - Delete the instance in elasticsearch. - - :arg index: elasticsearch index to use, if the ``Document`` is - associated with an index this can be omitted. - :arg using: connection alias to use, defaults to ``'default'`` - - Any additional keyword arguments will be passed to - ``Elasticsearch.delete`` unchanged. - """ - es = self._get_connection(using) - # extract routing etc from meta - doc_meta = {k: self.meta[k] for k in DOC_META_FIELDS if k in self.meta} - - # Optimistic concurrency control - if "seq_no" in self.meta and "primary_term" in self.meta: - doc_meta["if_seq_no"] = self.meta["seq_no"] - doc_meta["if_primary_term"] = self.meta["primary_term"] - - doc_meta.update(kwargs) - es.delete(index=self._get_index(index), **doc_meta) - - def to_dict(self, include_meta=False, skip_empty=True): - """ - Serialize the instance into a dictionary so that it can be saved in elasticsearch. - - :arg include_meta: if set to ``True`` will include all the metadata - (``_index``, ``_id`` etc). Otherwise just the document's - data is serialized. This is useful when passing multiple instances into - ``elasticsearch.helpers.bulk``. - :arg skip_empty: if set to ``False`` will cause empty values (``None``, - ``[]``, ``{}``) to be left on the document. Those values will be - stripped out otherwise as they make no difference in elasticsearch. - """ - d = super().to_dict(skip_empty=skip_empty) - if not include_meta: - return d - - meta = {"_" + k: self.meta[k] for k in DOC_META_FIELDS if k in self.meta} - - # in case of to_dict include the index unlike save/update/delete - index = self._get_index(required=False) - if index is not None: - meta["_index"] = index - - meta["_source"] = d - return meta - - def update( - self, - using=None, - index=None, - detect_noop=True, - doc_as_upsert=False, - refresh=False, - retry_on_conflict=None, - script=None, - script_id=None, - scripted_upsert=False, - upsert=None, - return_doc_meta=False, - **fields, - ): - """ - Partial update of the document, specify fields you wish to update and - both the instance and the document in elasticsearch will be updated:: - - doc = MyDocument(title='Document Title!') - doc.save() - doc.update(title='New Document Title!') - - :arg index: elasticsearch index to use, if the ``Document`` is - associated with an index this can be omitted. - :arg using: connection alias to use, defaults to ``'default'`` - :arg detect_noop: Set to ``False`` to disable noop detection. - :arg refresh: Control when the changes made by this request are visible - to search. Set to ``True`` for immediate effect. - :arg retry_on_conflict: In between the get and indexing phases of the - update, it is possible that another process might have already - updated the same document. By default, the update will fail with a - version conflict exception. The retry_on_conflict parameter - controls how many times to retry the update before finally throwing - an exception. - :arg doc_as_upsert: Instead of sending a partial doc plus an upsert - doc, setting doc_as_upsert to true will use the contents of doc as - the upsert value - :arg return_doc_meta: set to ``True`` to return all metadata from the - index API call instead of only the operation result - - :return operation result noop/updated - """ - body = { - "doc_as_upsert": doc_as_upsert, - "detect_noop": detect_noop, - } - - # scripted update - if script or script_id: - if upsert is not None: - body["upsert"] = upsert - - if script: - script = {"source": script} - else: - script = {"id": script_id} - - script["params"] = fields - - body["script"] = script - body["scripted_upsert"] = scripted_upsert - - # partial document update - else: - if not fields: - raise IllegalOperation( - "You cannot call update() without updating individual fields or a script. " - "If you wish to update the entire object use save()." - ) - - # update given fields locally - merge(self, fields) - - # prepare data for ES - values = self.to_dict() - - # if fields were given: partial update - body["doc"] = {k: values.get(k) for k in fields.keys()} - - # extract routing etc from meta - doc_meta = {k: self.meta[k] for k in DOC_META_FIELDS if k in self.meta} - - if retry_on_conflict is not None: - doc_meta["retry_on_conflict"] = retry_on_conflict - - # Optimistic concurrency control - if ( - retry_on_conflict in (None, 0) - and "seq_no" in self.meta - and "primary_term" in self.meta - ): - doc_meta["if_seq_no"] = self.meta["seq_no"] - doc_meta["if_primary_term"] = self.meta["primary_term"] - - meta = self._get_connection(using).update( - index=self._get_index(index), body=body, refresh=refresh, **doc_meta - ) - # update meta information from ES - for k in META_FIELDS: - if "_" + k in meta: - setattr(self.meta, k, meta["_" + k]) - - return meta if return_doc_meta else meta["result"] - - def save( - self, - using=None, - index=None, - validate=True, - skip_empty=True, - return_doc_meta=False, - **kwargs, - ): - """ - Save the document into elasticsearch. If the document doesn't exist it - is created, it is overwritten otherwise. Returns ``True`` if this - operations resulted in new document being created. - - :arg index: elasticsearch index to use, if the ``Document`` is - associated with an index this can be omitted. - :arg using: connection alias to use, defaults to ``'default'`` - :arg validate: set to ``False`` to skip validating the document - :arg skip_empty: if set to ``False`` will cause empty values (``None``, - ``[]``, ``{}``) to be left on the document. Those values will be - stripped out otherwise as they make no difference in elasticsearch. - :arg return_doc_meta: set to ``True`` to return all metadata from the - update API call instead of only the operation result - - Any additional keyword arguments will be passed to - ``Elasticsearch.index`` unchanged. - - :return operation result created/updated - """ - if validate: - self.full_clean() - - es = self._get_connection(using) - # extract routing etc from meta - doc_meta = {k: self.meta[k] for k in DOC_META_FIELDS if k in self.meta} - - # Optimistic concurrency control - if "seq_no" in self.meta and "primary_term" in self.meta: - doc_meta["if_seq_no"] = self.meta["seq_no"] - doc_meta["if_primary_term"] = self.meta["primary_term"] - - doc_meta.update(kwargs) - meta = es.index( - index=self._get_index(index), - body=self.to_dict(skip_empty=skip_empty), - **doc_meta, - ) - # update meta information from ES - for k in META_FIELDS: - if "_" + k in meta: - setattr(self.meta, k, meta["_" + k]) - - return meta if return_doc_meta else meta["result"] +from elasticsearch_dsl._async.document import AsyncDocument # noqa: F401 +from elasticsearch_dsl._sync.document import Document # noqa: F401 +from elasticsearch_dsl.document_base import InnerDoc, MetaField # noqa: F401 diff --git a/elasticsearch_dsl/document_base.py b/elasticsearch_dsl/document_base.py new file mode 100644 index 000000000..8abbc7967 --- /dev/null +++ b/elasticsearch_dsl/document_base.py @@ -0,0 +1,153 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from fnmatch import fnmatch + +from .exceptions import ValidationException +from .field import Field +from .mapping import Mapping +from .utils import DOC_META_FIELDS, ObjectBase + + +class MetaField: + def __init__(self, *args, **kwargs): + self.args, self.kwargs = args, kwargs + + +class DocumentMeta(type): + def __new__(cls, name, bases, attrs): + # DocumentMeta filters attrs in place + attrs["_doc_type"] = DocumentOptions(name, bases, attrs) + return super().__new__(cls, name, bases, attrs) + + +class DocumentOptions: + def __init__(self, name, bases, attrs): + meta = attrs.pop("Meta", None) + + # create the mapping instance + self.mapping = getattr(meta, "mapping", Mapping()) + + # register all declared fields into the mapping + for name, value in list(attrs.items()): + if isinstance(value, Field): + self.mapping.field(name, value) + del attrs[name] + + # add all the mappings for meta fields + for name in dir(meta): + if isinstance(getattr(meta, name, None), MetaField): + params = getattr(meta, name) + self.mapping.meta(name, *params.args, **params.kwargs) + + # document inheritance - include the fields from parents' mappings + for b in bases: + if hasattr(b, "_doc_type") and hasattr(b._doc_type, "mapping"): + self.mapping.update(b._doc_type.mapping, update_only=True) + + @property + def name(self): + return self.mapping.properties.name + + +class InnerDoc(ObjectBase, metaclass=DocumentMeta): + """ + Common class for inner documents like Object or Nested + """ + + @classmethod + def from_es(cls, data, data_only=False): + if data_only: + data = {"_source": data} + return super().from_es(data) + + +class DocumentBase(ObjectBase): + """ + Model-like class for persisting documents in elasticsearch. + """ + + @classmethod + def _matches(cls, hit): + if cls._index._name is None: + return True + return fnmatch(hit.get("_index", ""), cls._index._name) + + @classmethod + def _get_using(cls, using=None): + return using or cls._index._using + + @classmethod + def _default_index(cls, index=None): + return index or cls._index._name + + @classmethod + def init(cls, index=None, using=None): + """ + Create the index and populate the mappings in elasticsearch. + """ + i = cls._index + if index: + i = i.clone(name=index) + i.save(using=using) + + def _get_index(self, index=None, required=True): + if index is None: + index = getattr(self.meta, "index", None) + if index is None: + index = getattr(self._index, "_name", None) + if index is None and required: + raise ValidationException("No index") + if index and "*" in index: + raise ValidationException("You cannot write to a wildcard index.") + return index + + def __repr__(self): + return "{}({})".format( + self.__class__.__name__, + ", ".join( + f"{key}={getattr(self.meta, key)!r}" + for key in ("index", "id") + if key in self.meta + ), + ) + + def to_dict(self, include_meta=False, skip_empty=True): + """ + Serialize the instance into a dictionary so that it can be saved in elasticsearch. + + :arg include_meta: if set to ``True`` will include all the metadata + (``_index``, ``_id`` etc). Otherwise just the document's + data is serialized. This is useful when passing multiple instances into + ``elasticsearch.helpers.bulk``. + :arg skip_empty: if set to ``False`` will cause empty values (``None``, + ``[]``, ``{}``) to be left on the document. Those values will be + stripped out otherwise as they make no difference in elasticsearch. + """ + d = super().to_dict(skip_empty=skip_empty) + if not include_meta: + return d + + meta = {"_" + k: self.meta[k] for k in DOC_META_FIELDS if k in self.meta} + + # in case of to_dict include the index unlike save/update/delete + index = self._get_index(required=False) + if index is not None: + meta["_index"] = index + + meta["_source"] = d + return meta diff --git a/elasticsearch_dsl/faceted_search.py b/elasticsearch_dsl/faceted_search.py index 395eb8225..706be6274 100644 --- a/elasticsearch_dsl/faceted_search.py +++ b/elasticsearch_dsl/faceted_search.py @@ -15,444 +15,14 @@ # specific language governing permissions and limitations # under the License. -from datetime import datetime, timedelta - -from .aggs import A -from .query import MatchAll, Nested, Range, Terms -from .response import Response -from .search import Search -from .utils import AttrDict - -__all__ = [ - "FacetedSearch", - "HistogramFacet", - "TermsFacet", - "DateHistogramFacet", - "RangeFacet", - "NestedFacet", -] - - -class Facet: - """ - A facet on faceted search. Wraps and aggregation and provides functionality - to create a filter for selected values and return a list of facet values - from the result of the aggregation. - """ - - agg_type = None - - def __init__(self, metric=None, metric_sort="desc", **kwargs): - self.filter_values = () - self._params = kwargs - self._metric = metric - if metric and metric_sort: - self._params["order"] = {"metric": metric_sort} - - def get_aggregation(self): - """ - Return the aggregation object. - """ - agg = A(self.agg_type, **self._params) - if self._metric: - agg.metric("metric", self._metric) - return agg - - def add_filter(self, filter_values): - """ - Construct a filter. - """ - if not filter_values: - return - - f = self.get_value_filter(filter_values[0]) - for v in filter_values[1:]: - f |= self.get_value_filter(v) - return f - - def get_value_filter(self, filter_value): - """ - Construct a filter for an individual value - """ - pass - - def is_filtered(self, key, filter_values): - """ - Is a filter active on the given key. - """ - return key in filter_values - - def get_value(self, bucket): - """ - return a value representing a bucket. Its key as default. - """ - return bucket["key"] - - def get_metric(self, bucket): - """ - Return a metric, by default doc_count for a bucket. - """ - if self._metric: - return bucket["metric"]["value"] - return bucket["doc_count"] - - def get_values(self, data, filter_values): - """ - Turn the raw bucket data into a list of tuples containing the key, - number of documents and a flag indicating whether this value has been - selected or not. - """ - out = [] - for bucket in data.buckets: - key = self.get_value(bucket) - out.append( - (key, self.get_metric(bucket), self.is_filtered(key, filter_values)) - ) - return out - - -class TermsFacet(Facet): - agg_type = "terms" - - def add_filter(self, filter_values): - """Create a terms filter instead of bool containing term filters.""" - if filter_values: - return Terms( - _expand__to_dot=False, **{self._params["field"]: filter_values} - ) - - -class RangeFacet(Facet): - agg_type = "range" - - def _range_to_dict(self, range): - key, range = range - out = {"key": key} - if range[0] is not None: - out["from"] = range[0] - if range[1] is not None: - out["to"] = range[1] - return out - - def __init__(self, ranges, **kwargs): - super().__init__(**kwargs) - self._params["ranges"] = list(map(self._range_to_dict, ranges)) - self._params["keyed"] = False - self._ranges = dict(ranges) - - def get_value_filter(self, filter_value): - f, t = self._ranges[filter_value] - limits = {} - if f is not None: - limits["gte"] = f - if t is not None: - limits["lt"] = t - - return Range(_expand__to_dot=False, **{self._params["field"]: limits}) - - -class HistogramFacet(Facet): - agg_type = "histogram" - - def get_value_filter(self, filter_value): - return Range( - _expand__to_dot=False, - **{ - self._params["field"]: { - "gte": filter_value, - "lt": filter_value + self._params["interval"], - } - }, - ) - - -def _date_interval_year(d): - return d.replace( - year=d.year + 1, day=(28 if d.month == 2 and d.day == 29 else d.day) - ) - - -def _date_interval_month(d): - return (d + timedelta(days=32)).replace(day=1) - - -def _date_interval_week(d): - return d + timedelta(days=7) - - -def _date_interval_day(d): - return d + timedelta(days=1) - - -def _date_interval_hour(d): - return d + timedelta(hours=1) - - -class DateHistogramFacet(Facet): - agg_type = "date_histogram" - - DATE_INTERVALS = { - "year": _date_interval_year, - "1Y": _date_interval_year, - "month": _date_interval_month, - "1M": _date_interval_month, - "week": _date_interval_week, - "1w": _date_interval_week, - "day": _date_interval_day, - "1d": _date_interval_day, - "hour": _date_interval_hour, - "1h": _date_interval_hour, - } - - def __init__(self, **kwargs): - kwargs.setdefault("min_doc_count", 0) - super().__init__(**kwargs) - - def get_value(self, bucket): - if not isinstance(bucket["key"], datetime): - # Elasticsearch returns key=None instead of 0 for date 1970-01-01, - # so we need to set key to 0 to avoid TypeError exception - if bucket["key"] is None: - bucket["key"] = 0 - # Preserve milliseconds in the datetime - return datetime.utcfromtimestamp(int(bucket["key"]) / 1000.0) - else: - return bucket["key"] - - def get_value_filter(self, filter_value): - for interval_type in ("calendar_interval", "fixed_interval"): - if interval_type in self._params: - break - else: - interval_type = "interval" - - return Range( - _expand__to_dot=False, - **{ - self._params["field"]: { - "gte": filter_value, - "lt": self.DATE_INTERVALS[self._params[interval_type]]( - filter_value - ), - } - }, - ) - - -class NestedFacet(Facet): - agg_type = "nested" - - def __init__(self, path, nested_facet): - self._path = path - self._inner = nested_facet - super().__init__(path=path, aggs={"inner": nested_facet.get_aggregation()}) - - def get_values(self, data, filter_values): - return self._inner.get_values(data.inner, filter_values) - - def add_filter(self, filter_values): - inner_q = self._inner.add_filter(filter_values) - if inner_q: - return Nested(path=self._path, query=inner_q) - - -class FacetedResponse(Response): - @property - def query_string(self): - return self._faceted_search._query - - @property - def facets(self): - if not hasattr(self, "_facets"): - super(AttrDict, self).__setattr__("_facets", AttrDict({})) - for name, facet in self._faceted_search.facets.items(): - self._facets[name] = facet.get_values( - getattr(getattr(self.aggregations, "_filter_" + name), name), - self._faceted_search.filter_values.get(name, ()), - ) - return self._facets - - -class FacetedSearch: - """ - Abstraction for creating faceted navigation searches that takes care of - composing the queries, aggregations and filters as needed as well as - presenting the results in an easy-to-consume fashion:: - - class BlogSearch(FacetedSearch): - index = 'blogs' - doc_types = [Blog, Post] - fields = ['title^5', 'category', 'description', 'body'] - - facets = { - 'type': TermsFacet(field='_type'), - 'category': TermsFacet(field='category'), - 'weekly_posts': DateHistogramFacet(field='published_from', interval='week') - } - - def search(self): - ' Override search to add your own filters ' - s = super(BlogSearch, self).search() - return s.filter('term', published=True) - - # when using: - blog_search = BlogSearch("web framework", filters={"category": "python"}) - - # supports pagination - blog_search[10:20] - - response = blog_search.execute() - - # easy access to aggregation results: - for category, hit_count, is_selected in response.facets.category: - print( - "Category %s has %d hits%s." % ( - category, - hit_count, - ' and is chosen' if is_selected else '' - ) - ) - - """ - - index = None - doc_types = None - fields = None - facets = {} - using = "default" - - def __init__(self, query=None, filters={}, sort=()): - """ - :arg query: the text to search for - :arg filters: facet values to filter - :arg sort: sort information to be passed to :class:`~elasticsearch_dsl.Search` - """ - self._query = query - self._filters = {} - self._sort = sort - self.filter_values = {} - for name, value in filters.items(): - self.add_filter(name, value) - - self._s = self.build_search() - - def count(self): - return self._s.count() - - def __getitem__(self, k): - self._s = self._s[k] - return self - - def __iter__(self): - return iter(self._s) - - def add_filter(self, name, filter_values): - """ - Add a filter for a facet. - """ - # normalize the value into a list - if not isinstance(filter_values, (tuple, list)): - if filter_values is None: - return - filter_values = [ - filter_values, - ] - - # remember the filter values for use in FacetedResponse - self.filter_values[name] = filter_values - - # get the filter from the facet - f = self.facets[name].add_filter(filter_values) - if f is None: - return - - self._filters[name] = f - - def search(self): - """ - Returns the base Search object to which the facets are added. - - You can customize the query by overriding this method and returning a - modified search object. - """ - s = Search(doc_type=self.doc_types, index=self.index, using=self.using) - return s.response_class(FacetedResponse) - - def query(self, search, query): - """ - Add query part to ``search``. - - Override this if you wish to customize the query used. - """ - if query: - if self.fields: - return search.query("multi_match", fields=self.fields, query=query) - else: - return search.query("multi_match", query=query) - return search - - def aggregate(self, search): - """ - Add aggregations representing the facets selected, including potential - filters. - """ - for f, facet in self.facets.items(): - agg = facet.get_aggregation() - agg_filter = MatchAll() - for field, filter in self._filters.items(): - if f == field: - continue - agg_filter &= filter - search.aggs.bucket("_filter_" + f, "filter", filter=agg_filter).bucket( - f, agg - ) - - def filter(self, search): - """ - Add a ``post_filter`` to the search request narrowing the results based - on the facet filters. - """ - if not self._filters: - return search - - post_filter = MatchAll() - for f in self._filters.values(): - post_filter &= f - return search.post_filter(post_filter) - - def highlight(self, search): - """ - Add highlighting for all the fields - """ - return search.highlight( - *(f if "^" not in f else f.split("^", 1)[0] for f in self.fields) - ) - - def sort(self, search): - """ - Add sorting information to the request. - """ - if self._sort: - search = search.sort(*self._sort) - return search - - def build_search(self): - """ - Construct the ``Search`` object. - """ - s = self.search() - s = self.query(s, self._query) - s = self.filter(s) - if self.fields: - s = self.highlight(s) - s = self.sort(s) - self.aggregate(s) - return s - - def execute(self): - """ - Execute the search and return the response. - """ - r = self._s.execute() - r._faceted_search = self - return r +from elasticsearch_dsl._async.faceted_search import AsyncFacetedSearch # noqa: F401 +from elasticsearch_dsl._sync.faceted_search import FacetedSearch # noqa: F401 +from elasticsearch_dsl.faceted_search_base import ( # noqa: F401 + DateHistogramFacet, + Facet, + FacetedResponse, + HistogramFacet, + NestedFacet, + RangeFacet, + TermsFacet, +) diff --git a/elasticsearch_dsl/faceted_search_base.py b/elasticsearch_dsl/faceted_search_base.py new file mode 100644 index 000000000..e14297202 --- /dev/null +++ b/elasticsearch_dsl/faceted_search_base.py @@ -0,0 +1,439 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from datetime import datetime, timedelta + +from .aggs import A +from .query import MatchAll, Nested, Range, Terms +from .response import Response +from .utils import AttrDict + +__all__ = [ + "FacetedSearchBase", + "HistogramFacet", + "TermsFacet", + "DateHistogramFacet", + "RangeFacet", + "NestedFacet", +] + + +class Facet: + """ + A facet on faceted search. Wraps and aggregation and provides functionality + to create a filter for selected values and return a list of facet values + from the result of the aggregation. + """ + + agg_type = None + + def __init__(self, metric=None, metric_sort="desc", **kwargs): + self.filter_values = () + self._params = kwargs + self._metric = metric + if metric and metric_sort: + self._params["order"] = {"metric": metric_sort} + + def get_aggregation(self): + """ + Return the aggregation object. + """ + agg = A(self.agg_type, **self._params) + if self._metric: + agg.metric("metric", self._metric) + return agg + + def add_filter(self, filter_values): + """ + Construct a filter. + """ + if not filter_values: + return + + f = self.get_value_filter(filter_values[0]) + for v in filter_values[1:]: + f |= self.get_value_filter(v) + return f + + def get_value_filter(self, filter_value): + """ + Construct a filter for an individual value + """ + pass + + def is_filtered(self, key, filter_values): + """ + Is a filter active on the given key. + """ + return key in filter_values + + def get_value(self, bucket): + """ + return a value representing a bucket. Its key as default. + """ + return bucket["key"] + + def get_metric(self, bucket): + """ + Return a metric, by default doc_count for a bucket. + """ + if self._metric: + return bucket["metric"]["value"] + return bucket["doc_count"] + + def get_values(self, data, filter_values): + """ + Turn the raw bucket data into a list of tuples containing the key, + number of documents and a flag indicating whether this value has been + selected or not. + """ + out = [] + for bucket in data.buckets: + key = self.get_value(bucket) + out.append( + (key, self.get_metric(bucket), self.is_filtered(key, filter_values)) + ) + return out + + +class TermsFacet(Facet): + agg_type = "terms" + + def add_filter(self, filter_values): + """Create a terms filter instead of bool containing term filters.""" + if filter_values: + return Terms( + _expand__to_dot=False, **{self._params["field"]: filter_values} + ) + + +class RangeFacet(Facet): + agg_type = "range" + + def _range_to_dict(self, range): + key, range = range + out = {"key": key} + if range[0] is not None: + out["from"] = range[0] + if range[1] is not None: + out["to"] = range[1] + return out + + def __init__(self, ranges, **kwargs): + super().__init__(**kwargs) + self._params["ranges"] = list(map(self._range_to_dict, ranges)) + self._params["keyed"] = False + self._ranges = dict(ranges) + + def get_value_filter(self, filter_value): + f, t = self._ranges[filter_value] + limits = {} + if f is not None: + limits["gte"] = f + if t is not None: + limits["lt"] = t + + return Range(_expand__to_dot=False, **{self._params["field"]: limits}) + + +class HistogramFacet(Facet): + agg_type = "histogram" + + def get_value_filter(self, filter_value): + return Range( + _expand__to_dot=False, + **{ + self._params["field"]: { + "gte": filter_value, + "lt": filter_value + self._params["interval"], + } + }, + ) + + +def _date_interval_year(d): + return d.replace( + year=d.year + 1, day=(28 if d.month == 2 and d.day == 29 else d.day) + ) + + +def _date_interval_month(d): + return (d + timedelta(days=32)).replace(day=1) + + +def _date_interval_week(d): + return d + timedelta(days=7) + + +def _date_interval_day(d): + return d + timedelta(days=1) + + +def _date_interval_hour(d): + return d + timedelta(hours=1) + + +class DateHistogramFacet(Facet): + agg_type = "date_histogram" + + DATE_INTERVALS = { + "year": _date_interval_year, + "1Y": _date_interval_year, + "month": _date_interval_month, + "1M": _date_interval_month, + "week": _date_interval_week, + "1w": _date_interval_week, + "day": _date_interval_day, + "1d": _date_interval_day, + "hour": _date_interval_hour, + "1h": _date_interval_hour, + } + + def __init__(self, **kwargs): + kwargs.setdefault("min_doc_count", 0) + super().__init__(**kwargs) + + def get_value(self, bucket): + if not isinstance(bucket["key"], datetime): + # Elasticsearch returns key=None instead of 0 for date 1970-01-01, + # so we need to set key to 0 to avoid TypeError exception + if bucket["key"] is None: + bucket["key"] = 0 + # Preserve milliseconds in the datetime + return datetime.utcfromtimestamp(int(bucket["key"]) / 1000.0) + else: + return bucket["key"] + + def get_value_filter(self, filter_value): + for interval_type in ("calendar_interval", "fixed_interval"): + if interval_type in self._params: + break + else: + interval_type = "interval" + + return Range( + _expand__to_dot=False, + **{ + self._params["field"]: { + "gte": filter_value, + "lt": self.DATE_INTERVALS[self._params[interval_type]]( + filter_value + ), + } + }, + ) + + +class NestedFacet(Facet): + agg_type = "nested" + + def __init__(self, path, nested_facet): + self._path = path + self._inner = nested_facet + super().__init__(path=path, aggs={"inner": nested_facet.get_aggregation()}) + + def get_values(self, data, filter_values): + return self._inner.get_values(data.inner, filter_values) + + def add_filter(self, filter_values): + inner_q = self._inner.add_filter(filter_values) + if inner_q: + return Nested(path=self._path, query=inner_q) + + +class FacetedResponse(Response): + @property + def query_string(self): + return self._faceted_search._query + + @property + def facets(self): + if not hasattr(self, "_facets"): + super(AttrDict, self).__setattr__("_facets", AttrDict({})) + for name, facet in self._faceted_search.facets.items(): + self._facets[name] = facet.get_values( + getattr(getattr(self.aggregations, "_filter_" + name), name), + self._faceted_search.filter_values.get(name, ()), + ) + return self._facets + + +class FacetedSearchBase: + """ + Abstraction for creating faceted navigation searches that takes care of + composing the queries, aggregations and filters as needed as well as + presenting the results in an easy-to-consume fashion:: + + class BlogSearch(FacetedSearch): + index = 'blogs' + doc_types = [Blog, Post] + fields = ['title^5', 'category', 'description', 'body'] + + facets = { + 'type': TermsFacet(field='_type'), + 'category': TermsFacet(field='category'), + 'weekly_posts': DateHistogramFacet(field='published_from', interval='week') + } + + def search(self): + ' Override search to add your own filters ' + s = super(BlogSearch, self).search() + return s.filter('term', published=True) + + # when using: + blog_search = BlogSearch("web framework", filters={"category": "python"}) + + # supports pagination + blog_search[10:20] + + response = blog_search.execute() + + # easy access to aggregation results: + for category, hit_count, is_selected in response.facets.category: + print( + "Category %s has %d hits%s." % ( + category, + hit_count, + ' and is chosen' if is_selected else '' + ) + ) + + """ + + index = None + doc_types = None + fields = None + facets = {} + using = "default" + + def __init__(self, query=None, filters={}, sort=()): + """ + :arg query: the text to search for + :arg filters: facet values to filter + :arg sort: sort information to be passed to :class:`~elasticsearch_dsl.Search` + """ + self._query = query + self._filters = {} + self._sort = sort + self.filter_values = {} + for name, value in filters.items(): + self.add_filter(name, value) + + self._s = self.build_search() + + def count(self): + return self._s.count() + + def __getitem__(self, k): + self._s = self._s[k] + return self + + def __iter__(self): + return iter(self._s) + + def add_filter(self, name, filter_values): + """ + Add a filter for a facet. + """ + # normalize the value into a list + if not isinstance(filter_values, (tuple, list)): + if filter_values is None: + return + filter_values = [ + filter_values, + ] + + # remember the filter values for use in FacetedResponse + self.filter_values[name] = filter_values + + # get the filter from the facet + f = self.facets[name].add_filter(filter_values) + if f is None: + return + + self._filters[name] = f + + def query(self, search, query): + """ + Add query part to ``search``. + + Override this if you wish to customize the query used. + """ + if query: + if self.fields: + return search.query("multi_match", fields=self.fields, query=query) + else: + return search.query("multi_match", query=query) + return search + + def aggregate(self, search): + """ + Add aggregations representing the facets selected, including potential + filters. + """ + for f, facet in self.facets.items(): + agg = facet.get_aggregation() + agg_filter = MatchAll() + for field, filter in self._filters.items(): + if f == field: + continue + agg_filter &= filter + search.aggs.bucket("_filter_" + f, "filter", filter=agg_filter).bucket( + f, agg + ) + + def filter(self, search): + """ + Add a ``post_filter`` to the search request narrowing the results based + on the facet filters. + """ + if not self._filters: + return search + + post_filter = MatchAll() + for f in self._filters.values(): + post_filter &= f + return search.post_filter(post_filter) + + def highlight(self, search): + """ + Add highlighting for all the fields + """ + return search.highlight( + *(f if "^" not in f else f.split("^", 1)[0] for f in self.fields) + ) + + def sort(self, search): + """ + Add sorting information to the request. + """ + if self._sort: + search = search.sort(*self._sort) + return search + + def build_search(self): + """ + Construct the ``Search`` object. + """ + s = self.search() + s = self.query(s, self._query) + s = self.filter(s) + if self.fields: + s = self.highlight(s) + s = self.sort(s) + self.aggregate(s) + return s diff --git a/elasticsearch_dsl/index.py b/elasticsearch_dsl/index.py index ffd95dab9..ef9f4b0b9 100644 --- a/elasticsearch_dsl/index.py +++ b/elasticsearch_dsl/index.py @@ -15,637 +15,5 @@ # specific language governing permissions and limitations # under the License. -from . import analysis -from .connections import get_connection -from .exceptions import IllegalOperation -from .mapping import Mapping -from .search import Search -from .update_by_query import UpdateByQuery -from .utils import merge - - -class IndexTemplate: - def __init__(self, name, template, index=None, order=None, **kwargs): - if index is None: - self._index = Index(template, **kwargs) - else: - if kwargs: - raise ValueError( - "You cannot specify options for Index when" - " passing an Index instance." - ) - self._index = index.clone() - self._index._name = template - self._template_name = name - self.order = order - - def __getattr__(self, attr_name): - return getattr(self._index, attr_name) - - def to_dict(self): - d = self._index.to_dict() - d["index_patterns"] = [self._index._name] - if self.order is not None: - d["order"] = self.order - return d - - def save(self, using=None): - es = get_connection(using or self._index._using) - return es.indices.put_template(name=self._template_name, body=self.to_dict()) - - -class Index: - def __init__(self, name, using="default"): - """ - :arg name: name of the index - :arg using: connection alias to use, defaults to ``'default'`` - """ - self._name = name - self._doc_types = [] - self._using = using - self._settings = {} - self._aliases = {} - self._analysis = {} - self._mapping = None - - def get_or_create_mapping(self): - if self._mapping is None: - self._mapping = Mapping() - return self._mapping - - def as_template(self, template_name, pattern=None, order=None): - # TODO: should we allow pattern to be a top-level arg? - # or maybe have an IndexPattern that allows for it and have - # Document._index be that? - return IndexTemplate( - template_name, pattern or self._name, index=self, order=order - ) - - def resolve_nested(self, field_path): - for doc in self._doc_types: - nested, field = doc._doc_type.mapping.resolve_nested(field_path) - if field is not None: - return nested, field - if self._mapping: - return self._mapping.resolve_nested(field_path) - return (), None - - def resolve_field(self, field_path): - for doc in self._doc_types: - field = doc._doc_type.mapping.resolve_field(field_path) - if field is not None: - return field - if self._mapping: - return self._mapping.resolve_field(field_path) - return None - - def load_mappings(self, using=None): - self.get_or_create_mapping().update_from_es( - self._name, using=using or self._using - ) - - def clone(self, name=None, using=None): - """ - Create a copy of the instance with another name or connection alias. - Useful for creating multiple indices with shared configuration:: - - i = Index('base-index') - i.settings(number_of_shards=1) - i.create() - - i2 = i.clone('other-index') - i2.create() - - :arg name: name of the index - :arg using: connection alias to use, defaults to ``'default'`` - """ - i = Index(name or self._name, using=using or self._using) - i._settings = self._settings.copy() - i._aliases = self._aliases.copy() - i._analysis = self._analysis.copy() - i._doc_types = self._doc_types[:] - if self._mapping is not None: - i._mapping = self._mapping._clone() - return i - - def _get_connection(self, using=None): - if self._name is None: - raise ValueError("You cannot perform API calls on the default index.") - return get_connection(using or self._using) - - connection = property(_get_connection) - - def mapping(self, mapping): - """ - Associate a mapping (an instance of - :class:`~elasticsearch_dsl.Mapping`) with this index. - This means that, when this index is created, it will contain the - mappings for the document type defined by those mappings. - """ - self.get_or_create_mapping().update(mapping) - - def document(self, document): - """ - Associate a :class:`~elasticsearch_dsl.Document` subclass with an index. - This means that, when this index is created, it will contain the - mappings for the ``Document``. If the ``Document`` class doesn't have a - default index yet (by defining ``class Index``), this instance will be - used. Can be used as a decorator:: - - i = Index('blog') - - @i.document - class Post(Document): - title = Text() - - # create the index, including Post mappings - i.create() - - # .search() will now return a Search object that will return - # properly deserialized Post instances - s = i.search() - """ - self._doc_types.append(document) - - # If the document index does not have any name, that means the user - # did not set any index already to the document. - # So set this index as document index - if document._index._name is None: - document._index = self - - return document - - def settings(self, **kwargs): - """ - Add settings to the index:: - - i = Index('i') - i.settings(number_of_shards=1, number_of_replicas=0) - - Multiple calls to ``settings`` will merge the keys, later overriding - the earlier. - """ - self._settings.update(kwargs) - return self - - def aliases(self, **kwargs): - """ - Add aliases to the index definition:: - - i = Index('blog-v2') - i.aliases(blog={}, published={'filter': Q('term', published=True)}) - """ - self._aliases.update(kwargs) - return self - - def analyzer(self, *args, **kwargs): - """ - Explicitly add an analyzer to an index. Note that all custom analyzers - defined in mappings will also be created. This is useful for search analyzers. - - Example:: - - from elasticsearch_dsl import analyzer, tokenizer - - my_analyzer = analyzer('my_analyzer', - tokenizer=tokenizer('trigram', 'nGram', min_gram=3, max_gram=3), - filter=['lowercase'] - ) - - i = Index('blog') - i.analyzer(my_analyzer) - - """ - analyzer = analysis.analyzer(*args, **kwargs) - d = analyzer.get_analysis_definition() - # empty custom analyzer, probably already defined out of our control - if not d: - return - - # merge the definition - merge(self._analysis, d, True) - - def to_dict(self): - out = {} - if self._settings: - out["settings"] = self._settings - if self._aliases: - out["aliases"] = self._aliases - mappings = self._mapping.to_dict() if self._mapping else {} - analysis = self._mapping._collect_analysis() if self._mapping else {} - for d in self._doc_types: - mapping = d._doc_type.mapping - merge(mappings, mapping.to_dict(), True) - merge(analysis, mapping._collect_analysis(), True) - if mappings: - out["mappings"] = mappings - if analysis or self._analysis: - merge(analysis, self._analysis) - out.setdefault("settings", {})["analysis"] = analysis - return out - - def search(self, using=None): - """ - Return a :class:`~elasticsearch_dsl.Search` object searching over the - index (or all the indices belonging to this template) and its - ``Document``\\s. - """ - return Search( - using=using or self._using, index=self._name, doc_type=self._doc_types - ) - - def updateByQuery(self, using=None): - """ - Return a :class:`~elasticsearch_dsl.UpdateByQuery` object searching over the index - (or all the indices belonging to this template) and updating Documents that match - the search criteria. - - For more information, see here: - https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-update-by-query.html - """ - return UpdateByQuery( - using=using or self._using, - index=self._name, - ) - - def create(self, using=None, **kwargs): - """ - Creates the index in elasticsearch. - - Any additional keyword arguments will be passed to - ``Elasticsearch.indices.create`` unchanged. - """ - return self._get_connection(using).indices.create( - index=self._name, body=self.to_dict(), **kwargs - ) - - def is_closed(self, using=None): - state = self._get_connection(using).cluster.state( - index=self._name, metric="metadata" - ) - return state["metadata"]["indices"][self._name]["state"] == "close" - - def save(self, using=None): - """ - Sync the index definition with elasticsearch, creating the index if it - doesn't exist and updating its settings and mappings if it does. - - Note some settings and mapping changes cannot be done on an open - index (or at all on an existing index) and for those this method will - fail with the underlying exception. - """ - if not self.exists(using=using): - return self.create(using=using) - - body = self.to_dict() - settings = body.pop("settings", {}) - analysis = settings.pop("analysis", None) - current_settings = self.get_settings(using=using)[self._name]["settings"][ - "index" - ] - if analysis: - if self.is_closed(using=using): - # closed index, update away - settings["analysis"] = analysis - else: - # compare analysis definition, if all analysis objects are - # already defined as requested, skip analysis update and - # proceed, otherwise raise IllegalOperation - existing_analysis = current_settings.get("analysis", {}) - if any( - existing_analysis.get(section, {}).get(k, None) - != analysis[section][k] - for section in analysis - for k in analysis[section] - ): - raise IllegalOperation( - "You cannot update analysis configuration on an open index, " - "you need to close index %s first." % self._name - ) - - # try and update the settings - if settings: - settings = settings.copy() - for k, v in list(settings.items()): - if k in current_settings and current_settings[k] == str(v): - del settings[k] - - if settings: - self.put_settings(using=using, body=settings) - - # update the mappings, any conflict in the mappings will result in an - # exception - mappings = body.pop("mappings", {}) - if mappings: - self.put_mapping(using=using, body=mappings) - - def analyze(self, using=None, **kwargs): - """ - Perform the analysis process on a text and return the tokens breakdown - of the text. - - Any additional keyword arguments will be passed to - ``Elasticsearch.indices.analyze`` unchanged. - """ - return self._get_connection(using).indices.analyze(index=self._name, **kwargs) - - def refresh(self, using=None, **kwargs): - """ - Performs a refresh operation on the index. - - Any additional keyword arguments will be passed to - ``Elasticsearch.indices.refresh`` unchanged. - """ - return self._get_connection(using).indices.refresh(index=self._name, **kwargs) - - def flush(self, using=None, **kwargs): - """ - Performs a flush operation on the index. - - Any additional keyword arguments will be passed to - ``Elasticsearch.indices.flush`` unchanged. - """ - return self._get_connection(using).indices.flush(index=self._name, **kwargs) - - def get(self, using=None, **kwargs): - """ - The get index API allows to retrieve information about the index. - - Any additional keyword arguments will be passed to - ``Elasticsearch.indices.get`` unchanged. - """ - return self._get_connection(using).indices.get(index=self._name, **kwargs) - - def open(self, using=None, **kwargs): - """ - Opens the index in elasticsearch. - - Any additional keyword arguments will be passed to - ``Elasticsearch.indices.open`` unchanged. - """ - return self._get_connection(using).indices.open(index=self._name, **kwargs) - - def close(self, using=None, **kwargs): - """ - Closes the index in elasticsearch. - - Any additional keyword arguments will be passed to - ``Elasticsearch.indices.close`` unchanged. - """ - return self._get_connection(using).indices.close(index=self._name, **kwargs) - - def delete(self, using=None, **kwargs): - """ - Deletes the index in elasticsearch. - - Any additional keyword arguments will be passed to - ``Elasticsearch.indices.delete`` unchanged. - """ - return self._get_connection(using).indices.delete(index=self._name, **kwargs) - - def exists(self, using=None, **kwargs): - """ - Returns ``True`` if the index already exists in elasticsearch. - - Any additional keyword arguments will be passed to - ``Elasticsearch.indices.exists`` unchanged. - """ - return self._get_connection(using).indices.exists(index=self._name, **kwargs) - - def exists_type(self, using=None, **kwargs): - """ - Check if a type/types exists in the index. - - Any additional keyword arguments will be passed to - ``Elasticsearch.indices.exists_type`` unchanged. - """ - return self._get_connection(using).indices.exists_type( - index=self._name, **kwargs - ) - - def put_mapping(self, using=None, **kwargs): - """ - Register specific mapping definition for a specific type. - - Any additional keyword arguments will be passed to - ``Elasticsearch.indices.put_mapping`` unchanged. - """ - return self._get_connection(using).indices.put_mapping( - index=self._name, **kwargs - ) - - def get_mapping(self, using=None, **kwargs): - """ - Retrieve specific mapping definition for a specific type. - - Any additional keyword arguments will be passed to - ``Elasticsearch.indices.get_mapping`` unchanged. - """ - return self._get_connection(using).indices.get_mapping( - index=self._name, **kwargs - ) - - def get_field_mapping(self, using=None, **kwargs): - """ - Retrieve mapping definition of a specific field. - - Any additional keyword arguments will be passed to - ``Elasticsearch.indices.get_field_mapping`` unchanged. - """ - return self._get_connection(using).indices.get_field_mapping( - index=self._name, **kwargs - ) - - def put_alias(self, using=None, **kwargs): - """ - Create an alias for the index. - - Any additional keyword arguments will be passed to - ``Elasticsearch.indices.put_alias`` unchanged. - """ - return self._get_connection(using).indices.put_alias(index=self._name, **kwargs) - - def exists_alias(self, using=None, **kwargs): - """ - Return a boolean indicating whether given alias exists for this index. - - Any additional keyword arguments will be passed to - ``Elasticsearch.indices.exists_alias`` unchanged. - """ - return self._get_connection(using).indices.exists_alias( - index=self._name, **kwargs - ) - - def get_alias(self, using=None, **kwargs): - """ - Retrieve a specified alias. - - Any additional keyword arguments will be passed to - ``Elasticsearch.indices.get_alias`` unchanged. - """ - return self._get_connection(using).indices.get_alias(index=self._name, **kwargs) - - def delete_alias(self, using=None, **kwargs): - """ - Delete specific alias. - - Any additional keyword arguments will be passed to - ``Elasticsearch.indices.delete_alias`` unchanged. - """ - return self._get_connection(using).indices.delete_alias( - index=self._name, **kwargs - ) - - def get_settings(self, using=None, **kwargs): - """ - Retrieve settings for the index. - - Any additional keyword arguments will be passed to - ``Elasticsearch.indices.get_settings`` unchanged. - """ - return self._get_connection(using).indices.get_settings( - index=self._name, **kwargs - ) - - def put_settings(self, using=None, **kwargs): - """ - Change specific index level settings in real time. - - Any additional keyword arguments will be passed to - ``Elasticsearch.indices.put_settings`` unchanged. - """ - return self._get_connection(using).indices.put_settings( - index=self._name, **kwargs - ) - - def stats(self, using=None, **kwargs): - """ - Retrieve statistics on different operations happening on the index. - - Any additional keyword arguments will be passed to - ``Elasticsearch.indices.stats`` unchanged. - """ - return self._get_connection(using).indices.stats(index=self._name, **kwargs) - - def segments(self, using=None, **kwargs): - """ - Provide low level segments information that a Lucene index (shard - level) is built with. - - Any additional keyword arguments will be passed to - ``Elasticsearch.indices.segments`` unchanged. - """ - return self._get_connection(using).indices.segments(index=self._name, **kwargs) - - def validate_query(self, using=None, **kwargs): - """ - Validate a potentially expensive query without executing it. - - Any additional keyword arguments will be passed to - ``Elasticsearch.indices.validate_query`` unchanged. - """ - return self._get_connection(using).indices.validate_query( - index=self._name, **kwargs - ) - - def clear_cache(self, using=None, **kwargs): - """ - Clear all caches or specific cached associated with the index. - - Any additional keyword arguments will be passed to - ``Elasticsearch.indices.clear_cache`` unchanged. - """ - return self._get_connection(using).indices.clear_cache( - index=self._name, **kwargs - ) - - def recovery(self, using=None, **kwargs): - """ - The indices recovery API provides insight into on-going shard - recoveries for the index. - - Any additional keyword arguments will be passed to - ``Elasticsearch.indices.recovery`` unchanged. - """ - return self._get_connection(using).indices.recovery(index=self._name, **kwargs) - - def upgrade(self, using=None, **kwargs): - """ - Upgrade the index to the latest format. - - Any additional keyword arguments will be passed to - ``Elasticsearch.indices.upgrade`` unchanged. - """ - return self._get_connection(using).indices.upgrade(index=self._name, **kwargs) - - def get_upgrade(self, using=None, **kwargs): - """ - Monitor how much of the index is upgraded. - - Any additional keyword arguments will be passed to - ``Elasticsearch.indices.get_upgrade`` unchanged. - """ - return self._get_connection(using).indices.get_upgrade( - index=self._name, **kwargs - ) - - def flush_synced(self, using=None, **kwargs): - """ - Perform a normal flush, then add a generated unique marker (sync_id) to - all shards. - - Any additional keyword arguments will be passed to - ``Elasticsearch.indices.flush_synced`` unchanged. - """ - return self._get_connection(using).indices.flush_synced( - index=self._name, **kwargs - ) - - def shard_stores(self, using=None, **kwargs): - """ - Provides store information for shard copies of the index. Store - information reports on which nodes shard copies exist, the shard copy - version, indicating how recent they are, and any exceptions encountered - while opening the shard index or from earlier engine failure. - - Any additional keyword arguments will be passed to - ``Elasticsearch.indices.shard_stores`` unchanged. - """ - return self._get_connection(using).indices.shard_stores( - index=self._name, **kwargs - ) - - def forcemerge(self, using=None, **kwargs): - """ - The force merge API allows to force merging of the index through an - API. The merge relates to the number of segments a Lucene index holds - within each shard. The force merge operation allows to reduce the - number of segments by merging them. - - This call will block until the merge is complete. If the http - connection is lost, the request will continue in the background, and - any new requests will block until the previous force merge is complete. - - Any additional keyword arguments will be passed to - ``Elasticsearch.indices.forcemerge`` unchanged. - """ - return self._get_connection(using).indices.forcemerge( - index=self._name, **kwargs - ) - - def shrink(self, using=None, **kwargs): - """ - The shrink index API allows you to shrink an existing index into a new - index with fewer primary shards. The number of primary shards in the - target index must be a factor of the shards in the source index. For - example an index with 8 primary shards can be shrunk into 4, 2 or 1 - primary shards or an index with 15 primary shards can be shrunk into 5, - 3 or 1. If the number of shards in the index is a prime number it can - only be shrunk into a single primary shard. Before shrinking, a - (primary or replica) copy of every shard in the index must be present - on the same node. - - Any additional keyword arguments will be passed to - ``Elasticsearch.indices.shrink`` unchanged. - """ - return self._get_connection(using).indices.shrink(index=self._name, **kwargs) +from ._async.index import AsyncIndex, AsyncIndexTemplate # noqa: F401 +from ._sync.index import Index, IndexTemplate # noqa: F401 diff --git a/elasticsearch_dsl/index_base.py b/elasticsearch_dsl/index_base.py new file mode 100644 index 000000000..a0cda4e24 --- /dev/null +++ b/elasticsearch_dsl/index_base.py @@ -0,0 +1,167 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from . import analysis +from .utils import merge + + +class IndexBase: + def __init__(self, name, mapping_class, using="default"): + """ + :arg name: name of the index + :arg using: connection alias to use, defaults to ``'default'`` + """ + self._name = name + self._doc_types = [] + self._using = using + self._settings = {} + self._aliases = {} + self._analysis = {} + self._mapping_class = mapping_class + self._mapping = None + + def resolve_nested(self, field_path): + for doc in self._doc_types: + nested, field = doc._doc_type.mapping.resolve_nested(field_path) + if field is not None: + return nested, field + if self._mapping: + return self._mapping.resolve_nested(field_path) + return (), None + + def resolve_field(self, field_path): + for doc in self._doc_types: + field = doc._doc_type.mapping.resolve_field(field_path) + if field is not None: + return field + if self._mapping: + return self._mapping.resolve_field(field_path) + return None + + def get_or_create_mapping(self): + if self._mapping is None: + self._mapping = self._mapping_class() + return self._mapping + + def mapping(self, mapping): + """ + Associate a mapping (an instance of + :class:`~elasticsearch_dsl.Mapping`) with this index. + This means that, when this index is created, it will contain the + mappings for the document type defined by those mappings. + """ + self.get_or_create_mapping().update(mapping) + + def document(self, document): + """ + Associate a :class:`~elasticsearch_dsl.Document` subclass with an index. + This means that, when this index is created, it will contain the + mappings for the ``Document``. If the ``Document`` class doesn't have a + default index yet (by defining ``class Index``), this instance will be + used. Can be used as a decorator:: + + i = Index('blog') + + @i.document + class Post(Document): + title = Text() + + # create the index, including Post mappings + i.create() + + # .search() will now return a Search object that will return + # properly deserialized Post instances + s = i.search() + """ + self._doc_types.append(document) + + # If the document index does not have any name, that means the user + # did not set any index already to the document. + # So set this index as document index + if document._index._name is None: + document._index = self + + return document + + def settings(self, **kwargs): + """ + Add settings to the index:: + + i = Index('i') + i.settings(number_of_shards=1, number_of_replicas=0) + + Multiple calls to ``settings`` will merge the keys, later overriding + the earlier. + """ + self._settings.update(kwargs) + return self + + def aliases(self, **kwargs): + """ + Add aliases to the index definition:: + + i = Index('blog-v2') + i.aliases(blog={}, published={'filter': Q('term', published=True)}) + """ + self._aliases.update(kwargs) + return self + + def analyzer(self, *args, **kwargs): + """ + Explicitly add an analyzer to an index. Note that all custom analyzers + defined in mappings will also be created. This is useful for search analyzers. + + Example:: + + from elasticsearch_dsl import analyzer, tokenizer + + my_analyzer = analyzer('my_analyzer', + tokenizer=tokenizer('trigram', 'nGram', min_gram=3, max_gram=3), + filter=['lowercase'] + ) + + i = Index('blog') + i.analyzer(my_analyzer) + + """ + analyzer = analysis.analyzer(*args, **kwargs) + d = analyzer.get_analysis_definition() + # empty custom analyzer, probably already defined out of our control + if not d: + return + + # merge the definition + merge(self._analysis, d, True) + + def to_dict(self): + out = {} + if self._settings: + out["settings"] = self._settings + if self._aliases: + out["aliases"] = self._aliases + mappings = self._mapping.to_dict() if self._mapping else {} + analysis = self._mapping._collect_analysis() if self._mapping else {} + for d in self._doc_types: + mapping = d._doc_type.mapping + merge(mappings, mapping.to_dict(), True) + merge(analysis, mapping._collect_analysis(), True) + if mappings: + out["mappings"] = mappings + if analysis or self._analysis: + merge(analysis, self._analysis) + out.setdefault("settings", {})["analysis"] = analysis + return out diff --git a/elasticsearch_dsl/mapping.py b/elasticsearch_dsl/mapping.py index f48608475..301de2811 100644 --- a/elasticsearch_dsl/mapping.py +++ b/elasticsearch_dsl/mapping.py @@ -15,217 +15,5 @@ # specific language governing permissions and limitations # under the License. -import collections.abc -from itertools import chain - -from .connections import get_connection -from .field import Nested, Text, construct_field -from .utils import DslBase - -META_FIELDS = frozenset( - ( - "dynamic", - "transform", - "dynamic_date_formats", - "date_detection", - "numeric_detection", - "dynamic_templates", - "enabled", - ) -) - - -class Properties(DslBase): - name = "properties" - _param_defs = {"properties": {"type": "field", "hash": True}} - - def __init__(self): - super().__init__() - - def __repr__(self): - return "Properties()" - - def __getitem__(self, name): - return self.properties[name] - - def __contains__(self, name): - return name in self.properties - - def to_dict(self): - return super().to_dict()["properties"] - - def field(self, name, *args, **kwargs): - self.properties[name] = construct_field(*args, **kwargs) - return self - - def _collect_fields(self): - """Iterate over all Field objects within, including multi fields.""" - for f in self.properties.to_dict().values(): - yield f - # multi fields - if hasattr(f, "fields"): - yield from f.fields.to_dict().values() - # nested and inner objects - if hasattr(f, "_collect_fields"): - yield from f._collect_fields() - - def update(self, other_object): - if not hasattr(other_object, "properties"): - # not an inner/nested object, no merge possible - return - - our, other = self.properties, other_object.properties - for name in other: - if name in our: - if hasattr(our[name], "update"): - our[name].update(other[name]) - continue - our[name] = other[name] - - -class Mapping: - def __init__(self): - self.properties = Properties() - self._meta = {} - - def __repr__(self): - return "Mapping()" - - def _clone(self): - m = Mapping() - m.properties._params = self.properties._params.copy() - return m - - @classmethod - def from_es(cls, index, using="default"): - m = cls() - m.update_from_es(index, using) - return m - - def resolve_nested(self, field_path): - field = self - nested = [] - parts = field_path.split(".") - for i, step in enumerate(parts): - try: - field = field[step] - except KeyError: - return (), None - if isinstance(field, Nested): - nested.append(".".join(parts[: i + 1])) - return nested, field - - def resolve_field(self, field_path): - field = self - for step in field_path.split("."): - try: - field = field[step] - except KeyError: - return - return field - - def _collect_analysis(self): - analysis = {} - fields = [] - if "_all" in self._meta: - fields.append(Text(**self._meta["_all"])) - - for f in chain(fields, self.properties._collect_fields()): - for analyzer_name in ( - "analyzer", - "normalizer", - "search_analyzer", - "search_quote_analyzer", - ): - if not hasattr(f, analyzer_name): - continue - analyzer = getattr(f, analyzer_name) - d = analyzer.get_analysis_definition() - # empty custom analyzer, probably already defined out of our control - if not d: - continue - - # merge the definition - # TODO: conflict detection/resolution - for key in d: - analysis.setdefault(key, {}).update(d[key]) - - return analysis - - def save(self, index, using="default"): - from .index import Index - - index = Index(index, using=using) - index.mapping(self) - return index.save() - - def update_from_es(self, index, using="default"): - es = get_connection(using) - raw = es.indices.get_mapping(index=index) - _, raw = raw.popitem() - self._update_from_dict(raw["mappings"]) - - def _update_from_dict(self, raw): - for name, definition in raw.get("properties", {}).items(): - self.field(name, definition) - - # metadata like _all etc - for name, value in raw.items(): - if name != "properties": - if isinstance(value, collections.abc.Mapping): - self.meta(name, **value) - else: - self.meta(name, value) - - def update(self, mapping, update_only=False): - for name in mapping: - if update_only and name in self: - # nested and inner objects, merge recursively - if hasattr(self[name], "update"): - # FIXME only merge subfields, not the settings - self[name].update(mapping[name], update_only) - continue - self.field(name, mapping[name]) - - if update_only: - for name in mapping._meta: - if name not in self._meta: - self._meta[name] = mapping._meta[name] - else: - self._meta.update(mapping._meta) - - def __contains__(self, name): - return name in self.properties.properties - - def __getitem__(self, name): - return self.properties.properties[name] - - def __iter__(self): - return iter(self.properties.properties) - - def field(self, *args, **kwargs): - self.properties.field(*args, **kwargs) - return self - - def meta(self, name, params=None, **kwargs): - if not name.startswith("_") and name not in META_FIELDS: - name = "_" + name - - if params and kwargs: - raise ValueError("Meta configs cannot have both value and a dictionary.") - - self._meta[name] = kwargs if params is None else params - return self - - def to_dict(self): - meta = self._meta - - # hard coded serialization of analyzers in _all - if "_all" in meta: - meta = meta.copy() - _all = meta["_all"] = meta["_all"].copy() - for f in ("analyzer", "search_analyzer", "search_quote_analyzer"): - if hasattr(_all.get(f, None), "to_dict"): - _all[f] = _all[f].to_dict() - meta.update(self.properties.to_dict()) - return meta +from elasticsearch_dsl._async.mapping import AsyncMapping # noqa: F401 +from elasticsearch_dsl._sync.mapping import Mapping # noqa: F401 diff --git a/elasticsearch_dsl/mapping_base.py b/elasticsearch_dsl/mapping_base.py new file mode 100644 index 000000000..a8e5ea41f --- /dev/null +++ b/elasticsearch_dsl/mapping_base.py @@ -0,0 +1,211 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import collections.abc +from itertools import chain + +from .field import Nested, Text, construct_field +from .utils import DslBase + +META_FIELDS = frozenset( + ( + "dynamic", + "transform", + "dynamic_date_formats", + "date_detection", + "numeric_detection", + "dynamic_templates", + "enabled", + ) +) + + +class Properties(DslBase): + name = "properties" + _param_defs = {"properties": {"type": "field", "hash": True}} + + def __init__(self): + super().__init__() + + def __repr__(self): + return "Properties()" + + def __getitem__(self, name): + return self.properties[name] + + def __contains__(self, name): + return name in self.properties + + def to_dict(self): + return super().to_dict()["properties"] + + def field(self, name, *args, **kwargs): + self.properties[name] = construct_field(*args, **kwargs) + return self + + def _collect_fields(self): + """Iterate over all Field objects within, including multi fields.""" + for f in self.properties.to_dict().values(): + yield f + # multi fields + if hasattr(f, "fields"): + yield from f.fields.to_dict().values() + # nested and inner objects + if hasattr(f, "_collect_fields"): + yield from f._collect_fields() + + def update(self, other_object): + if not hasattr(other_object, "properties"): + # not an inner/nested object, no merge possible + return + + our, other = self.properties, other_object.properties + for name in other: + if name in our: + if hasattr(our[name], "update"): + our[name].update(other[name]) + continue + our[name] = other[name] + + +class MappingBase: + def __init__(self): + self.properties = Properties() + self._meta = {} + + def __repr__(self): + return "Mapping()" + + def _clone(self): + m = self.__class__() + m.properties._params = self.properties._params.copy() + return m + + def resolve_nested(self, field_path): + field = self + nested = [] + parts = field_path.split(".") + for i, step in enumerate(parts): + try: + field = field[step] + except KeyError: + return (), None + if isinstance(field, Nested): + nested.append(".".join(parts[: i + 1])) + return nested, field + + def resolve_field(self, field_path): + field = self + for step in field_path.split("."): + try: + field = field[step] + except KeyError: + return + return field + + def _collect_analysis(self): + analysis = {} + fields = [] + if "_all" in self._meta: + fields.append(Text(**self._meta["_all"])) + + for f in chain(fields, self.properties._collect_fields()): + for analyzer_name in ( + "analyzer", + "normalizer", + "search_analyzer", + "search_quote_analyzer", + ): + if not hasattr(f, analyzer_name): + continue + analyzer = getattr(f, analyzer_name) + d = analyzer.get_analysis_definition() + # empty custom analyzer, probably already defined out of our control + if not d: + continue + + # merge the definition + # TODO: conflict detection/resolution + for key in d: + analysis.setdefault(key, {}).update(d[key]) + + return analysis + + def _update_from_dict(self, raw): + for name, definition in raw.get("properties", {}).items(): + self.field(name, definition) + + # metadata like _all etc + for name, value in raw.items(): + if name != "properties": + if isinstance(value, collections.abc.Mapping): + self.meta(name, **value) + else: + self.meta(name, value) + + def update(self, mapping, update_only=False): + for name in mapping: + if update_only and name in self: + # nested and inner objects, merge recursively + if hasattr(self[name], "update"): + # FIXME only merge subfields, not the settings + self[name].update(mapping[name], update_only) + continue + self.field(name, mapping[name]) + + if update_only: + for name in mapping._meta: + if name not in self._meta: + self._meta[name] = mapping._meta[name] + else: + self._meta.update(mapping._meta) + + def __contains__(self, name): + return name in self.properties.properties + + def __getitem__(self, name): + return self.properties.properties[name] + + def __iter__(self): + return iter(self.properties.properties) + + def field(self, *args, **kwargs): + self.properties.field(*args, **kwargs) + return self + + def meta(self, name, params=None, **kwargs): + if not name.startswith("_") and name not in META_FIELDS: + name = "_" + name + + if params and kwargs: + raise ValueError("Meta configs cannot have both value and a dictionary.") + + self._meta[name] = kwargs if params is None else params + return self + + def to_dict(self): + meta = self._meta + + # hard coded serialization of analyzers in _all + if "_all" in meta: + meta = meta.copy() + _all = meta["_all"] = meta["_all"].copy() + for f in ("analyzer", "search_analyzer", "search_quote_analyzer"): + if hasattr(_all.get(f, None), "to_dict"): + _all[f] = _all[f].to_dict() + meta.update(self.properties.to_dict()) + return meta diff --git a/elasticsearch_dsl/search.py b/elasticsearch_dsl/search.py index 4ffff1883..a94e569be 100644 --- a/elasticsearch_dsl/search.py +++ b/elasticsearch_dsl/search.py @@ -15,924 +15,6 @@ # specific language governing permissions and limitations # under the License. -import collections.abc -import copy - -from elasticsearch.exceptions import ApiError -from elasticsearch.helpers import scan - -from .aggs import A, AggBase -from .connections import get_connection -from .exceptions import IllegalOperation -from .query import Bool, Q, Query -from .response import Hit, Response -from .utils import AttrDict, DslBase, recursive_to_dict - - -class QueryProxy: - """ - Simple proxy around DSL objects (queries) that can be called - (to add query/post_filter) and also allows attribute access which is proxied to - the wrapped query. - """ - - def __init__(self, search, attr_name): - self._search = search - self._proxied = None - self._attr_name = attr_name - - def __nonzero__(self): - return self._proxied is not None - - __bool__ = __nonzero__ - - def __call__(self, *args, **kwargs): - s = self._search._clone() - - # we cannot use self._proxied since we just cloned self._search and - # need to access the new self on the clone - proxied = getattr(s, self._attr_name) - if proxied._proxied is None: - proxied._proxied = Q(*args, **kwargs) - else: - proxied._proxied &= Q(*args, **kwargs) - - # always return search to be chainable - return s - - def __getattr__(self, attr_name): - return getattr(self._proxied, attr_name) - - def __setattr__(self, attr_name, value): - if not attr_name.startswith("_"): - self._proxied = Q(self._proxied.to_dict()) - setattr(self._proxied, attr_name, value) - super().__setattr__(attr_name, value) - - def __getstate__(self): - return self._search, self._proxied, self._attr_name - - def __setstate__(self, state): - self._search, self._proxied, self._attr_name = state - - -class ProxyDescriptor: - """ - Simple descriptor to enable setting of queries and filters as: - - s = Search() - s.query = Q(...) - - """ - - def __init__(self, name): - self._attr_name = f"_{name}_proxy" - - def __get__(self, instance, owner): - return getattr(instance, self._attr_name) - - def __set__(self, instance, value): - proxy = getattr(instance, self._attr_name) - proxy._proxied = Q(value) - - -class AggsProxy(AggBase, DslBase): - name = "aggs" - - def __init__(self, search): - self._base = self - self._search = search - self._params = {"aggs": {}} - - def to_dict(self): - return super().to_dict().get("aggs", {}) - - -class Request: - def __init__(self, using="default", index=None, doc_type=None, extra=None): - self._using = using - - self._index = None - if isinstance(index, (tuple, list)): - self._index = list(index) - elif index: - self._index = [index] - - self._doc_type = [] - self._doc_type_map = {} - if isinstance(doc_type, (tuple, list)): - self._doc_type.extend(doc_type) - elif isinstance(doc_type, collections.abc.Mapping): - self._doc_type.extend(doc_type.keys()) - self._doc_type_map.update(doc_type) - elif doc_type: - self._doc_type.append(doc_type) - - self._params = {} - self._extra = extra or {} - - def __eq__(self, other): - return ( - isinstance(other, Request) - and other._params == self._params - and other._index == self._index - and other._doc_type == self._doc_type - and other.to_dict() == self.to_dict() - ) - - def __copy__(self): - return self._clone() - - def params(self, **kwargs): - """ - Specify query params to be used when executing the search. All the - keyword arguments will override the current values. See - https://elasticsearch-py.readthedocs.io/en/master/api.html#elasticsearch.Elasticsearch.search - for all available parameters. - - Example:: - - s = Search() - s = s.params(routing='user-1', preference='local') - """ - s = self._clone() - s._params.update(kwargs) - return s - - def index(self, *index): - """ - Set the index for the search. If called empty it will remove all information. - - Example: - - s = Search() - s = s.index('twitter-2015.01.01', 'twitter-2015.01.02') - s = s.index(['twitter-2015.01.01', 'twitter-2015.01.02']) - """ - # .index() resets - s = self._clone() - if not index: - s._index = None - else: - indexes = [] - for i in index: - if isinstance(i, str): - indexes.append(i) - elif isinstance(i, list): - indexes += i - elif isinstance(i, tuple): - indexes += list(i) - - s._index = (self._index or []) + indexes - - return s - - def _resolve_field(self, path): - for dt in self._doc_type: - if not hasattr(dt, "_index"): - continue - field = dt._index.resolve_field(path) - if field is not None: - return field - - def _resolve_nested(self, hit, parent_class=None): - doc_class = Hit - - nested_path = [] - nesting = hit["_nested"] - while nesting and "field" in nesting: - nested_path.append(nesting["field"]) - nesting = nesting.get("_nested") - nested_path = ".".join(nested_path) - - if hasattr(parent_class, "_index"): - nested_field = parent_class._index.resolve_field(nested_path) - else: - nested_field = self._resolve_field(nested_path) - - if nested_field is not None: - return nested_field._doc_class - - return doc_class - - def _get_result(self, hit, parent_class=None): - doc_class = Hit - dt = hit.get("_type") - - if "_nested" in hit: - doc_class = self._resolve_nested(hit, parent_class) - - elif dt in self._doc_type_map: - doc_class = self._doc_type_map[dt] - - else: - for doc_type in self._doc_type: - if hasattr(doc_type, "_matches") and doc_type._matches(hit): - doc_class = doc_type - break - - for t in hit.get("inner_hits", ()): - hit["inner_hits"][t] = Response( - self, hit["inner_hits"][t], doc_class=doc_class - ) - - callback = getattr(doc_class, "from_es", doc_class) - return callback(hit) - - def doc_type(self, *doc_type, **kwargs): - """ - Set the type to search through. You can supply a single value or - multiple. Values can be strings or subclasses of ``Document``. - - You can also pass in any keyword arguments, mapping a doc_type to a - callback that should be used instead of the Hit class. - - If no doc_type is supplied any information stored on the instance will - be erased. - - Example: - - s = Search().doc_type('product', 'store', User, custom=my_callback) - """ - # .doc_type() resets - s = self._clone() - if not doc_type and not kwargs: - s._doc_type = [] - s._doc_type_map = {} - else: - s._doc_type.extend(doc_type) - s._doc_type.extend(kwargs.keys()) - s._doc_type_map.update(kwargs) - return s - - def using(self, client): - """ - Associate the search request with an elasticsearch client. A fresh copy - will be returned with current instance remaining unchanged. - - :arg client: an instance of ``elasticsearch.Elasticsearch`` to use or - an alias to look up in ``elasticsearch_dsl.connections`` - - """ - s = self._clone() - s._using = client - return s - - def extra(self, **kwargs): - """ - Add extra keys to the request body. Mostly here for backwards - compatibility. - """ - s = self._clone() - if "from_" in kwargs: - kwargs["from"] = kwargs.pop("from_") - s._extra.update(kwargs) - return s - - def _clone(self): - s = self.__class__( - using=self._using, index=self._index, doc_type=self._doc_type - ) - s._doc_type_map = self._doc_type_map.copy() - s._extra = self._extra.copy() - s._params = self._params.copy() - return s - - -class Search(Request): - query = ProxyDescriptor("query") - post_filter = ProxyDescriptor("post_filter") - - def __init__(self, **kwargs): - """ - Search request to elasticsearch. - - :arg using: `Elasticsearch` instance to use - :arg index: limit the search to index - :arg doc_type: only query this type. - - All the parameters supplied (or omitted) at creation type can be later - overridden by methods (`using`, `index` and `doc_type` respectively). - """ - super().__init__(**kwargs) - - self.aggs = AggsProxy(self) - self._sort = [] - self._knn = [] - self._rank = {} - self._collapse = {} - self._source = None - self._highlight = {} - self._highlight_opts = {} - self._suggest = {} - self._script_fields = {} - self._response_class = Response - - self._query_proxy = QueryProxy(self, "query") - self._post_filter_proxy = QueryProxy(self, "post_filter") - - def filter(self, *args, **kwargs): - return self.query(Bool(filter=[Q(*args, **kwargs)])) - - def exclude(self, *args, **kwargs): - return self.query(Bool(filter=[~Q(*args, **kwargs)])) - - def __iter__(self): - """ - Iterate over the hits. - """ - return iter(self.execute()) - - def __getitem__(self, n): - """ - Support slicing the `Search` instance for pagination. - - Slicing equates to the from/size parameters. E.g.:: - - s = Search().query(...)[0:25] - - is equivalent to:: - - s = Search().query(...).extra(from_=0, size=25) - - """ - s = self._clone() - - if isinstance(n, slice): - # If negative slicing, abort. - if n.start and n.start < 0 or n.stop and n.stop < 0: - raise ValueError("Search does not support negative slicing.") - # Elasticsearch won't get all results so we default to size: 10 if - # stop not given. - s._extra["from"] = n.start or 0 - s._extra["size"] = max( - 0, n.stop - (n.start or 0) if n.stop is not None else 10 - ) - return s - else: # This is an index lookup, equivalent to slicing by [n:n+1]. - # If negative index, abort. - if n < 0: - raise ValueError("Search does not support negative indexing.") - s._extra["from"] = n - s._extra["size"] = 1 - return s - - @classmethod - def from_dict(cls, d): - """ - Construct a new `Search` instance from a raw dict containing the search - body. Useful when migrating from raw dictionaries. - - Example:: - - s = Search.from_dict({ - "query": { - "bool": { - "must": [...] - } - }, - "aggs": {...} - }) - s = s.filter('term', published=True) - """ - s = cls() - s.update_from_dict(d) - return s - - def _clone(self): - """ - Return a clone of the current search request. Performs a shallow copy - of all the underlying objects. Used internally by most state modifying - APIs. - """ - s = super()._clone() - - s._response_class = self._response_class - s._knn = [knn.copy() for knn in self._knn] - s._rank = self._rank.copy() - s._collapse = self._collapse.copy() - s._sort = self._sort[:] - s._source = copy.copy(self._source) if self._source is not None else None - s._highlight = self._highlight.copy() - s._highlight_opts = self._highlight_opts.copy() - s._suggest = self._suggest.copy() - s._script_fields = self._script_fields.copy() - for x in ("query", "post_filter"): - getattr(s, x)._proxied = getattr(self, x)._proxied - - # copy top-level bucket definitions - if self.aggs._params.get("aggs"): - s.aggs._params = {"aggs": self.aggs._params["aggs"].copy()} - return s - - def response_class(self, cls): - """ - Override the default wrapper used for the response. - """ - s = self._clone() - s._response_class = cls - return s - - def update_from_dict(self, d): - """ - Apply options from a serialized body to the current instance. Modifies - the object in-place. Used mostly by ``from_dict``. - """ - d = d.copy() - if "query" in d: - self.query._proxied = Q(d.pop("query")) - if "post_filter" in d: - self.post_filter._proxied = Q(d.pop("post_filter")) - - aggs = d.pop("aggs", d.pop("aggregations", {})) - if aggs: - self.aggs._params = { - "aggs": {name: A(value) for (name, value) in aggs.items()} - } - if "knn" in d: - self._knn = d.pop("knn") - if isinstance(self._knn, dict): - self._knn = [self._knn] - if "rank" in d: - self._rank = d.pop("rank") - if "collapse" in d: - self._collapse = d.pop("collapse") - if "sort" in d: - self._sort = d.pop("sort") - if "_source" in d: - self._source = d.pop("_source") - if "highlight" in d: - high = d.pop("highlight").copy() - self._highlight = high.pop("fields") - self._highlight_opts = high - if "suggest" in d: - self._suggest = d.pop("suggest") - if "text" in self._suggest: - text = self._suggest.pop("text") - for s in self._suggest.values(): - s.setdefault("text", text) - if "script_fields" in d: - self._script_fields = d.pop("script_fields") - self._extra.update(d) - return self - - def script_fields(self, **kwargs): - """ - Define script fields to be calculated on hits. See - https://www.elastic.co/guide/en/elasticsearch/reference/current/search-request-script-fields.html - for more details. - - Example:: - - s = Search() - s = s.script_fields(times_two="doc['field'].value * 2") - s = s.script_fields( - times_three={ - 'script': { - 'lang': 'painless', - 'source': "doc['field'].value * params.n", - 'params': {'n': 3} - } - } - ) - - """ - s = self._clone() - for name in kwargs: - if isinstance(kwargs[name], str): - kwargs[name] = {"script": kwargs[name]} - s._script_fields.update(kwargs) - return s - - def knn( - self, - field, - k, - num_candidates, - query_vector=None, - query_vector_builder=None, - boost=None, - filter=None, - similarity=None, - ): - """ - Add a k-nearest neighbor (kNN) search. - - :arg field: the name of the vector field to search against - :arg k: number of nearest neighbors to return as top hits - :arg num_candidates: number of nearest neighbor candidates to consider per shard - :arg query_vector: the vector to search for - :arg query_vector_builder: A dictionary indicating how to build a query vector - :arg boost: A floating-point boost factor for kNN scores - :arg filter: query to filter the documents that can match - :arg similarity: the minimum similarity required for a document to be considered a match, as a float value - - Example:: - - s = Search() - s = s.knn(field='embedding', k=5, num_candidates=10, query_vector=vector, - filter=Q('term', category='blog'))) - """ - s = self._clone() - s._knn.append( - { - "field": field, - "k": k, - "num_candidates": num_candidates, - } - ) - if query_vector is None and query_vector_builder is None: - raise ValueError("one of query_vector and query_vector_builder is required") - if query_vector is not None and query_vector_builder is not None: - raise ValueError( - "only one of query_vector and query_vector_builder must be given" - ) - if query_vector is not None: - s._knn[-1]["query_vector"] = query_vector - if query_vector_builder is not None: - s._knn[-1]["query_vector_builder"] = query_vector_builder - if boost is not None: - s._knn[-1]["boost"] = boost - if filter is not None: - if isinstance(filter, Query): - s._knn[-1]["filter"] = filter.to_dict() - else: - s._knn[-1]["filter"] = filter - if similarity is not None: - s._knn[-1]["similarity"] = similarity - return s - - def rank(self, rrf=None): - """ - Defines a method for combining and ranking results sets from a combination - of searches. Requires a minimum of 2 results sets. - - :arg rrf: Set to ``True`` or an options dictionary to set the rank method to reciprocal rank fusion (RRF). - - Example:: - s = Search() - s = s.query('match', content='search text') - s = s.knn(field='embedding', k=5, num_candidates=10, query_vector=vector) - s = s.rank(rrf=True) - - Note: This option is in technical preview and may change in the future. The syntax will likely change before GA. - """ - s = self._clone() - s._rank = {} - if rrf is not None and rrf is not False: - s._rank["rrf"] = {} if rrf is True else rrf - return s - - def source(self, fields=None, **kwargs): - """ - Selectively control how the _source field is returned. - - :arg fields: wildcard string, array of wildcards, or dictionary of includes and excludes - - If ``fields`` is None, the entire document will be returned for - each hit. If fields is a dictionary with keys of 'includes' and/or - 'excludes' the fields will be either included or excluded appropriately. - - Calling this multiple times with the same named parameter will override the - previous values with the new ones. - - Example:: - - s = Search() - s = s.source(includes=['obj1.*'], excludes=["*.description"]) - - s = Search() - s = s.source(includes=['obj1.*']).source(excludes=["*.description"]) - - """ - s = self._clone() - - if fields and kwargs: - raise ValueError("You cannot specify fields and kwargs at the same time.") - - if fields is not None: - s._source = fields - return s - - if kwargs and not isinstance(s._source, dict): - s._source = {} - - for key, value in kwargs.items(): - if value is None: - try: - del s._source[key] - except KeyError: - pass - else: - s._source[key] = value - - return s - - def sort(self, *keys): - """ - Add sorting information to the search request. If called without - arguments it will remove all sort requirements. Otherwise it will - replace them. Acceptable arguments are:: - - 'some.field' - '-some.other.field' - {'different.field': {'any': 'dict'}} - - so for example:: - - s = Search().sort( - 'category', - '-title', - {"price" : {"order" : "asc", "mode" : "avg"}} - ) - - will sort by ``category``, ``title`` (in descending order) and - ``price`` in ascending order using the ``avg`` mode. - - The API returns a copy of the Search object and can thus be chained. - """ - s = self._clone() - s._sort = [] - for k in keys: - if isinstance(k, str) and k.startswith("-"): - if k[1:] == "_score": - raise IllegalOperation("Sorting by `-_score` is not allowed.") - k = {k[1:]: {"order": "desc"}} - s._sort.append(k) - return s - - def collapse(self, field=None, inner_hits=None, max_concurrent_group_searches=None): - """ - Add collapsing information to the search request. - If called without providing ``field``, it will remove all collapse - requirements, otherwise it will replace them with the provided - arguments. - The API returns a copy of the Search object and can thus be chained. - """ - s = self._clone() - s._collapse = {} - - if field is None: - return s - - s._collapse["field"] = field - if inner_hits: - s._collapse["inner_hits"] = inner_hits - if max_concurrent_group_searches: - s._collapse["max_concurrent_group_searches"] = max_concurrent_group_searches - return s - - def highlight_options(self, **kwargs): - """ - Update the global highlighting options used for this request. For - example:: - - s = Search() - s = s.highlight_options(order='score') - """ - s = self._clone() - s._highlight_opts.update(kwargs) - return s - - def highlight(self, *fields, **kwargs): - """ - Request highlighting of some fields. All keyword arguments passed in will be - used as parameters for all the fields in the ``fields`` parameter. Example:: - - Search().highlight('title', 'body', fragment_size=50) - - will produce the equivalent of:: - - { - "highlight": { - "fields": { - "body": {"fragment_size": 50}, - "title": {"fragment_size": 50} - } - } - } - - If you want to have different options for different fields - you can call ``highlight`` twice:: - - Search().highlight('title', fragment_size=50).highlight('body', fragment_size=100) - - which will produce:: - - { - "highlight": { - "fields": { - "body": {"fragment_size": 100}, - "title": {"fragment_size": 50} - } - } - } - - """ - s = self._clone() - for f in fields: - s._highlight[f] = kwargs - return s - - def suggest(self, name, text, **kwargs): - """ - Add a suggestions request to the search. - - :arg name: name of the suggestion - :arg text: text to suggest on - - All keyword arguments will be added to the suggestions body. For example:: - - s = Search() - s = s.suggest('suggestion-1', 'Elasticsearch', term={'field': 'body'}) - """ - s = self._clone() - s._suggest[name] = {"text": text} - s._suggest[name].update(kwargs) - return s - - def to_dict(self, count=False, **kwargs): - """ - Serialize the search into the dictionary that will be sent over as the - request's body. - - :arg count: a flag to specify if we are interested in a body for count - - no aggregations, no pagination bounds etc. - - All additional keyword arguments will be included into the dictionary. - """ - d = {} - - if self.query: - d["query"] = self.query.to_dict() - - if self._knn: - if len(self._knn) == 1: - d["knn"] = self._knn[0] - else: - d["knn"] = self._knn - - if self._rank: - d["rank"] = self._rank - - # count request doesn't care for sorting and other things - if not count: - if self.post_filter: - d["post_filter"] = self.post_filter.to_dict() - - if self.aggs.aggs: - d.update(self.aggs.to_dict()) - - if self._sort: - d["sort"] = self._sort - - if self._collapse: - d["collapse"] = self._collapse - - d.update(recursive_to_dict(self._extra)) - - if self._source not in (None, {}): - d["_source"] = self._source - - if self._highlight: - d["highlight"] = {"fields": self._highlight} - d["highlight"].update(self._highlight_opts) - - if self._suggest: - d["suggest"] = self._suggest - - if self._script_fields: - d["script_fields"] = self._script_fields - - d.update(recursive_to_dict(kwargs)) - return d - - def count(self): - """ - Return the number of hits matching the query and filters. Note that - only the actual number is returned. - """ - if hasattr(self, "_response") and self._response.hits.total.relation == "eq": - return self._response.hits.total.value - - es = get_connection(self._using) - - d = self.to_dict(count=True) - # TODO: failed shards detection - resp = es.count(index=self._index, query=d.get("query", None), **self._params) - return resp["count"] - - def execute(self, ignore_cache=False): - """ - Execute the search and return an instance of ``Response`` wrapping all - the data. - - :arg ignore_cache: if set to ``True``, consecutive calls will hit - ES, while cached result will be ignored. Defaults to `False` - """ - if ignore_cache or not hasattr(self, "_response"): - es = get_connection(self._using) - - self._response = self._response_class( - self, - es.search(index=self._index, body=self.to_dict(), **self._params).body, - ) - return self._response - - def scan(self): - """ - Turn the search into a scan search and return a generator that will - iterate over all the documents matching the query. - - Use ``params`` method to specify any additional arguments you with to - pass to the underlying ``scan`` helper from ``elasticsearch-py`` - - https://elasticsearch-py.readthedocs.io/en/master/helpers.html#elasticsearch.helpers.scan - - """ - es = get_connection(self._using) - - for hit in scan(es, query=self.to_dict(), index=self._index, **self._params): - yield self._get_result(hit) - - def delete(self): - """ - delete() executes the query by delegating to delete_by_query() - """ - - es = get_connection(self._using) - - return AttrDict( - es.delete_by_query(index=self._index, body=self.to_dict(), **self._params) - ) - - -class MultiSearch(Request): - """ - Combine multiple :class:`~elasticsearch_dsl.Search` objects into a single - request. - """ - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self._searches = [] - - def __getitem__(self, key): - return self._searches[key] - - def __iter__(self): - return iter(self._searches) - - def _clone(self): - ms = super()._clone() - ms._searches = self._searches[:] - return ms - - def add(self, search): - """ - Adds a new :class:`~elasticsearch_dsl.Search` object to the request:: - - ms = MultiSearch(index='my-index') - ms = ms.add(Search(doc_type=Category).filter('term', category='python')) - ms = ms.add(Search(doc_type=Blog)) - """ - ms = self._clone() - ms._searches.append(search) - return ms - - def to_dict(self): - out = [] - for s in self._searches: - meta = {} - if s._index: - meta["index"] = s._index - meta.update(s._params) - - out.append(meta) - out.append(s.to_dict()) - - return out - - def execute(self, ignore_cache=False, raise_on_error=True): - """ - Execute the multi search request and return a list of search results. - """ - if ignore_cache or not hasattr(self, "_response"): - es = get_connection(self._using) - - responses = es.msearch( - index=self._index, body=self.to_dict(), **self._params - ) - - out = [] - for s, r in zip(self._searches, responses["responses"]): - if r.get("error", False): - if raise_on_error: - raise ApiError("N/A", meta=responses.meta, body=r) - r = None - else: - r = Response(s, r) - out.append(r) - - self._response = out - - return self._response +from elasticsearch_dsl._async.search import AsyncMultiSearch, AsyncSearch # noqa: F401 +from elasticsearch_dsl._sync.search import MultiSearch, Search # noqa: F401 +from elasticsearch_dsl.search_base import Q # noqa: F401 diff --git a/elasticsearch_dsl/search_base.py b/elasticsearch_dsl/search_base.py new file mode 100644 index 000000000..3241baf80 --- /dev/null +++ b/elasticsearch_dsl/search_base.py @@ -0,0 +1,846 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import collections.abc +import copy + +from .aggs import A, AggBase +from .exceptions import IllegalOperation +from .query import Bool, Q, Query +from .response import Hit, Response +from .utils import DslBase, recursive_to_dict + + +class QueryProxy: + """ + Simple proxy around DSL objects (queries) that can be called + (to add query/post_filter) and also allows attribute access which is proxied to + the wrapped query. + """ + + def __init__(self, search, attr_name): + self._search = search + self._proxied = None + self._attr_name = attr_name + + def __nonzero__(self): + return self._proxied is not None + + __bool__ = __nonzero__ + + def __call__(self, *args, **kwargs): + s = self._search._clone() + + # we cannot use self._proxied since we just cloned self._search and + # need to access the new self on the clone + proxied = getattr(s, self._attr_name) + if proxied._proxied is None: + proxied._proxied = Q(*args, **kwargs) + else: + proxied._proxied &= Q(*args, **kwargs) + + # always return search to be chainable + return s + + def __getattr__(self, attr_name): + return getattr(self._proxied, attr_name) + + def __setattr__(self, attr_name, value): + if not attr_name.startswith("_"): + self._proxied = Q(self._proxied.to_dict()) + setattr(self._proxied, attr_name, value) + super().__setattr__(attr_name, value) + + def __getstate__(self): + return self._search, self._proxied, self._attr_name + + def __setstate__(self, state): + self._search, self._proxied, self._attr_name = state + + +class ProxyDescriptor: + """ + Simple descriptor to enable setting of queries and filters as: + + s = Search() + s.query = Q(...) + + """ + + def __init__(self, name): + self._attr_name = f"_{name}_proxy" + + def __get__(self, instance, owner): + return getattr(instance, self._attr_name) + + def __set__(self, instance, value): + proxy = getattr(instance, self._attr_name) + proxy._proxied = Q(value) + + +class AggsProxy(AggBase, DslBase): + name = "aggs" + + def __init__(self, search): + self._base = self + self._search = search + self._params = {"aggs": {}} + + def to_dict(self): + return super().to_dict().get("aggs", {}) + + +class Request: + def __init__(self, using="default", index=None, doc_type=None, extra=None): + self._using = using + + self._index = None + if isinstance(index, (tuple, list)): + self._index = list(index) + elif index: + self._index = [index] + + self._doc_type = [] + self._doc_type_map = {} + if isinstance(doc_type, (tuple, list)): + self._doc_type.extend(doc_type) + elif isinstance(doc_type, collections.abc.Mapping): + self._doc_type.extend(doc_type.keys()) + self._doc_type_map.update(doc_type) + elif doc_type: + self._doc_type.append(doc_type) + + self._params = {} + self._extra = extra or {} + + def __eq__(self, other): + return ( + isinstance(other, Request) + and other._params == self._params + and other._index == self._index + and other._doc_type == self._doc_type + and other.to_dict() == self.to_dict() + ) + + def __copy__(self): + return self._clone() + + def params(self, **kwargs): + """ + Specify query params to be used when executing the search. All the + keyword arguments will override the current values. See + https://elasticsearch-py.readthedocs.io/en/master/api.html#elasticsearch.Elasticsearch.search + for all available parameters. + + Example:: + + s = Search() + s = s.params(routing='user-1', preference='local') + """ + s = self._clone() + s._params.update(kwargs) + return s + + def index(self, *index): + """ + Set the index for the search. If called empty it will remove all information. + + Example:: + + s = Search() + s = s.index('twitter-2015.01.01', 'twitter-2015.01.02') + s = s.index(['twitter-2015.01.01', 'twitter-2015.01.02']) + """ + # .index() resets + s = self._clone() + if not index: + s._index = None + else: + indexes = [] + for i in index: + if isinstance(i, str): + indexes.append(i) + elif isinstance(i, list): + indexes += i + elif isinstance(i, tuple): + indexes += list(i) + + s._index = (self._index or []) + indexes + + return s + + def _resolve_field(self, path): + for dt in self._doc_type: + if not hasattr(dt, "_index"): + continue + field = dt._index.resolve_field(path) + if field is not None: + return field + + def _resolve_nested(self, hit, parent_class=None): + doc_class = Hit + + nested_path = [] + nesting = hit["_nested"] + while nesting and "field" in nesting: + nested_path.append(nesting["field"]) + nesting = nesting.get("_nested") + nested_path = ".".join(nested_path) + + if hasattr(parent_class, "_index"): + nested_field = parent_class._index.resolve_field(nested_path) + else: + nested_field = self._resolve_field(nested_path) + + if nested_field is not None: + return nested_field._doc_class + + return doc_class + + def _get_result(self, hit, parent_class=None): + doc_class = Hit + dt = hit.get("_type") + + if "_nested" in hit: + doc_class = self._resolve_nested(hit, parent_class) + + elif dt in self._doc_type_map: + doc_class = self._doc_type_map[dt] + + else: + for doc_type in self._doc_type: + if hasattr(doc_type, "_matches") and doc_type._matches(hit): + doc_class = doc_type + break + + for t in hit.get("inner_hits", ()): + hit["inner_hits"][t] = Response( + self, hit["inner_hits"][t], doc_class=doc_class + ) + + callback = getattr(doc_class, "from_es", doc_class) + return callback(hit) + + def doc_type(self, *doc_type, **kwargs): + """ + Set the type to search through. You can supply a single value or + multiple. Values can be strings or subclasses of ``Document``. + + You can also pass in any keyword arguments, mapping a doc_type to a + callback that should be used instead of the Hit class. + + If no doc_type is supplied any information stored on the instance will + be erased. + + Example: + + s = Search().doc_type('product', 'store', User, custom=my_callback) + """ + # .doc_type() resets + s = self._clone() + if not doc_type and not kwargs: + s._doc_type = [] + s._doc_type_map = {} + else: + s._doc_type.extend(doc_type) + s._doc_type.extend(kwargs.keys()) + s._doc_type_map.update(kwargs) + return s + + def using(self, client): + """ + Associate the search request with an elasticsearch client. A fresh copy + will be returned with current instance remaining unchanged. + + :arg client: an instance of ``elasticsearch.Elasticsearch`` to use or + an alias to look up in ``elasticsearch_dsl.connections`` + + """ + s = self._clone() + s._using = client + return s + + def extra(self, **kwargs): + """ + Add extra keys to the request body. Mostly here for backwards + compatibility. + """ + s = self._clone() + if "from_" in kwargs: + kwargs["from"] = kwargs.pop("from_") + s._extra.update(kwargs) + return s + + def _clone(self): + s = self.__class__( + using=self._using, index=self._index, doc_type=self._doc_type + ) + s._doc_type_map = self._doc_type_map.copy() + s._extra = self._extra.copy() + s._params = self._params.copy() + return s + + +class SearchBase(Request): + query = ProxyDescriptor("query") + post_filter = ProxyDescriptor("post_filter") + + def __init__(self, **kwargs): + """ + Search request to elasticsearch. + + :arg using: `Elasticsearch` instance to use + :arg index: limit the search to index + :arg doc_type: only query this type. + + All the parameters supplied (or omitted) at creation type can be later + overridden by methods (`using`, `index` and `doc_type` respectively). + """ + super().__init__(**kwargs) + + self.aggs = AggsProxy(self) + self._sort = [] + self._knn = [] + self._rank = {} + self._collapse = {} + self._source = None + self._highlight = {} + self._highlight_opts = {} + self._suggest = {} + self._script_fields = {} + self._response_class = Response + + self._query_proxy = QueryProxy(self, "query") + self._post_filter_proxy = QueryProxy(self, "post_filter") + + def filter(self, *args, **kwargs): + return self.query(Bool(filter=[Q(*args, **kwargs)])) + + def exclude(self, *args, **kwargs): + return self.query(Bool(filter=[~Q(*args, **kwargs)])) + + def __getitem__(self, n): + """ + Support slicing the `Search` instance for pagination. + + Slicing equates to the from/size parameters. E.g.:: + + s = Search().query(...)[0:25] + + is equivalent to:: + + s = Search().query(...).extra(from_=0, size=25) + + """ + s = self._clone() + + if isinstance(n, slice): + # If negative slicing, abort. + if n.start and n.start < 0 or n.stop and n.stop < 0: + raise ValueError("Search does not support negative slicing.") + # Elasticsearch won't get all results so we default to size: 10 if + # stop not given. + s._extra["from"] = n.start or 0 + s._extra["size"] = max( + 0, n.stop - (n.start or 0) if n.stop is not None else 10 + ) + return s + else: # This is an index lookup, equivalent to slicing by [n:n+1]. + # If negative index, abort. + if n < 0: + raise ValueError("Search does not support negative indexing.") + s._extra["from"] = n + s._extra["size"] = 1 + return s + + @classmethod + def from_dict(cls, d): + """ + Construct a new `Search` instance from a raw dict containing the search + body. Useful when migrating from raw dictionaries. + + Example:: + + s = Search.from_dict({ + "query": { + "bool": { + "must": [...] + } + }, + "aggs": {...} + }) + s = s.filter('term', published=True) + """ + s = cls() + s.update_from_dict(d) + return s + + def _clone(self): + """ + Return a clone of the current search request. Performs a shallow copy + of all the underlying objects. Used internally by most state modifying + APIs. + """ + s = super()._clone() + + s._response_class = self._response_class + s._knn = [knn.copy() for knn in self._knn] + s._rank = self._rank.copy() + s._collapse = self._collapse.copy() + s._sort = self._sort[:] + s._source = copy.copy(self._source) if self._source is not None else None + s._highlight = self._highlight.copy() + s._highlight_opts = self._highlight_opts.copy() + s._suggest = self._suggest.copy() + s._script_fields = self._script_fields.copy() + for x in ("query", "post_filter"): + getattr(s, x)._proxied = getattr(self, x)._proxied + + # copy top-level bucket definitions + if self.aggs._params.get("aggs"): + s.aggs._params = {"aggs": self.aggs._params["aggs"].copy()} + return s + + def response_class(self, cls): + """ + Override the default wrapper used for the response. + """ + s = self._clone() + s._response_class = cls + return s + + def update_from_dict(self, d): + """ + Apply options from a serialized body to the current instance. Modifies + the object in-place. Used mostly by ``from_dict``. + """ + d = d.copy() + if "query" in d: + self.query._proxied = Q(d.pop("query")) + if "post_filter" in d: + self.post_filter._proxied = Q(d.pop("post_filter")) + + aggs = d.pop("aggs", d.pop("aggregations", {})) + if aggs: + self.aggs._params = { + "aggs": {name: A(value) for (name, value) in aggs.items()} + } + if "knn" in d: + self._knn = d.pop("knn") + if isinstance(self._knn, dict): + self._knn = [self._knn] + if "rank" in d: + self._rank = d.pop("rank") + if "collapse" in d: + self._collapse = d.pop("collapse") + if "sort" in d: + self._sort = d.pop("sort") + if "_source" in d: + self._source = d.pop("_source") + if "highlight" in d: + high = d.pop("highlight").copy() + self._highlight = high.pop("fields") + self._highlight_opts = high + if "suggest" in d: + self._suggest = d.pop("suggest") + if "text" in self._suggest: + text = self._suggest.pop("text") + for s in self._suggest.values(): + s.setdefault("text", text) + if "script_fields" in d: + self._script_fields = d.pop("script_fields") + self._extra.update(d) + return self + + def script_fields(self, **kwargs): + """ + Define script fields to be calculated on hits. See + https://www.elastic.co/guide/en/elasticsearch/reference/current/search-request-script-fields.html + for more details. + + Example:: + + s = Search() + s = s.script_fields(times_two="doc['field'].value * 2") + s = s.script_fields( + times_three={ + 'script': { + 'lang': 'painless', + 'source': "doc['field'].value * params.n", + 'params': {'n': 3} + } + } + ) + + """ + s = self._clone() + for name in kwargs: + if isinstance(kwargs[name], str): + kwargs[name] = {"script": kwargs[name]} + s._script_fields.update(kwargs) + return s + + def knn( + self, + field, + k, + num_candidates, + query_vector=None, + query_vector_builder=None, + boost=None, + filter=None, + similarity=None, + ): + """ + Add a k-nearest neighbor (kNN) search. + + :arg field: the name of the vector field to search against + :arg k: number of nearest neighbors to return as top hits + :arg num_candidates: number of nearest neighbor candidates to consider per shard + :arg query_vector: the vector to search for + :arg query_vector_builder: A dictionary indicating how to build a query vector + :arg boost: A floating-point boost factor for kNN scores + :arg filter: query to filter the documents that can match + :arg similarity: the minimum similarity required for a document to be considered a match, as a float value + + Example:: + + s = Search() + s = s.knn(field='embedding', k=5, num_candidates=10, query_vector=vector, + filter=Q('term', category='blog'))) + """ + s = self._clone() + s._knn.append( + { + "field": field, + "k": k, + "num_candidates": num_candidates, + } + ) + if query_vector is None and query_vector_builder is None: + raise ValueError("one of query_vector and query_vector_builder is required") + if query_vector is not None and query_vector_builder is not None: + raise ValueError( + "only one of query_vector and query_vector_builder must be given" + ) + if query_vector is not None: + s._knn[-1]["query_vector"] = query_vector + if query_vector_builder is not None: + s._knn[-1]["query_vector_builder"] = query_vector_builder + if boost is not None: + s._knn[-1]["boost"] = boost + if filter is not None: + if isinstance(filter, Query): + s._knn[-1]["filter"] = filter.to_dict() + else: + s._knn[-1]["filter"] = filter + if similarity is not None: + s._knn[-1]["similarity"] = similarity + return s + + def rank(self, rrf=None): + """ + Defines a method for combining and ranking results sets from a combination + of searches. Requires a minimum of 2 results sets. + + :arg rrf: Set to ``True`` or an options dictionary to set the rank method to reciprocal rank fusion (RRF). + + Example:: + + s = Search() + s = s.query('match', content='search text') + s = s.knn(field='embedding', k=5, num_candidates=10, query_vector=vector) + s = s.rank(rrf=True) + + Note: This option is in technical preview and may change in the future. The syntax will likely change before GA. + """ + s = self._clone() + s._rank = {} + if rrf is not None and rrf is not False: + s._rank["rrf"] = {} if rrf is True else rrf + return s + + def source(self, fields=None, **kwargs): + """ + Selectively control how the _source field is returned. + + :arg fields: wildcard string, array of wildcards, or dictionary of includes and excludes + + If ``fields`` is None, the entire document will be returned for + each hit. If fields is a dictionary with keys of 'includes' and/or + 'excludes' the fields will be either included or excluded appropriately. + + Calling this multiple times with the same named parameter will override the + previous values with the new ones. + + Example:: + + s = Search() + s = s.source(includes=['obj1.*'], excludes=["*.description"]) + + s = Search() + s = s.source(includes=['obj1.*']).source(excludes=["*.description"]) + + """ + s = self._clone() + + if fields and kwargs: + raise ValueError("You cannot specify fields and kwargs at the same time.") + + if fields is not None: + s._source = fields + return s + + if kwargs and not isinstance(s._source, dict): + s._source = {} + + for key, value in kwargs.items(): + if value is None: + try: + del s._source[key] + except KeyError: + pass + else: + s._source[key] = value + + return s + + def sort(self, *keys): + """ + Add sorting information to the search request. If called without + arguments it will remove all sort requirements. Otherwise it will + replace them. Acceptable arguments are:: + + 'some.field' + '-some.other.field' + {'different.field': {'any': 'dict'}} + + so for example:: + + s = Search().sort( + 'category', + '-title', + {"price" : {"order" : "asc", "mode" : "avg"}} + ) + + will sort by ``category``, ``title`` (in descending order) and + ``price`` in ascending order using the ``avg`` mode. + + The API returns a copy of the Search object and can thus be chained. + """ + s = self._clone() + s._sort = [] + for k in keys: + if isinstance(k, str) and k.startswith("-"): + if k[1:] == "_score": + raise IllegalOperation("Sorting by `-_score` is not allowed.") + k = {k[1:]: {"order": "desc"}} + s._sort.append(k) + return s + + def collapse(self, field=None, inner_hits=None, max_concurrent_group_searches=None): + """ + Add collapsing information to the search request. + If called without providing ``field``, it will remove all collapse + requirements, otherwise it will replace them with the provided + arguments. + The API returns a copy of the Search object and can thus be chained. + """ + s = self._clone() + s._collapse = {} + + if field is None: + return s + + s._collapse["field"] = field + if inner_hits: + s._collapse["inner_hits"] = inner_hits + if max_concurrent_group_searches: + s._collapse["max_concurrent_group_searches"] = max_concurrent_group_searches + return s + + def highlight_options(self, **kwargs): + """ + Update the global highlighting options used for this request. For + example:: + + s = Search() + s = s.highlight_options(order='score') + """ + s = self._clone() + s._highlight_opts.update(kwargs) + return s + + def highlight(self, *fields, **kwargs): + """ + Request highlighting of some fields. All keyword arguments passed in will be + used as parameters for all the fields in the ``fields`` parameter. Example:: + + Search().highlight('title', 'body', fragment_size=50) + + will produce the equivalent of:: + + { + "highlight": { + "fields": { + "body": {"fragment_size": 50}, + "title": {"fragment_size": 50} + } + } + } + + If you want to have different options for different fields + you can call ``highlight`` twice:: + + Search().highlight('title', fragment_size=50).highlight('body', fragment_size=100) + + which will produce:: + + { + "highlight": { + "fields": { + "body": {"fragment_size": 100}, + "title": {"fragment_size": 50} + } + } + } + + """ + s = self._clone() + for f in fields: + s._highlight[f] = kwargs + return s + + def suggest(self, name, text, **kwargs): + """ + Add a suggestions request to the search. + + :arg name: name of the suggestion + :arg text: text to suggest on + + All keyword arguments will be added to the suggestions body. For example:: + + s = Search() + s = s.suggest('suggestion-1', 'Elasticsearch', term={'field': 'body'}) + """ + s = self._clone() + s._suggest[name] = {"text": text} + s._suggest[name].update(kwargs) + return s + + def to_dict(self, count=False, **kwargs): + """ + Serialize the search into the dictionary that will be sent over as the + request's body. + + :arg count: a flag to specify if we are interested in a body for count - + no aggregations, no pagination bounds etc. + + All additional keyword arguments will be included into the dictionary. + """ + d = {} + + if self.query: + d["query"] = self.query.to_dict() + + if self._knn: + if len(self._knn) == 1: + d["knn"] = self._knn[0] + else: + d["knn"] = self._knn + + if self._rank: + d["rank"] = self._rank + + # count request doesn't care for sorting and other things + if not count: + if self.post_filter: + d["post_filter"] = self.post_filter.to_dict() + + if self.aggs.aggs: + d.update(self.aggs.to_dict()) + + if self._sort: + d["sort"] = self._sort + + if self._collapse: + d["collapse"] = self._collapse + + d.update(recursive_to_dict(self._extra)) + + if self._source not in (None, {}): + d["_source"] = self._source + + if self._highlight: + d["highlight"] = {"fields": self._highlight} + d["highlight"].update(self._highlight_opts) + + if self._suggest: + d["suggest"] = self._suggest + + if self._script_fields: + d["script_fields"] = self._script_fields + + d.update(recursive_to_dict(kwargs)) + return d + + +class MultiSearchBase(Request): + """ + Combine multiple :class:`~elasticsearch_dsl.Search` objects into a single + request. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._searches = [] + + def __getitem__(self, key): + return self._searches[key] + + def __iter__(self): + return iter(self._searches) + + def _clone(self): + ms = super()._clone() + ms._searches = self._searches[:] + return ms + + def add(self, search): + """ + Adds a new :class:`~elasticsearch_dsl.Search` object to the request:: + + ms = MultiSearch(index='my-index') + ms = ms.add(Search(doc_type=Category).filter('term', category='python')) + ms = ms.add(Search(doc_type=Blog)) + """ + ms = self._clone() + ms._searches.append(search) + return ms + + def to_dict(self): + out = [] + for s in self._searches: + meta = {} + if s._index: + meta["index"] = s._index + meta.update(s._params) + + out.append(meta) + out.append(s.to_dict()) + + return out diff --git a/elasticsearch_dsl/update_by_query.py b/elasticsearch_dsl/update_by_query.py index dd11856b1..fdff22bc8 100644 --- a/elasticsearch_dsl/update_by_query.py +++ b/elasticsearch_dsl/update_by_query.py @@ -15,145 +15,5 @@ # specific language governing permissions and limitations # under the License. -from .connections import get_connection -from .query import Bool, Q -from .response import UpdateByQueryResponse -from .search import ProxyDescriptor, QueryProxy, Request -from .utils import recursive_to_dict - - -class UpdateByQuery(Request): - query = ProxyDescriptor("query") - - def __init__(self, **kwargs): - """ - Update by query request to elasticsearch. - - :arg using: `Elasticsearch` instance to use - :arg index: limit the search to index - :arg doc_type: only query this type. - - All the parameters supplied (or omitted) at creation type can be later - overridden by methods (`using`, `index` and `doc_type` respectively). - - """ - super().__init__(**kwargs) - self._response_class = UpdateByQueryResponse - self._script = {} - self._query_proxy = QueryProxy(self, "query") - - def filter(self, *args, **kwargs): - return self.query(Bool(filter=[Q(*args, **kwargs)])) - - def exclude(self, *args, **kwargs): - return self.query(Bool(filter=[~Q(*args, **kwargs)])) - - @classmethod - def from_dict(cls, d): - """ - Construct a new `UpdateByQuery` instance from a raw dict containing the search - body. Useful when migrating from raw dictionaries. - - Example:: - - ubq = UpdateByQuery.from_dict({ - "query": { - "bool": { - "must": [...] - } - }, - "script": {...} - }) - ubq = ubq.filter('term', published=True) - """ - u = cls() - u.update_from_dict(d) - return u - - def _clone(self): - """ - Return a clone of the current search request. Performs a shallow copy - of all the underlying objects. Used internally by most state modifying - APIs. - """ - ubq = super()._clone() - - ubq._response_class = self._response_class - ubq._script = self._script.copy() - ubq.query._proxied = self.query._proxied - return ubq - - def response_class(self, cls): - """ - Override the default wrapper used for the response. - """ - ubq = self._clone() - ubq._response_class = cls - return ubq - - def update_from_dict(self, d): - """ - Apply options from a serialized body to the current instance. Modifies - the object in-place. Used mostly by ``from_dict``. - """ - d = d.copy() - if "query" in d: - self.query._proxied = Q(d.pop("query")) - if "script" in d: - self._script = d.pop("script") - self._extra.update(d) - return self - - def script(self, **kwargs): - """ - Define update action to take: - https://www.elastic.co/guide/en/elasticsearch/reference/current/modules-scripting-using.html - for more details. - - Note: the API only accepts a single script, so - calling the script multiple times will overwrite. - - Example:: - - ubq = Search() - ubq = ubq.script(source="ctx._source.likes++"") - ubq = ubq.script(source="ctx._source.likes += params.f"", - lang="expression", - params={'f': 3}) - """ - ubq = self._clone() - if ubq._script: - ubq._script = {} - ubq._script.update(kwargs) - return ubq - - def to_dict(self, **kwargs): - """ - Serialize the search into the dictionary that will be sent over as the - request'ubq body. - - All additional keyword arguments will be included into the dictionary. - """ - d = {} - if self.query: - d["query"] = self.query.to_dict() - - if self._script: - d["script"] = self._script - - d.update(recursive_to_dict(self._extra)) - d.update(recursive_to_dict(kwargs)) - return d - - def execute(self): - """ - Execute the search and return an instance of ``Response`` wrapping all - the data. - """ - es = get_connection(self._using) - - self._response = self._response_class( - self, - es.update_by_query(index=self._index, **self.to_dict(), **self._params), - ) - return self._response +from ._async.update_by_query import AsyncUpdateByQuery # noqa: F401 +from ._sync.update_by_query import UpdateByQuery # noqa: F401 diff --git a/elasticsearch_dsl/update_by_query_base.py b/elasticsearch_dsl/update_by_query_base.py new file mode 100644 index 000000000..204241ba7 --- /dev/null +++ b/elasticsearch_dsl/update_by_query_base.py @@ -0,0 +1,145 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from .query import Bool, Q +from .response import UpdateByQueryResponse +from .search_base import ProxyDescriptor, QueryProxy, Request +from .utils import recursive_to_dict + + +class UpdateByQueryBase(Request): + query = ProxyDescriptor("query") + + def __init__(self, **kwargs): + """ + Update by query request to elasticsearch. + + :arg using: `Elasticsearch` instance to use + :arg index: limit the search to index + :arg doc_type: only query this type. + + All the parameters supplied (or omitted) at creation type can be later + overridden by methods (`using`, `index` and `doc_type` respectively). + + """ + super().__init__(**kwargs) + self._response_class = UpdateByQueryResponse + self._script = {} + self._query_proxy = QueryProxy(self, "query") + + def filter(self, *args, **kwargs): + return self.query(Bool(filter=[Q(*args, **kwargs)])) + + def exclude(self, *args, **kwargs): + return self.query(Bool(filter=[~Q(*args, **kwargs)])) + + @classmethod + def from_dict(cls, d): + """ + Construct a new `UpdateByQuery` instance from a raw dict containing the search + body. Useful when migrating from raw dictionaries. + + Example:: + + ubq = UpdateByQuery.from_dict({ + "query": { + "bool": { + "must": [...] + } + }, + "script": {...} + }) + ubq = ubq.filter('term', published=True) + """ + u = cls() + u.update_from_dict(d) + return u + + def _clone(self): + """ + Return a clone of the current search request. Performs a shallow copy + of all the underlying objects. Used internally by most state modifying + APIs. + """ + ubq = super()._clone() + + ubq._response_class = self._response_class + ubq._script = self._script.copy() + ubq.query._proxied = self.query._proxied + return ubq + + def response_class(self, cls): + """ + Override the default wrapper used for the response. + """ + ubq = self._clone() + ubq._response_class = cls + return ubq + + def update_from_dict(self, d): + """ + Apply options from a serialized body to the current instance. Modifies + the object in-place. Used mostly by ``from_dict``. + """ + d = d.copy() + if "query" in d: + self.query._proxied = Q(d.pop("query")) + if "script" in d: + self._script = d.pop("script") + self._extra.update(d) + return self + + def script(self, **kwargs): + """ + Define update action to take: + https://www.elastic.co/guide/en/elasticsearch/reference/current/modules-scripting-using.html + for more details. + + Note: the API only accepts a single script, so + calling the script multiple times will overwrite. + + Example:: + + ubq = Search() + ubq = ubq.script(source="ctx._source.likes++"") + ubq = ubq.script(source="ctx._source.likes += params.f"", + lang="expression", + params={'f': 3}) + """ + ubq = self._clone() + if ubq._script: + ubq._script = {} + ubq._script.update(kwargs) + return ubq + + def to_dict(self, **kwargs): + """ + Serialize the search into the dictionary that will be sent over as the + request'ubq body. + + All additional keyword arguments will be included into the dictionary. + """ + d = {} + if self.query: + d["query"] = self.query.to_dict() + + if self._script: + d["script"] = self._script + + d.update(recursive_to_dict(self._extra)) + d.update(recursive_to_dict(kwargs)) + return d diff --git a/examples/alias_migration.py b/examples/alias_migration.py index c56627b9a..e407990c4 100644 --- a/examples/alias_migration.py +++ b/examples/alias_migration.py @@ -35,6 +35,7 @@ will have index set to the concrete index whereas the class refers to the alias. """ +import os from datetime import datetime from fnmatch import fnmatch @@ -124,9 +125,9 @@ def migrate(move_data=True, update_alias=True): ) -if __name__ == "__main__": +def main(): # initiate the default connection to elasticsearch - connections.create_connection() + connections.create_connection(hosts=[os.environ["ELASTICSEARCH_URL"]]) # create the empty index setup() @@ -142,3 +143,10 @@ def migrate(move_data=True, update_alias=True): # create new index migrate() + + # close the connection + connections.get_connection().close() + + +if __name__ == "__main__": + main() diff --git a/examples/async/alias_migration.py b/examples/async/alias_migration.py new file mode 100644 index 000000000..07bb995a5 --- /dev/null +++ b/examples/async/alias_migration.py @@ -0,0 +1,153 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Simple example with a single Document demonstrating how schema can be managed, +including upgrading with reindexing. + +Key concepts: + + * setup() function to first initialize the schema (as index template) in + elasticsearch. Can be called any time (recommended with every deploy of + your app). + + * migrate() function to be called any time when the schema changes - it + will create a new index (by incrementing the version) and update the alias. + By default it will also (before flipping the alias) move the data from the + previous index to the new one. + + * BlogPost._matches() class method is required for this code to work since + otherwise BlogPost will not be used to deserialize the documents as those + will have index set to the concrete index whereas the class refers to the + alias. +""" +import asyncio +import os +from datetime import datetime +from fnmatch import fnmatch + +from elasticsearch_dsl import AsyncDocument, Date, Keyword, Text, async_connections + +ALIAS = "test-blog" +PATTERN = ALIAS + "-*" + + +class BlogPost(AsyncDocument): + title = Text() + published = Date() + tags = Keyword(multi=True) + content = Text() + + def is_published(self): + return self.published and datetime.now() > self.published + + @classmethod + def _matches(cls, hit): + # override _matches to match indices in a pattern instead of just ALIAS + # hit is the raw dict as returned by elasticsearch + return fnmatch(hit["_index"], PATTERN) + + class Index: + # we will use an alias instead of the index + name = ALIAS + # set settings and possibly other attributes of the index like + # analyzers + settings = {"number_of_shards": 1, "number_of_replicas": 0} + + +async def setup(): + """ + Create the index template in elasticsearch specifying the mappings and any + settings to be used. This can be run at any time, ideally at every new code + deploy. + """ + # create an index template + index_template = BlogPost._index.as_template(ALIAS, PATTERN) + # upload the template into elasticsearch + # potentially overriding the one already there + await index_template.save() + + # create the first index if it doesn't exist + if not await BlogPost._index.exists(): + await migrate(move_data=False) + + +async def migrate(move_data=True, update_alias=True): + """ + Upgrade function that creates a new index for the data. Optionally it also can + (and by default will) reindex previous copy of the data into the new index + (specify ``move_data=False`` to skip this step) and update the alias to + point to the latest index (set ``update_alias=False`` to skip). + + Note that while this function is running the application can still perform + any and all searches without any loss of functionality. It should, however, + not perform any writes at this time as those might be lost. + """ + # construct a new index name by appending current timestamp + next_index = PATTERN.replace("*", datetime.now().strftime("%Y%m%d%H%M%S%f")) + + # get the low level connection + es = async_connections.get_connection() + + # create new index, it will use the settings from the template + await es.indices.create(index=next_index) + + if move_data: + # move data from current alias to the new index + await es.options(request_timeout=3600).reindex( + body={"source": {"index": ALIAS}, "dest": {"index": next_index}} + ) + # refresh the index to make the changes visible + await es.indices.refresh(index=next_index) + + if update_alias: + # repoint the alias to point to the newly created index + await es.indices.update_aliases( + body={ + "actions": [ + {"remove": {"alias": ALIAS, "index": PATTERN}}, + {"add": {"alias": ALIAS, "index": next_index}}, + ] + } + ) + + +async def main(): + # initiate the default connection to elasticsearch + async_connections.create_connection(hosts=[os.environ["ELASTICSEARCH_URL"]]) + + # create the empty index + await setup() + + # create a new document + bp = BlogPost( + _id=0, + title="Hello World!", + tags=["testing", "dummy"], + content=open(__file__).read(), + ) + await bp.save(refresh=True) + + # create new index + await migrate() + + # close the connection + await async_connections.get_connection().close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/async/completion.py b/examples/async/completion.py new file mode 100644 index 000000000..96d60724e --- /dev/null +++ b/examples/async/completion.py @@ -0,0 +1,107 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example ``Document`` with completion suggester. + +In the ``Person`` class we index the person's name to allow auto completing in +any order ("first last", "middle last first", ...). For the weight we use a +value from the ``popularity`` field which is a long. + +To make the suggestions work in different languages we added a custom analyzer +that does ascii folding. +""" + +import asyncio +import os +from itertools import permutations + +from elasticsearch_dsl import ( + AsyncDocument, + Completion, + Keyword, + Long, + Text, + analyzer, + async_connections, + token_filter, +) + +# custom analyzer for names +ascii_fold = analyzer( + "ascii_fold", + # we don't want to split O'Brian or Toulouse-Lautrec + tokenizer="whitespace", + filter=["lowercase", token_filter("ascii_fold", "asciifolding")], +) + + +class Person(AsyncDocument): + name = Text(fields={"keyword": Keyword()}) + popularity = Long() + + # copletion field with a custom analyzer + suggest = Completion(analyzer=ascii_fold) + + def clean(self): + """ + Automatically construct the suggestion input and weight by taking all + possible permutation of Person's name as ``input`` and taking their + popularity as ``weight``. + """ + self.suggest = { + "input": [" ".join(p) for p in permutations(self.name.split())], + "weight": self.popularity, + } + + class Index: + name = "test-suggest" + settings = {"number_of_shards": 1, "number_of_replicas": 0} + + +async def main(): + # initiate the default connection to elasticsearch + async_connections.create_connection(hosts=[os.environ["ELASTICSEARCH_URL"]]) + + # create the empty index + await Person.init() + + # index some sample data + for id, (name, popularity) in enumerate( + [("Henri de Toulouse-Lautrec", 42), ("Jára Cimrman", 124)] + ): + await Person(_id=id, name=name, popularity=popularity).save() + + # refresh index manually to make changes live + await Person._index.refresh() + + # run some suggestions + for text in ("já", "Jara Cimr", "tou", "de hen"): + s = Person.search() + s = s.suggest("auto_complete", text, completion={"field": "suggest"}) + response = await s.execute() + + # print out all the options we got + for option in response.suggest.auto_complete[0].options: + print("%10s: %25s (%d)" % (text, option._source.name, option._score)) + + # close the connection + await async_connections.get_connection().close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/async/composite_agg.py b/examples/async/composite_agg.py new file mode 100644 index 000000000..52726cb46 --- /dev/null +++ b/examples/async/composite_agg.py @@ -0,0 +1,68 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import asyncio +import os + +from elasticsearch_dsl import A, AsyncSearch, async_connections + + +async def scan_aggs(search, source_aggs, inner_aggs={}, size=10): + """ + Helper function used to iterate over all possible bucket combinations of + ``source_aggs``, returning results of ``inner_aggs`` for each. Uses the + ``composite`` aggregation under the hood to perform this. + """ + + async def run_search(**kwargs): + s = search[:0] + s.aggs.bucket("comp", "composite", sources=source_aggs, size=size, **kwargs) + for agg_name, agg in inner_aggs.items(): + s.aggs["comp"][agg_name] = agg + return await s.execute() + + response = await run_search() + while response.aggregations.comp.buckets: + for b in response.aggregations.comp.buckets: + yield b + if "after_key" in response.aggregations.comp: + after = response.aggregations.comp.after_key + else: + after = response.aggregations.comp.buckets[-1].key + response = await run_search(after=after) + + +async def main(): + # initiate the default connection to elasticsearch + async_connections.create_connection(hosts=[os.environ["ELASTICSEARCH_URL"]]) + + async for b in scan_aggs( + AsyncSearch(index="git"), + {"files": A("terms", field="files")}, + {"first_seen": A("min", field="committed_date")}, + ): + print( + "File %s has been modified %d times, first seen at %s." + % (b.key.files, b.doc_count, b.first_seen.value_as_string) + ) + + # close the connection + await async_connections.get_connection().close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/async/parent_child.py b/examples/async/parent_child.py new file mode 100644 index 000000000..4ee61d4af --- /dev/null +++ b/examples/async/parent_child.py @@ -0,0 +1,255 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Complex data model example modeling stackoverflow-like data. + +It is used to showcase several key features of elasticsearch-dsl: + + * Object and Nested fields: see User and Comment classes and fields they + are used in + + * method add_comment is used to add comments + + * Parent/Child relationship + + * See the Join field on Post creating the relationship between Question + and Answer + + * Meta.matches allows the hits from same index to be wrapped in proper + classes + + * to see how child objects are created see Question.add_answer + + * Question.search_answers shows how to query for children of a + particular parent + +""" +import asyncio +import os +from datetime import datetime + +from elasticsearch_dsl import ( + AsyncDocument, + Boolean, + Date, + InnerDoc, + Join, + Keyword, + Long, + Nested, + Object, + Text, + async_connections, +) + + +class User(InnerDoc): + """ + Class used to represent a denormalized user stored on other objects. + """ + + id = Long(required=True) + signed_up = Date() + username = Text(fields={"keyword": Keyword()}, required=True) + email = Text(fields={"keyword": Keyword()}) + location = Text(fields={"keyword": Keyword()}) + + +class Comment(InnerDoc): + """ + Class wrapper for nested comment objects. + """ + + author = Object(User, required=True) + created = Date(required=True) + content = Text(required=True) + + +class Post(AsyncDocument): + """ + Base class for Question and Answer containing the common fields. + """ + + author = Object(User, required=True) + created = Date(required=True) + body = Text(required=True) + comments = Nested(Comment) + question_answer = Join(relations={"question": "answer"}) + + @classmethod + def _matches(cls, hit): + # Post is an abstract class, make sure it never gets used for + # deserialization + return False + + class Index: + name = "test-qa-site" + settings = { + "number_of_shards": 1, + "number_of_replicas": 0, + } + + async def add_comment(self, user, content, created=None, commit=True): + c = Comment(author=user, content=content, created=created or datetime.now()) + self.comments.append(c) + if commit: + await self.save() + return c + + async def save(self, **kwargs): + # if there is no date, use now + if self.created is None: + self.created = datetime.now() + return await super().save(**kwargs) + + +class Question(Post): + # use multi True so that .tags will return empty list if not present + tags = Keyword(multi=True) + title = Text(fields={"keyword": Keyword()}) + + @classmethod + def _matches(cls, hit): + """Use Question class for parent documents""" + return hit["_source"]["question_answer"] == "question" + + @classmethod + def search(cls, **kwargs): + return cls._index.search(**kwargs).filter("term", question_answer="question") + + async def add_answer(self, user, body, created=None, accepted=False, commit=True): + answer = Answer( + # required make sure the answer is stored in the same shard + _routing=self.meta.id, + # since we don't have explicit index, ensure same index as self + _index=self.meta.index, + # set up the parent/child mapping + question_answer={"name": "answer", "parent": self.meta.id}, + # pass in the field values + author=user, + created=created, + body=body, + accepted=accepted, + ) + if commit: + await answer.save() + return answer + + def search_answers(self): + # search only our index + s = Answer.search() + # filter for answers belonging to us + s = s.filter("parent_id", type="answer", id=self.meta.id) + # add routing to only go to specific shard + s = s.params(routing=self.meta.id) + return s + + async def get_answers(self): + """ + Get answers either from inner_hits already present or by searching + elasticsearch. + """ + if "inner_hits" in self.meta and "answer" in self.meta.inner_hits: + return self.meta.inner_hits.answer.hits + return [a async for a in self.search_answers()] + + async def save(self, **kwargs): + self.question_answer = "question" + return await super().save(**kwargs) + + +class Answer(Post): + is_accepted = Boolean() + + @classmethod + def _matches(cls, hit): + """Use Answer class for child documents with child name 'answer'""" + return ( + isinstance(hit["_source"]["question_answer"], dict) + and hit["_source"]["question_answer"].get("name") == "answer" + ) + + @classmethod + def search(cls, **kwargs): + return cls._index.search(**kwargs).exclude("term", question_answer="question") + + async def get_question(self): + # cache question in self.meta + # any attributes set on self would be interpretted as fields + if "question" not in self.meta: + self.meta.question = await Question.get( + id=self.question_answer.parent, index=self.meta.index + ) + return self.meta.question + + async def save(self, **kwargs): + # set routing to parents id automatically + self.meta.routing = self.question_answer.parent + return await super().save(**kwargs) + + +async def setup(): + """Create an IndexTemplate and save it into elasticsearch.""" + index_template = Post._index.as_template("base") + await index_template.save() + + +async def main(): + # initiate the default connection to elasticsearch + async_connections.create_connection(hosts=[os.environ["ELASTICSEARCH_URL"]]) + + # create index + await setup() + + # user objects to use + nick = User( + id=47, + signed_up=datetime(2017, 4, 3), + username="fxdgear", + email="nick.lang@elastic.co", + location="Colorado", + ) + honza = User( + id=42, + signed_up=datetime(2013, 4, 3), + username="honzakral", + email="honza@elastic.co", + location="Prague", + ) + + # create a question object + question = Question( + _id=1, + author=nick, + tags=["elasticsearch", "python"], + title="How do I use elasticsearch from Python?", + body=""" + I want to use elasticsearch, how do I do it from Python? + """, + ) + await question.save() + answer = await question.add_answer(honza, "Just use `elasticsearch-py`!") + + # close the connection + await async_connections.get_connection().close() + + return answer + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/async/percolate.py b/examples/async/percolate.py new file mode 100644 index 000000000..4b075cd8d --- /dev/null +++ b/examples/async/percolate.py @@ -0,0 +1,106 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import asyncio +import os + +from elasticsearch_dsl import ( + AsyncDocument, + AsyncSearch, + Keyword, + Percolator, + Q, + Text, + async_connections, +) + + +class BlogPost(AsyncDocument): + """ + Blog posts that will be automatically tagged based on percolation queries. + """ + + content = Text() + tags = Keyword(multi=True) + + class Index: + name = "test-blogpost" + + async def add_tags(self): + # run a percolation to automatically tag the blog post. + s = AsyncSearch(index="test-percolator") + s = s.query( + "percolate", field="query", index=self._get_index(), document=self.to_dict() + ) + + # collect all the tags from matched percolators + async for percolator in s: + self.tags.extend(percolator.tags) + + # make sure tags are unique + self.tags = list(set(self.tags)) + + async def save(self, **kwargs): + await self.add_tags() + return await super().save(**kwargs) + + +class PercolatorDoc(AsyncDocument): + """ + Document class used for storing the percolation queries. + """ + + # relevant fields from BlogPost must be also present here for the queries + # to be able to use them. Another option would be to use document + # inheritance but save() would have to be reset to normal behavior. + content = Text() + + # the percolator query to be run against the doc + query = Percolator() + # list of tags to append to a document + tags = Keyword(multi=True) + + class Index: + name = "test-percolator" + settings = {"number_of_shards": 1, "number_of_replicas": 0} + + +async def setup(): + # create the percolator index if it doesn't exist + if not await PercolatorDoc._index.exists(): + await PercolatorDoc.init() + + # register a percolation query looking for documents about python + await PercolatorDoc( + _id="python", + tags=["programming", "development", "python"], + query=Q("match", content="python"), + ).save(refresh=True) + + +async def main(): + # initiate the default connection to elasticsearch + async_connections.create_connection(hosts=[os.environ["ELASTICSEARCH_URL"]]) + + await setup() + + # close the connection + await async_connections.get_connection().close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/async/search_as_you_type.py b/examples/async/search_as_you_type.py new file mode 100644 index 000000000..3b76622ed --- /dev/null +++ b/examples/async/search_as_you_type.py @@ -0,0 +1,102 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example ``Document`` with search_as_you_type field datatype and how to search it. + +When creating a field with search_as_you_type datatype ElasticSearch creates additional +subfields to enable efficient as-you-type completion, matching terms at any position +within the input. + +To custom analyzer with ascii folding allow search to work in different languages. +""" + +import asyncio +import os + +from elasticsearch_dsl import ( + AsyncDocument, + SearchAsYouType, + analyzer, + async_connections, + token_filter, +) +from elasticsearch_dsl.query import MultiMatch + +# custom analyzer for names +ascii_fold = analyzer( + "ascii_fold", + # we don't want to split O'Brian or Toulouse-Lautrec + tokenizer="whitespace", + filter=["lowercase", token_filter("ascii_fold", "asciifolding")], +) + + +class Person(AsyncDocument): + name = SearchAsYouType(max_shingle_size=3) + + class Index: + name = "test-search-as-you-type" + settings = {"number_of_shards": 1, "number_of_replicas": 0} + + +async def main(): + # initiate the default connection to elasticsearch + async_connections.create_connection(hosts=[os.environ["ELASTICSEARCH_URL"]]) + + # create the empty index + await Person.init() + + import pprint + + pprint.pprint(Person().to_dict(), indent=2) + + # index some sample data + names = [ + "Andy Warhol", + "Alphonse Mucha", + "Henri de Toulouse-Lautrec", + "Jára Cimrman", + ] + for id, name in enumerate(names): + await Person(_id=id, name=name).save() + + # refresh index manually to make changes live + await Person._index.refresh() + + # run some suggestions + for text in ("já", "Cimr", "toulouse", "Henri Tou", "a"): + s = Person.search() + + s.query = MultiMatch( + query=text, + type="bool_prefix", + fields=["name", "name._2gram", "name._3gram"], + ) + + response = await s.execute() + + # print out all the options we got + for h in response: + print("%15s: %25s" % (text, h.name)) + + # close the connection + await async_connections.get_connection().close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/completion.py b/examples/completion.py index 8d910f1a2..4e68ce0ca 100644 --- a/examples/completion.py +++ b/examples/completion.py @@ -26,6 +26,7 @@ that does ascii folding. """ +import os from itertools import permutations from elasticsearch_dsl import ( @@ -71,9 +72,9 @@ class Index: settings = {"number_of_shards": 1, "number_of_replicas": 0} -if __name__ == "__main__": +def main(): # initiate the default connection to elasticsearch - connections.create_connection() + connections.create_connection(hosts=[os.environ["ELASTICSEARCH_URL"]]) # create the empty index Person.init() @@ -96,3 +97,10 @@ class Index: # print out all the options we got for option in response.suggest.auto_complete[0].options: print("%10s: %25s (%d)" % (text, option._source.name, option._score)) + + # close the connection + connections.get_connection().close() + + +if __name__ == "__main__": + main() diff --git a/examples/composite_agg.py b/examples/composite_agg.py index a9fab676c..753a9cbaa 100644 --- a/examples/composite_agg.py +++ b/examples/composite_agg.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +import os from elasticsearch_dsl import A, Search, connections @@ -35,7 +36,8 @@ def run_search(**kwargs): response = run_search() while response.aggregations.comp.buckets: - yield from response.aggregations.comp.buckets + for b in response.aggregations.comp.buckets: + yield b if "after_key" in response.aggregations.comp: after = response.aggregations.comp.after_key else: @@ -43,9 +45,9 @@ def run_search(**kwargs): response = run_search(after=after) -if __name__ == "__main__": +def main(): # initiate the default connection to elasticsearch - connections.create_connection() + connections.create_connection(hosts=[os.environ["ELASTICSEARCH_URL"]]) for b in scan_aggs( Search(index="git"), @@ -56,3 +58,10 @@ def run_search(**kwargs): "File %s has been modified %d times, first seen at %s." % (b.key.files, b.doc_count, b.first_seen.value_as_string) ) + + # close the connection + connections.get_connection().close() + + +if __name__ == "__main__": + main() diff --git a/examples/parent_child.py b/examples/parent_child.py index df832d650..3d0e369cc 100644 --- a/examples/parent_child.py +++ b/examples/parent_child.py @@ -39,6 +39,7 @@ particular parent """ +import os from datetime import datetime from elasticsearch_dsl import ( @@ -164,7 +165,7 @@ def get_answers(self): """ if "inner_hits" in self.meta and "answer" in self.meta.inner_hits: return self.meta.inner_hits.answer.hits - return list(self.search_answers()) + return [a for a in self.search_answers()] def save(self, **kwargs): self.question_answer = "question" @@ -186,8 +187,7 @@ def _matches(cls, hit): def search(cls, **kwargs): return cls._index.search(**kwargs).exclude("term", question_answer="question") - @property - def question(self): + def get_question(self): # cache question in self.meta # any attributes set on self would be interpretted as fields if "question" not in self.meta: @@ -208,9 +208,9 @@ def setup(): index_template.save() -if __name__ == "__main__": +def main(): # initiate the default connection to elasticsearch - connections.create_connection() + connections.create_connection(hosts=[os.environ["ELASTICSEARCH_URL"]]) # create index setup() @@ -243,3 +243,12 @@ def setup(): ) question.save() answer = question.add_answer(honza, "Just use `elasticsearch-py`!") + + # close the connection + connections.get_connection().close() + + return answer + + +if __name__ == "__main__": + main() diff --git a/examples/percolate.py b/examples/percolate.py index fbe6a923a..df4709541 100644 --- a/examples/percolate.py +++ b/examples/percolate.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. +import os + from elasticsearch_dsl import ( Document, Keyword, @@ -89,8 +91,15 @@ def setup(): ).save(refresh=True) -if __name__ == "__main__": +def main(): # initiate the default connection to elasticsearch - connections.create_connection() + connections.create_connection(hosts=[os.environ["ELASTICSEARCH_URL"]]) setup() + + # close the connection + connections.get_connection().close() + + +if __name__ == "__main__": + main() diff --git a/examples/search_as_you_type.py b/examples/search_as_you_type.py index f82e12027..31ff3f028 100644 --- a/examples/search_as_you_type.py +++ b/examples/search_as_you_type.py @@ -25,6 +25,8 @@ To custom analyzer with ascii folding allow search to work in different languages. """ +import os + from elasticsearch_dsl import ( Document, SearchAsYouType, @@ -51,9 +53,9 @@ class Index: settings = {"number_of_shards": 1, "number_of_replicas": 0} -if __name__ == "__main__": +def main(): # initiate the default connection to elasticsearch - connections.create_connection() + connections.create_connection(hosts=[os.environ["ELASTICSEARCH_URL"]]) # create the empty index Person.init() @@ -90,3 +92,10 @@ class Index: # print out all the options we got for h in response: print("%15s: %25s" % (text, h.name)) + + # close the connection + connections.get_connection().close() + + +if __name__ == "__main__": + main() diff --git a/noxfile.py b/noxfile.py index 3031effd0..d730593d8 100644 --- a/noxfile.py +++ b/noxfile.py @@ -54,7 +54,8 @@ def test(session): @nox.session(python="3.12") def format(session): - session.install("black~=24.0", "isort") + session.install("black~=24.0", "isort", "unasync", "setuptools") + session.run("python", "utils/run-unasync.py") session.run("black", "--target-version=py38", *SOURCE_FILES) session.run("isort", *SOURCE_FILES) session.run("python", "utils/license-headers.py", "fix", *SOURCE_FILES) @@ -64,9 +65,10 @@ def format(session): @nox.session(python="3.12") def lint(session): - session.install("flake8", "black~=24.0", "isort") + session.install("flake8", "black~=24.0", "isort", "unasync", "setuptools") session.run("black", "--check", "--target-version=py38", *SOURCE_FILES) session.run("isort", "--check", *SOURCE_FILES) + session.run("python", "utils/run-unasync.py", "--check") session.run("flake8", "--ignore=E501,E741,W503", *SOURCE_FILES) session.run("python", "utils/license-headers.py", "check", *SOURCE_FILES) diff --git a/setup.cfg b/setup.cfg index d902949f3..2a34c2fab 100644 --- a/setup.cfg +++ b/setup.cfg @@ -11,3 +11,4 @@ filterwarnings = error ignore:Legacy index templates are deprecated in favor of composable templates.:elasticsearch.exceptions.ElasticsearchWarning ignore:datetime.datetime.utcfromtimestamp\(\) is deprecated and scheduled for removal in a future version..*:DeprecationWarning +asyncio_mode = auto diff --git a/setup.py b/setup.py index 3f3aa7864..adb490a4a 100644 --- a/setup.py +++ b/setup.py @@ -32,10 +32,17 @@ "elasticsearch>=8.0.0,<9.0.0", ] +async_requires = [ + "elasticsearch[async]>=8.0.0,<9.0.0", +] + develop_requires = [ + "elasticsearch[async]", + "unasync", "pytest", "pytest-cov", "pytest-mock", + "pytest-asyncio", "pytz", "coverage", # Override Read the Docs default (sphinx<2 and sphinx-rtd-theme<0.5) @@ -72,5 +79,5 @@ "Programming Language :: Python :: Implementation :: PyPy", ], install_requires=install_requires, - extras_require={"develop": develop_requires}, + extras_require={"async": async_requires, "develop": develop_requires}, ) diff --git a/tests/_async/__init__.py b/tests/_async/__init__.py new file mode 100644 index 000000000..2a87d183f --- /dev/null +++ b/tests/_async/__init__.py @@ -0,0 +1,16 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/_async/test_document.py b/tests/_async/test_document.py new file mode 100644 index 000000000..c26539d8e --- /dev/null +++ b/tests/_async/test_document.py @@ -0,0 +1,638 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import codecs +import ipaddress +import pickle +from datetime import datetime +from hashlib import md5 + +from pytest import raises + +from elasticsearch_dsl import ( + AsyncDocument, + Index, + InnerDoc, + Mapping, + MetaField, + Range, + analyzer, + field, + utils, +) +from elasticsearch_dsl.exceptions import IllegalOperation, ValidationException + + +class MyInner(InnerDoc): + old_field = field.Text() + + +class MyDoc(AsyncDocument): + title = field.Keyword() + name = field.Text() + created_at = field.Date() + inner = field.Object(MyInner) + + +class MySubDoc(MyDoc): + name = field.Keyword() + + class Index: + name = "default-index" + + +class MyDoc2(AsyncDocument): + extra = field.Long() + + +class MyMultiSubDoc(MyDoc2, MySubDoc): + pass + + +class Comment(InnerDoc): + title = field.Text() + tags = field.Keyword(multi=True) + + +class DocWithNested(AsyncDocument): + comments = field.Nested(Comment) + + class Index: + name = "test-doc-with-nested" + + +class SimpleCommit(AsyncDocument): + files = field.Text(multi=True) + + class Index: + name = "test-git" + + +class Secret(str): + pass + + +class SecretField(field.CustomField): + builtin_type = "text" + + def _serialize(self, data): + return codecs.encode(data, "rot_13") + + def _deserialize(self, data): + if isinstance(data, Secret): + return data + return Secret(codecs.decode(data, "rot_13")) + + +class SecretDoc(AsyncDocument): + title = SecretField(index="no") + + class Index: + name = "test-secret-doc" + + +class NestedSecret(AsyncDocument): + secrets = field.Nested(SecretDoc) + + class Index: + name = "test-nested-secret" + + +class OptionalObjectWithRequiredField(AsyncDocument): + comments = field.Nested(properties={"title": field.Keyword(required=True)}) + + class Index: + name = "test-required" + + +class Host(AsyncDocument): + ip = field.Ip() + + class Index: + name = "test-host" + + +def test_range_serializes_properly(): + class D(AsyncDocument): + lr = field.LongRange() + + d = D(lr=Range(lt=42)) + assert 40 in d.lr + assert 47 not in d.lr + assert {"lr": {"lt": 42}} == d.to_dict() + + d = D(lr={"lt": 42}) + assert {"lr": {"lt": 42}} == d.to_dict() + + +def test_range_deserializes_properly(): + class D(InnerDoc): + lr = field.LongRange() + + d = D.from_es({"lr": {"lt": 42}}, True) + assert isinstance(d.lr, Range) + assert 40 in d.lr + assert 47 not in d.lr + + +def test_resolve_nested(): + nested, field = NestedSecret._index.resolve_nested("secrets.title") + assert nested == ["secrets"] + assert field is NestedSecret._doc_type.mapping["secrets"]["title"] + + +def test_conflicting_mapping_raises_error_in_index_to_dict(): + class A(AsyncDocument): + name = field.Text() + + class B(AsyncDocument): + name = field.Keyword() + + i = Index("i") + i.document(A) + i.document(B) + + with raises(ValueError): + i.to_dict() + + +def test_ip_address_serializes_properly(): + host = Host(ip=ipaddress.IPv4Address("10.0.0.1")) + + assert {"ip": "10.0.0.1"} == host.to_dict() + + +def test_matches_uses_index(): + assert SimpleCommit._matches({"_index": "test-git"}) + assert not SimpleCommit._matches({"_index": "not-test-git"}) + + +def test_matches_with_no_name_always_matches(): + class D(AsyncDocument): + pass + + assert D._matches({}) + assert D._matches({"_index": "whatever"}) + + +def test_matches_accepts_wildcards(): + class MyDoc(AsyncDocument): + class Index: + name = "my-*" + + assert MyDoc._matches({"_index": "my-index"}) + assert not MyDoc._matches({"_index": "not-my-index"}) + + +def test_assigning_attrlist_to_field(): + sc = SimpleCommit() + l = ["README", "README.rst"] + sc.files = utils.AttrList(l) + + assert sc.to_dict()["files"] is l + + +def test_optional_inner_objects_are_not_validated_if_missing(): + d = OptionalObjectWithRequiredField() + + assert d.full_clean() is None + + +def test_custom_field(): + s = SecretDoc(title=Secret("Hello")) + + assert {"title": "Uryyb"} == s.to_dict() + assert s.title == "Hello" + + s = SecretDoc.from_es({"_source": {"title": "Uryyb"}}) + assert s.title == "Hello" + assert isinstance(s.title, Secret) + + +def test_custom_field_mapping(): + assert { + "properties": {"title": {"index": "no", "type": "text"}} + } == SecretDoc._doc_type.mapping.to_dict() + + +def test_custom_field_in_nested(): + s = NestedSecret() + s.secrets.append(SecretDoc(title=Secret("Hello"))) + + assert {"secrets": [{"title": "Uryyb"}]} == s.to_dict() + assert s.secrets[0].title == "Hello" + + +def test_multi_works_after_doc_has_been_saved(): + c = SimpleCommit() + c.full_clean() + c.files.append("setup.py") + + assert c.to_dict() == {"files": ["setup.py"]} + + +def test_multi_works_in_nested_after_doc_has_been_serialized(): + # Issue #359 + c = DocWithNested(comments=[Comment(title="First!")]) + + assert [] == c.comments[0].tags + assert {"comments": [{"title": "First!"}]} == c.to_dict() + assert [] == c.comments[0].tags + + +def test_null_value_for_object(): + d = MyDoc(inner=None) + + assert d.inner is None + + +def test_inherited_doc_types_can_override_index(): + class MyDocDifferentIndex(MySubDoc): + class Index: + name = "not-default-index" + settings = {"number_of_replicas": 0} + aliases = {"a": {}} + analyzers = [analyzer("my_analizer", tokenizer="keyword")] + + assert MyDocDifferentIndex._index._name == "not-default-index" + assert MyDocDifferentIndex()._get_index() == "not-default-index" + assert MyDocDifferentIndex._index.to_dict() == { + "aliases": {"a": {}}, + "mappings": { + "properties": { + "created_at": {"type": "date"}, + "inner": { + "type": "object", + "properties": {"old_field": {"type": "text"}}, + }, + "name": {"type": "keyword"}, + "title": {"type": "keyword"}, + } + }, + "settings": { + "analysis": { + "analyzer": {"my_analizer": {"tokenizer": "keyword", "type": "custom"}} + }, + "number_of_replicas": 0, + }, + } + + +def test_to_dict_with_meta(): + d = MySubDoc(title="hello") + d.meta.routing = "some-parent" + + assert { + "_index": "default-index", + "_routing": "some-parent", + "_source": {"title": "hello"}, + } == d.to_dict(True) + + +def test_to_dict_with_meta_includes_custom_index(): + d = MySubDoc(title="hello") + d.meta.index = "other-index" + + assert {"_index": "other-index", "_source": {"title": "hello"}} == d.to_dict(True) + + +def test_to_dict_without_skip_empty_will_include_empty_fields(): + d = MySubDoc(tags=[], title=None, inner={}) + + assert {} == d.to_dict() + assert {"tags": [], "title": None, "inner": {}} == d.to_dict(skip_empty=False) + + +def test_attribute_can_be_removed(): + d = MyDoc(title="hello") + + del d.title + assert "title" not in d._d_ + + +def test_doc_type_can_be_correctly_pickled(): + d = DocWithNested( + title="Hello World!", comments=[Comment(title="hellp")], meta={"id": 42} + ) + s = pickle.dumps(d) + + d2 = pickle.loads(s) + + assert d2 == d + assert 42 == d2.meta.id + assert "Hello World!" == d2.title + assert [{"title": "hellp"}] == d2.comments + assert isinstance(d2.comments[0], Comment) + + +def test_meta_is_accessible_even_on_empty_doc(): + d = MyDoc() + d.meta + + d = MyDoc(title="aaa") + d.meta + + +def test_meta_field_mapping(): + class User(AsyncDocument): + username = field.Text() + + class Meta: + all = MetaField(enabled=False) + _index = MetaField(enabled=True) + dynamic = MetaField("strict") + dynamic_templates = MetaField([42]) + + assert { + "properties": {"username": {"type": "text"}}, + "_all": {"enabled": False}, + "_index": {"enabled": True}, + "dynamic": "strict", + "dynamic_templates": [42], + } == User._doc_type.mapping.to_dict() + + +def test_multi_value_fields(): + class Blog(AsyncDocument): + tags = field.Keyword(multi=True) + + b = Blog() + assert [] == b.tags + b.tags.append("search") + b.tags.append("python") + assert ["search", "python"] == b.tags + + +def test_docs_with_properties(): + class User(AsyncDocument): + pwd_hash = field.Text() + + def check_password(self, pwd): + return md5(pwd).hexdigest() == self.pwd_hash + + @property + def password(self): + raise AttributeError("readonly") + + @password.setter + def password(self, pwd): + self.pwd_hash = md5(pwd).hexdigest() + + u = User(pwd_hash=md5(b"secret").hexdigest()) + assert u.check_password(b"secret") + assert not u.check_password(b"not-secret") + + u.password = b"not-secret" + assert "password" not in u._d_ + assert not u.check_password(b"secret") + assert u.check_password(b"not-secret") + + with raises(AttributeError): + u.password + + +def test_nested_can_be_assigned_to(): + d1 = DocWithNested(comments=[Comment(title="First!")]) + d2 = DocWithNested() + + d2.comments = d1.comments + assert isinstance(d1.comments[0], Comment) + assert d2.comments == [{"title": "First!"}] + assert {"comments": [{"title": "First!"}]} == d2.to_dict() + assert isinstance(d2.comments[0], Comment) + + +def test_nested_can_be_none(): + d = DocWithNested(comments=None, title="Hello World!") + + assert {"title": "Hello World!"} == d.to_dict() + + +def test_nested_defaults_to_list_and_can_be_updated(): + md = DocWithNested() + + assert [] == md.comments + + md.comments.append({"title": "hello World!"}) + assert {"comments": [{"title": "hello World!"}]} == md.to_dict() + + +def test_to_dict_is_recursive_and_can_cope_with_multi_values(): + md = MyDoc(name=["a", "b", "c"]) + md.inner = [MyInner(old_field="of1"), MyInner(old_field="of2")] + + assert isinstance(md.inner[0], MyInner) + + assert { + "name": ["a", "b", "c"], + "inner": [{"old_field": "of1"}, {"old_field": "of2"}], + } == md.to_dict() + + +def test_to_dict_ignores_empty_collections(): + md = MySubDoc(name="", address={}, count=0, valid=False, tags=[]) + + assert {"name": "", "count": 0, "valid": False} == md.to_dict() + + +def test_declarative_mapping_definition(): + assert issubclass(MyDoc, AsyncDocument) + assert hasattr(MyDoc, "_doc_type") + assert { + "properties": { + "created_at": {"type": "date"}, + "name": {"type": "text"}, + "title": {"type": "keyword"}, + "inner": {"type": "object", "properties": {"old_field": {"type": "text"}}}, + } + } == MyDoc._doc_type.mapping.to_dict() + + +def test_you_can_supply_own_mapping_instance(): + class MyD(AsyncDocument): + title = field.Text() + + class Meta: + mapping = Mapping() + mapping.meta("_all", enabled=False) + + assert { + "_all": {"enabled": False}, + "properties": {"title": {"type": "text"}}, + } == MyD._doc_type.mapping.to_dict() + + +def test_document_can_be_created_dynamically(): + n = datetime.now() + md = MyDoc(title="hello") + md.name = "My Fancy Document!" + md.created_at = n + + inner = md.inner + # consistent returns + assert inner is md.inner + inner.old_field = "Already defined." + + md.inner.new_field = ["undefined", "field"] + + assert { + "title": "hello", + "name": "My Fancy Document!", + "created_at": n, + "inner": {"old_field": "Already defined.", "new_field": ["undefined", "field"]}, + } == md.to_dict() + + +def test_invalid_date_will_raise_exception(): + md = MyDoc() + md.created_at = "not-a-date" + with raises(ValidationException): + md.full_clean() + + +def test_document_inheritance(): + assert issubclass(MySubDoc, MyDoc) + assert issubclass(MySubDoc, AsyncDocument) + assert hasattr(MySubDoc, "_doc_type") + assert { + "properties": { + "created_at": {"type": "date"}, + "name": {"type": "keyword"}, + "title": {"type": "keyword"}, + "inner": {"type": "object", "properties": {"old_field": {"type": "text"}}}, + } + } == MySubDoc._doc_type.mapping.to_dict() + + +def test_child_class_can_override_parent(): + class A(AsyncDocument): + o = field.Object(dynamic=False, properties={"a": field.Text()}) + + class B(A): + o = field.Object(dynamic="strict", properties={"b": field.Text()}) + + assert { + "properties": { + "o": { + "dynamic": "strict", + "properties": {"a": {"type": "text"}, "b": {"type": "text"}}, + "type": "object", + } + } + } == B._doc_type.mapping.to_dict() + + +def test_meta_fields_are_stored_in_meta_and_ignored_by_to_dict(): + md = MySubDoc(meta={"id": 42}, name="My First doc!") + + md.meta.index = "my-index" + assert md.meta.index == "my-index" + assert md.meta.id == 42 + assert {"name": "My First doc!"} == md.to_dict() + assert {"id": 42, "index": "my-index"} == md.meta.to_dict() + + +def test_index_inheritance(): + assert issubclass(MyMultiSubDoc, MySubDoc) + assert issubclass(MyMultiSubDoc, MyDoc2) + assert issubclass(MyMultiSubDoc, AsyncDocument) + assert hasattr(MyMultiSubDoc, "_doc_type") + assert hasattr(MyMultiSubDoc, "_index") + assert { + "properties": { + "created_at": {"type": "date"}, + "name": {"type": "keyword"}, + "title": {"type": "keyword"}, + "inner": {"type": "object", "properties": {"old_field": {"type": "text"}}}, + "extra": {"type": "long"}, + } + } == MyMultiSubDoc._doc_type.mapping.to_dict() + + +def test_meta_fields_can_be_set_directly_in_init(): + p = object() + md = MyDoc(_id=p, title="Hello World!") + + assert md.meta.id is p + + +async def test_save_no_index(async_mock_client): + md = MyDoc() + with raises(ValidationException): + await md.save(using="mock") + + +async def test_delete_no_index(async_mock_client): + md = MyDoc() + with raises(ValidationException): + await md.delete(using="mock") + + +async def test_update_no_fields(): + md = MyDoc() + with raises(IllegalOperation): + await md.update() + + +def test_search_with_custom_alias_and_index(): + search_object = MyDoc.search( + using="staging", index=["custom_index1", "custom_index2"] + ) + + assert search_object._using == "staging" + assert search_object._index == ["custom_index1", "custom_index2"] + + +def test_from_es_respects_underscored_non_meta_fields(): + doc = { + "_index": "test-index", + "_id": "elasticsearch", + "_score": 12.0, + "fields": {"hello": "world", "_routing": "es", "_tags": ["search"]}, + "_source": { + "city": "Amsterdam", + "name": "Elasticsearch", + "_tagline": "You know, for search", + }, + } + + class Company(AsyncDocument): + class Index: + name = "test-company" + + c = Company.from_es(doc) + + assert c.meta.fields._tags == ["search"] + assert c.meta.fields._routing == "es" + assert c._tagline == "You know, for search" + + +def test_nested_and_object_inner_doc(): + class MySubDocWithNested(MyDoc): + nested_inner = field.Nested(MyInner) + + props = MySubDocWithNested._doc_type.mapping.to_dict()["properties"] + assert props == { + "created_at": {"type": "date"}, + "inner": {"properties": {"old_field": {"type": "text"}}, "type": "object"}, + "name": {"type": "text"}, + "nested_inner": { + "properties": {"old_field": {"type": "text"}}, + "type": "nested", + }, + "title": {"type": "keyword"}, + } diff --git a/tests/_async/test_faceted_search.py b/tests/_async/test_faceted_search.py new file mode 100644 index 000000000..701cb99a7 --- /dev/null +++ b/tests/_async/test_faceted_search.py @@ -0,0 +1,194 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from datetime import datetime + +import pytest + +from elasticsearch_dsl.faceted_search import ( + AsyncFacetedSearch, + DateHistogramFacet, + TermsFacet, +) + + +class BlogSearch(AsyncFacetedSearch): + doc_types = ["user", "post"] + fields = ( + "title^5", + "body", + ) + + facets = { + "category": TermsFacet(field="category.raw"), + "tags": TermsFacet(field="tags"), + } + + +def test_query_is_created_properly(): + bs = BlogSearch("python search") + s = bs.build_search() + + assert s._doc_type == ["user", "post"] + assert { + "aggs": { + "_filter_tags": { + "filter": {"match_all": {}}, + "aggs": {"tags": {"terms": {"field": "tags"}}}, + }, + "_filter_category": { + "filter": {"match_all": {}}, + "aggs": {"category": {"terms": {"field": "category.raw"}}}, + }, + }, + "query": { + "multi_match": {"fields": ("title^5", "body"), "query": "python search"} + }, + "highlight": {"fields": {"body": {}, "title": {}}}, + } == s.to_dict() + + +def test_query_is_created_properly_with_sort_tuple(): + bs = BlogSearch("python search", sort=("category", "-title")) + s = bs.build_search() + + assert s._doc_type == ["user", "post"] + assert { + "aggs": { + "_filter_tags": { + "filter": {"match_all": {}}, + "aggs": {"tags": {"terms": {"field": "tags"}}}, + }, + "_filter_category": { + "filter": {"match_all": {}}, + "aggs": {"category": {"terms": {"field": "category.raw"}}}, + }, + }, + "query": { + "multi_match": {"fields": ("title^5", "body"), "query": "python search"} + }, + "highlight": {"fields": {"body": {}, "title": {}}}, + "sort": ["category", {"title": {"order": "desc"}}], + } == s.to_dict() + + +def test_filter_is_applied_to_search_but_not_relevant_facet(): + bs = BlogSearch("python search", filters={"category": "elastic"}) + s = bs.build_search() + + assert { + "aggs": { + "_filter_tags": { + "filter": {"terms": {"category.raw": ["elastic"]}}, + "aggs": {"tags": {"terms": {"field": "tags"}}}, + }, + "_filter_category": { + "filter": {"match_all": {}}, + "aggs": {"category": {"terms": {"field": "category.raw"}}}, + }, + }, + "post_filter": {"terms": {"category.raw": ["elastic"]}}, + "query": { + "multi_match": {"fields": ("title^5", "body"), "query": "python search"} + }, + "highlight": {"fields": {"body": {}, "title": {}}}, + } == s.to_dict() + + +def test_filters_are_applied_to_search_ant_relevant_facets(): + bs = BlogSearch( + "python search", filters={"category": "elastic", "tags": ["python", "django"]} + ) + s = bs.build_search() + + d = s.to_dict() + + # we need to test post_filter without relying on order + f = d["post_filter"]["bool"].pop("must") + assert len(f) == 2 + assert {"terms": {"category.raw": ["elastic"]}} in f + assert {"terms": {"tags": ["python", "django"]}} in f + + assert { + "aggs": { + "_filter_tags": { + "filter": {"terms": {"category.raw": ["elastic"]}}, + "aggs": {"tags": {"terms": {"field": "tags"}}}, + }, + "_filter_category": { + "filter": {"terms": {"tags": ["python", "django"]}}, + "aggs": {"category": {"terms": {"field": "category.raw"}}}, + }, + }, + "query": { + "multi_match": {"fields": ("title^5", "body"), "query": "python search"} + }, + "post_filter": {"bool": {}}, + "highlight": {"fields": {"body": {}, "title": {}}}, + } == d + + +def test_date_histogram_facet_with_1970_01_01_date(): + dhf = DateHistogramFacet() + assert dhf.get_value({"key": None}) == datetime(1970, 1, 1, 0, 0) + assert dhf.get_value({"key": 0}) == datetime(1970, 1, 1, 0, 0) + + +@pytest.mark.parametrize( + ["interval_type", "interval"], + [ + ("interval", "year"), + ("calendar_interval", "year"), + ("interval", "month"), + ("calendar_interval", "month"), + ("interval", "week"), + ("calendar_interval", "week"), + ("interval", "day"), + ("calendar_interval", "day"), + ("fixed_interval", "day"), + ("interval", "hour"), + ("fixed_interval", "hour"), + ("interval", "1Y"), + ("calendar_interval", "1Y"), + ("interval", "1M"), + ("calendar_interval", "1M"), + ("interval", "1w"), + ("calendar_interval", "1w"), + ("interval", "1d"), + ("calendar_interval", "1d"), + ("fixed_interval", "1d"), + ("interval", "1h"), + ("fixed_interval", "1h"), + ], +) +def test_date_histogram_interval_types(interval_type, interval): + dhf = DateHistogramFacet(field="@timestamp", **{interval_type: interval}) + assert dhf.get_aggregation().to_dict() == { + "date_histogram": { + "field": "@timestamp", + interval_type: interval, + "min_doc_count": 0, + } + } + dhf.get_value_filter(datetime.now()) + + +def test_date_histogram_no_interval_keyerror(): + dhf = DateHistogramFacet(field="@timestamp") + with pytest.raises(KeyError) as e: + dhf.get_value_filter(datetime.now()) + assert str(e.value) == "'interval'" diff --git a/tests/_async/test_index.py b/tests/_async/test_index.py new file mode 100644 index 000000000..aab603ed7 --- /dev/null +++ b/tests/_async/test_index.py @@ -0,0 +1,194 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import string +from random import choice + +from pytest import raises + +from elasticsearch_dsl import ( + AsyncDocument, + AsyncIndex, + AsyncIndexTemplate, + Date, + Text, + analyzer, +) + + +class Post(AsyncDocument): + title = Text() + published_from = Date() + + +def test_multiple_doc_types_will_combine_mappings(): + class User(AsyncDocument): + username = Text() + + i = AsyncIndex("i") + i.document(Post) + i.document(User) + assert { + "mappings": { + "properties": { + "title": {"type": "text"}, + "username": {"type": "text"}, + "published_from": {"type": "date"}, + } + } + } == i.to_dict() + + +def test_search_is_limited_to_index_name(): + i = AsyncIndex("my-index") + s = i.search() + + assert s._index == ["my-index"] + + +def test_cloned_index_has_copied_settings_and_using(): + client = object() + i = AsyncIndex("my-index", using=client) + i.settings(number_of_shards=1) + + i2 = i.clone("my-other-index") + + assert "my-other-index" == i2._name + assert client is i2._using + assert i._settings == i2._settings + assert i._settings is not i2._settings + + +def test_cloned_index_has_analysis_attribute(): + """ + Regression test for Issue #582 in which `AsyncIndex.clone()` was not copying + over the `_analysis` attribute. + """ + client = object() + i = AsyncIndex("my-index", using=client) + + random_analyzer_name = "".join(choice(string.ascii_letters) for _ in range(100)) + random_analyzer = analyzer( + random_analyzer_name, tokenizer="standard", filter="standard" + ) + + i.analyzer(random_analyzer) + + i2 = i.clone("my-clone-index") + + assert i.to_dict()["settings"]["analysis"] == i2.to_dict()["settings"]["analysis"] + + +def test_settings_are_saved(): + i = AsyncIndex("i") + i.settings(number_of_replicas=0) + i.settings(number_of_shards=1) + + assert {"settings": {"number_of_shards": 1, "number_of_replicas": 0}} == i.to_dict() + + +def test_registered_doc_type_included_in_to_dict(): + i = AsyncIndex("i", using="alias") + i.document(Post) + + assert { + "mappings": { + "properties": { + "title": {"type": "text"}, + "published_from": {"type": "date"}, + } + } + } == i.to_dict() + + +def test_registered_doc_type_included_in_search(): + i = AsyncIndex("i", using="alias") + i.document(Post) + + s = i.search() + + assert s._doc_type == [Post] + + +def test_aliases_add_to_object(): + random_alias = "".join(choice(string.ascii_letters) for _ in range(100)) + alias_dict = {random_alias: {}} + + index = AsyncIndex("i", using="alias") + index.aliases(**alias_dict) + + assert index._aliases == alias_dict + + +def test_aliases_returned_from_to_dict(): + random_alias = "".join(choice(string.ascii_letters) for _ in range(100)) + alias_dict = {random_alias: {}} + + index = AsyncIndex("i", using="alias") + index.aliases(**alias_dict) + + assert index._aliases == index.to_dict()["aliases"] == alias_dict + + +def test_analyzers_added_to_object(): + random_analyzer_name = "".join(choice(string.ascii_letters) for _ in range(100)) + random_analyzer = analyzer( + random_analyzer_name, tokenizer="standard", filter="standard" + ) + + index = AsyncIndex("i", using="alias") + index.analyzer(random_analyzer) + + assert index._analysis["analyzer"][random_analyzer_name] == { + "filter": ["standard"], + "type": "custom", + "tokenizer": "standard", + } + + +def test_analyzers_returned_from_to_dict(): + random_analyzer_name = "".join(choice(string.ascii_letters) for _ in range(100)) + random_analyzer = analyzer( + random_analyzer_name, tokenizer="standard", filter="standard" + ) + index = AsyncIndex("i", using="alias") + index.analyzer(random_analyzer) + + assert index.to_dict()["settings"]["analysis"]["analyzer"][ + random_analyzer_name + ] == {"filter": ["standard"], "type": "custom", "tokenizer": "standard"} + + +def test_conflicting_analyzer_raises_error(): + i = AsyncIndex("i") + i.analyzer("my_analyzer", tokenizer="whitespace", filter=["lowercase", "stop"]) + + with raises(ValueError): + i.analyzer("my_analyzer", tokenizer="keyword", filter=["lowercase", "stop"]) + + +def test_index_template_can_have_order(): + i = AsyncIndex("i-*") + it = i.as_template("i", order=2) + + assert {"index_patterns": ["i-*"], "order": 2} == it.to_dict() + + +async def test_index_template_save_result(async_mock_client): + it = AsyncIndexTemplate("test-template", "test-*") + + assert await it.save(using="mock") == await async_mock_client.indices.put_template() diff --git a/tests/test_mapping.py b/tests/_async/test_mapping.py similarity index 95% rename from tests/test_mapping.py rename to tests/_async/test_mapping.py index aa4939fbc..6d47901c9 100644 --- a/tests/test_mapping.py +++ b/tests/_async/test_mapping.py @@ -17,11 +17,11 @@ import json -from elasticsearch_dsl import Keyword, Nested, Text, analysis, mapping +from elasticsearch_dsl import AsyncMapping, Keyword, Nested, Text, analysis def test_mapping_can_has_fields(): - m = mapping.Mapping() + m = AsyncMapping() m.field("name", "text").field("tags", "keyword") assert { @@ -30,14 +30,14 @@ def test_mapping_can_has_fields(): def test_mapping_update_is_recursive(): - m1 = mapping.Mapping() + m1 = AsyncMapping() m1.field("title", "text") m1.field("author", "object") m1.field("author", "object", properties={"name": {"type": "text"}}) m1.meta("_all", enabled=False) m1.meta("dynamic", False) - m2 = mapping.Mapping() + m2 = AsyncMapping() m2.field("published_from", "date") m2.field("author", "object", properties={"email": {"type": "text"}}) m2.field("title", "text") @@ -63,7 +63,7 @@ def test_mapping_update_is_recursive(): def test_properties_can_iterate_over_all_the_fields(): - m = mapping.Mapping() + m = AsyncMapping() m.field("f1", "text", test_attr="f1", fields={"f2": Keyword(test_attr="f2")}) m.field("f3", Nested(test_attr="f3", properties={"f4": Text(test_attr="f4")})) @@ -100,7 +100,7 @@ def test_mapping_can_collect_all_analyzers_and_normalizers(): ) n3 = analysis.normalizer("unknown_custom") - m = mapping.Mapping() + m = AsyncMapping() m.field( "title", "text", @@ -159,7 +159,7 @@ def test_mapping_can_collect_multiple_analyzers(): tokenizer=analysis.tokenizer("trigram", "nGram", min_gram=3, max_gram=3), filter=[analysis.token_filter("my_filter2", "stop", stopwords=["c", "d"])], ) - m = mapping.Mapping() + m = AsyncMapping() m.field("title", "text", analyzer=a1, search_analyzer=a2) m.field( "text", @@ -193,7 +193,7 @@ def test_mapping_can_collect_multiple_analyzers(): def test_even_non_custom_analyzers_can_have_params(): a1 = analysis.analyzer("whitespace", type="pattern", pattern=r"\\s+") - m = mapping.Mapping() + m = AsyncMapping() m.field("title", "text", analyzer=a1) assert { @@ -202,14 +202,14 @@ def test_even_non_custom_analyzers_can_have_params(): def test_resolve_field_can_resolve_multifields(): - m = mapping.Mapping() + m = AsyncMapping() m.field("title", "text", fields={"keyword": Keyword()}) assert isinstance(m.resolve_field("title.keyword"), Keyword) def test_resolve_nested(): - m = mapping.Mapping() + m = AsyncMapping() m.field("n1", "nested", properties={"n2": Nested(properties={"k1": Keyword()})}) m.field("k2", "keyword") diff --git a/tests/_async/test_search.py b/tests/_async/test_search.py new file mode 100644 index 000000000..1f7ad1de8 --- /dev/null +++ b/tests/_async/test_search.py @@ -0,0 +1,681 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from copy import deepcopy + +from pytest import raises + +from elasticsearch_dsl import AsyncSearch, Document, Q, query +from elasticsearch_dsl.exceptions import IllegalOperation + + +def test_expand__to_dot_is_respected(): + s = AsyncSearch().query("match", a__b=42, _expand__to_dot=False) + + assert {"query": {"match": {"a__b": 42}}} == s.to_dict() + + +async def test_execute_uses_cache(): + s = AsyncSearch() + r = object() + s._response = r + + assert r is await s.execute() + + +async def test_cache_can_be_ignored(async_mock_client): + s = AsyncSearch(using="mock") + r = object() + s._response = r + await s.execute(ignore_cache=True) + + async_mock_client.search.assert_awaited_once_with(index=None, body={}) + + +async def test_iter_iterates_over_hits(): + s = AsyncSearch() + s._response = [1, 2, 3] + + assert [1, 2, 3] == [hit async for hit in s] + + +def test_cache_isnt_cloned(): + s = AsyncSearch() + s._response = object() + + assert not hasattr(s._clone(), "_response") + + +def test_search_starts_with_no_query(): + s = AsyncSearch() + + assert s.query._proxied is None + + +def test_search_query_combines_query(): + s = AsyncSearch() + + s2 = s.query("match", f=42) + assert s2.query._proxied == query.Match(f=42) + assert s.query._proxied is None + + s3 = s2.query("match", f=43) + assert s2.query._proxied == query.Match(f=42) + assert s3.query._proxied == query.Bool(must=[query.Match(f=42), query.Match(f=43)]) + + +def test_query_can_be_assigned_to(): + s = AsyncSearch() + + q = Q("match", title="python") + s.query = q + + assert s.query._proxied is q + + +def test_query_can_be_wrapped(): + s = AsyncSearch().query("match", title="python") + + s.query = Q("function_score", query=s.query, field_value_factor={"field": "rating"}) + + assert { + "query": { + "function_score": { + "functions": [{"field_value_factor": {"field": "rating"}}], + "query": {"match": {"title": "python"}}, + } + } + } == s.to_dict() + + +def test_using(): + o = object() + o2 = object() + s = AsyncSearch(using=o) + assert s._using is o + s2 = s.using(o2) + assert s._using is o + assert s2._using is o2 + + +def test_methods_are_proxied_to_the_query(): + s = AsyncSearch().query("match_all") + + assert s.query.to_dict() == {"match_all": {}} + + +def test_query_always_returns_search(): + s = AsyncSearch() + + assert isinstance(s.query("match", f=42), AsyncSearch) + + +def test_source_copied_on_clone(): + s = AsyncSearch().source(False) + assert s._clone()._source == s._source + assert s._clone()._source is False + + s2 = AsyncSearch().source([]) + assert s2._clone()._source == s2._source + assert s2._source == [] + + s3 = AsyncSearch().source(["some", "fields"]) + assert s3._clone()._source == s3._source + assert s3._clone()._source == ["some", "fields"] + + +def test_copy_clones(): + from copy import copy + + s1 = AsyncSearch().source(["some", "fields"]) + s2 = copy(s1) + + assert s1 == s2 + assert s1 is not s2 + + +def test_aggs_allow_two_metric(): + s = AsyncSearch() + + s.aggs.metric("a", "max", field="a").metric("b", "max", field="b") + + assert s.to_dict() == { + "aggs": {"a": {"max": {"field": "a"}}, "b": {"max": {"field": "b"}}} + } + + +def test_aggs_get_copied_on_change(): + s = AsyncSearch().query("match_all") + s.aggs.bucket("per_tag", "terms", field="f").metric( + "max_score", "max", field="score" + ) + + s2 = s.query("match_all") + s2.aggs.bucket("per_month", "date_histogram", field="date", interval="month") + s3 = s2.query("match_all") + s3.aggs["per_month"].metric("max_score", "max", field="score") + s4 = s3._clone() + s4.aggs.metric("max_score", "max", field="score") + + d = { + "query": {"match_all": {}}, + "aggs": { + "per_tag": { + "terms": {"field": "f"}, + "aggs": {"max_score": {"max": {"field": "score"}}}, + } + }, + } + + assert d == s.to_dict() + d["aggs"]["per_month"] = {"date_histogram": {"field": "date", "interval": "month"}} + assert d == s2.to_dict() + d["aggs"]["per_month"]["aggs"] = {"max_score": {"max": {"field": "score"}}} + assert d == s3.to_dict() + d["aggs"]["max_score"] = {"max": {"field": "score"}} + assert d == s4.to_dict() + + +def test_search_index(): + s = AsyncSearch(index="i") + assert s._index == ["i"] + s = s.index("i2") + assert s._index == ["i", "i2"] + s = s.index("i3") + assert s._index == ["i", "i2", "i3"] + s = s.index() + assert s._index is None + s = AsyncSearch(index=("i", "i2")) + assert s._index == ["i", "i2"] + s = AsyncSearch(index=["i", "i2"]) + assert s._index == ["i", "i2"] + s = AsyncSearch() + s = s.index("i", "i2") + assert s._index == ["i", "i2"] + s2 = s.index("i3") + assert s._index == ["i", "i2"] + assert s2._index == ["i", "i2", "i3"] + s = AsyncSearch() + s = s.index(["i", "i2"], "i3") + assert s._index == ["i", "i2", "i3"] + s2 = s.index("i4") + assert s._index == ["i", "i2", "i3"] + assert s2._index == ["i", "i2", "i3", "i4"] + s2 = s.index(["i4"]) + assert s2._index == ["i", "i2", "i3", "i4"] + s2 = s.index(("i4", "i5")) + assert s2._index == ["i", "i2", "i3", "i4", "i5"] + + +def test_doc_type_document_class(): + class MyDocument(Document): + pass + + s = AsyncSearch(doc_type=MyDocument) + assert s._doc_type == [MyDocument] + assert s._doc_type_map == {} + + s = AsyncSearch().doc_type(MyDocument) + assert s._doc_type == [MyDocument] + assert s._doc_type_map == {} + + +def test_knn(): + s = AsyncSearch() + + with raises(TypeError): + s.knn() + with raises(TypeError): + s.knn("field") + with raises(TypeError): + s.knn("field", 5) + with raises(ValueError): + s.knn("field", 5, 100) + with raises(ValueError): + s.knn("field", 5, 100, query_vector=[1, 2, 3], query_vector_builder={}) + + s = s.knn("field", 5, 100, query_vector=[1, 2, 3]) + assert { + "knn": { + "field": "field", + "k": 5, + "num_candidates": 100, + "query_vector": [1, 2, 3], + } + } == s.to_dict() + + s = s.knn( + k=4, + num_candidates=40, + boost=0.8, + field="name", + query_vector_builder={ + "text_embedding": {"model_id": "foo", "model_text": "search text"} + }, + ) + assert { + "knn": [ + { + "field": "field", + "k": 5, + "num_candidates": 100, + "query_vector": [1, 2, 3], + }, + { + "field": "name", + "k": 4, + "num_candidates": 40, + "query_vector_builder": { + "text_embedding": {"model_id": "foo", "model_text": "search text"} + }, + "boost": 0.8, + }, + ] + } == s.to_dict() + + +def test_rank(): + s = AsyncSearch() + s.rank(rrf=False) + assert {} == s.to_dict() + + s = s.rank(rrf=True) + assert {"rank": {"rrf": {}}} == s.to_dict() + + s = s.rank(rrf={"window_size": 50, "rank_constant": 20}) + assert {"rank": {"rrf": {"window_size": 50, "rank_constant": 20}}} == s.to_dict() + + +def test_sort(): + s = AsyncSearch() + s = s.sort("fielda", "-fieldb") + + assert ["fielda", {"fieldb": {"order": "desc"}}] == s._sort + assert {"sort": ["fielda", {"fieldb": {"order": "desc"}}]} == s.to_dict() + + s = s.sort() + assert [] == s._sort + assert AsyncSearch().to_dict() == s.to_dict() + + +def test_sort_by_score(): + s = AsyncSearch() + s = s.sort("_score") + assert {"sort": ["_score"]} == s.to_dict() + + s = AsyncSearch() + with raises(IllegalOperation): + s.sort("-_score") + + +def test_collapse(): + s = AsyncSearch() + + inner_hits = {"name": "most_recent", "size": 5, "sort": [{"@timestamp": "desc"}]} + s = s.collapse("user.id", inner_hits=inner_hits, max_concurrent_group_searches=4) + + assert { + "field": "user.id", + "inner_hits": { + "name": "most_recent", + "size": 5, + "sort": [{"@timestamp": "desc"}], + }, + "max_concurrent_group_searches": 4, + } == s._collapse + assert { + "collapse": { + "field": "user.id", + "inner_hits": { + "name": "most_recent", + "size": 5, + "sort": [{"@timestamp": "desc"}], + }, + "max_concurrent_group_searches": 4, + } + } == s.to_dict() + + s = s.collapse() + assert {} == s._collapse + assert AsyncSearch().to_dict() == s.to_dict() + + +def test_slice(): + s = AsyncSearch() + assert {"from": 3, "size": 7} == s[3:10].to_dict() + assert {"from": 0, "size": 5} == s[:5].to_dict() + assert {"from": 3, "size": 10} == s[3:].to_dict() + assert {"from": 0, "size": 0} == s[0:0].to_dict() + assert {"from": 20, "size": 0} == s[20:0].to_dict() + + +def test_index(): + s = AsyncSearch() + assert {"from": 3, "size": 1} == s[3].to_dict() + + +def test_search_to_dict(): + s = AsyncSearch() + assert {} == s.to_dict() + + s = s.query("match", f=42) + assert {"query": {"match": {"f": 42}}} == s.to_dict() + + assert {"query": {"match": {"f": 42}}, "size": 10} == s.to_dict(size=10) + + s.aggs.bucket("per_tag", "terms", field="f").metric( + "max_score", "max", field="score" + ) + d = { + "aggs": { + "per_tag": { + "terms": {"field": "f"}, + "aggs": {"max_score": {"max": {"field": "score"}}}, + } + }, + "query": {"match": {"f": 42}}, + } + assert d == s.to_dict() + + s = AsyncSearch(extra={"size": 5}) + assert {"size": 5} == s.to_dict() + s = s.extra(from_=42) + assert {"size": 5, "from": 42} == s.to_dict() + + +def test_complex_example(): + s = AsyncSearch() + s = ( + s.query("match", title="python") + .query(~Q("match", title="ruby")) + .filter(Q("term", category="meetup") | Q("term", category="conference")) + .collapse("user_id") + .post_filter("terms", tags=["prague", "czech"]) + .script_fields(more_attendees="doc['attendees'].value + 42") + ) + + s.aggs.bucket("per_country", "terms", field="country").metric( + "avg_attendees", "avg", field="attendees" + ) + + s.query.minimum_should_match = 2 + + s = s.highlight_options(order="score").highlight("title", "body", fragment_size=50) + + assert { + "query": { + "bool": { + "filter": [ + { + "bool": { + "should": [ + {"term": {"category": "meetup"}}, + {"term": {"category": "conference"}}, + ] + } + } + ], + "must": [{"match": {"title": "python"}}], + "must_not": [{"match": {"title": "ruby"}}], + "minimum_should_match": 2, + } + }, + "post_filter": {"terms": {"tags": ["prague", "czech"]}}, + "aggs": { + "per_country": { + "terms": {"field": "country"}, + "aggs": {"avg_attendees": {"avg": {"field": "attendees"}}}, + } + }, + "collapse": {"field": "user_id"}, + "highlight": { + "order": "score", + "fields": {"title": {"fragment_size": 50}, "body": {"fragment_size": 50}}, + }, + "script_fields": {"more_attendees": {"script": "doc['attendees'].value + 42"}}, + } == s.to_dict() + + +def test_reverse(): + d = { + "query": { + "filtered": { + "filter": { + "bool": { + "should": [ + {"term": {"category": "meetup"}}, + {"term": {"category": "conference"}}, + ] + } + }, + "query": { + "bool": { + "must": [{"match": {"title": "python"}}], + "must_not": [{"match": {"title": "ruby"}}], + "minimum_should_match": 2, + } + }, + } + }, + "post_filter": {"bool": {"must": [{"terms": {"tags": ["prague", "czech"]}}]}}, + "aggs": { + "per_country": { + "terms": {"field": "country"}, + "aggs": {"avg_attendees": {"avg": {"field": "attendees"}}}, + } + }, + "sort": ["title", {"category": {"order": "desc"}}, "_score"], + "size": 5, + "highlight": {"order": "score", "fields": {"title": {"fragment_size": 50}}}, + "suggest": { + "my-title-suggestions-1": { + "text": "devloping distibutd saerch engies", + "term": {"size": 3, "field": "title"}, + } + }, + "script_fields": {"more_attendees": {"script": "doc['attendees'].value + 42"}}, + } + + d2 = deepcopy(d) + + s = AsyncSearch.from_dict(d) + + # make sure we haven't modified anything in place + assert d == d2 + assert {"size": 5} == s._extra + assert d == s.to_dict() + + +def test_from_dict_doesnt_need_query(): + s = AsyncSearch.from_dict({"size": 5}) + + assert {"size": 5} == s.to_dict() + + +async def test_params_being_passed_to_search(async_mock_client): + s = AsyncSearch(using="mock") + s = s.params(routing="42") + await s.execute() + + async_mock_client.search.assert_awaited_once_with(index=None, body={}, routing="42") + + +def test_source(): + assert {} == AsyncSearch().source().to_dict() + + assert { + "_source": {"includes": ["foo.bar.*"], "excludes": ["foo.one"]} + } == AsyncSearch().source(includes=["foo.bar.*"], excludes=["foo.one"]).to_dict() + + assert {"_source": False} == AsyncSearch().source(False).to_dict() + + assert {"_source": ["f1", "f2"]} == AsyncSearch().source( + includes=["foo.bar.*"], excludes=["foo.one"] + ).source(["f1", "f2"]).to_dict() + + +def test_source_on_clone(): + assert { + "_source": {"includes": ["foo.bar.*"], "excludes": ["foo.one"]}, + "query": {"bool": {"filter": [{"term": {"title": "python"}}]}}, + } == AsyncSearch().source(includes=["foo.bar.*"]).source( + excludes=["foo.one"] + ).filter( + "term", title="python" + ).to_dict() + assert { + "_source": False, + "query": {"bool": {"filter": [{"term": {"title": "python"}}]}}, + } == AsyncSearch().source(False).filter("term", title="python").to_dict() + + +def test_source_on_clear(): + assert ( + {} + == AsyncSearch() + .source(includes=["foo.bar.*"]) + .source(includes=None, excludes=None) + .to_dict() + ) + + +def test_suggest_accepts_global_text(): + s = AsyncSearch.from_dict( + { + "suggest": { + "text": "the amsterdma meetpu", + "my-suggest-1": {"term": {"field": "title"}}, + "my-suggest-2": {"text": "other", "term": {"field": "body"}}, + } + } + ) + + assert { + "suggest": { + "my-suggest-1": { + "term": {"field": "title"}, + "text": "the amsterdma meetpu", + }, + "my-suggest-2": {"term": {"field": "body"}, "text": "other"}, + } + } == s.to_dict() + + +def test_suggest(): + s = AsyncSearch() + s = s.suggest("my_suggestion", "pyhton", term={"field": "title"}) + + assert { + "suggest": {"my_suggestion": {"term": {"field": "title"}, "text": "pyhton"}} + } == s.to_dict() + + +def test_exclude(): + s = AsyncSearch() + s = s.exclude("match", title="python") + + assert { + "query": { + "bool": { + "filter": [{"bool": {"must_not": [{"match": {"title": "python"}}]}}] + } + } + } == s.to_dict() + + +async def test_delete_by_query(async_mock_client): + s = AsyncSearch(using="mock").query("match", lang="java") + await s.delete() + + async_mock_client.delete_by_query.assert_awaited_once_with( + index=None, body={"query": {"match": {"lang": "java"}}} + ) + + +def test_update_from_dict(): + s = AsyncSearch() + s.update_from_dict({"indices_boost": [{"important-documents": 2}]}) + s.update_from_dict({"_source": ["id", "name"]}) + s.update_from_dict({"collapse": {"field": "user_id"}}) + + assert { + "indices_boost": [{"important-documents": 2}], + "_source": ["id", "name"], + "collapse": {"field": "user_id"}, + } == s.to_dict() + + +def test_rescore_query_to_dict(): + s = AsyncSearch(index="index-name") + + positive_query = Q( + "function_score", + query=Q("term", tags="a"), + script_score={"script": "_score * 1"}, + ) + + negative_query = Q( + "function_score", + query=Q("term", tags="b"), + script_score={"script": "_score * -100"}, + ) + + s = s.query(positive_query) + s = s.extra( + rescore={"window_size": 100, "query": {"rescore_query": negative_query}} + ) + assert s.to_dict() == { + "query": { + "function_score": { + "query": {"term": {"tags": "a"}}, + "functions": [{"script_score": {"script": "_score * 1"}}], + } + }, + "rescore": { + "window_size": 100, + "query": { + "rescore_query": { + "function_score": { + "query": {"term": {"tags": "b"}}, + "functions": [{"script_score": {"script": "_score * -100"}}], + } + } + }, + }, + } + + assert s.to_dict( + rescore={"window_size": 10, "query": {"rescore_query": positive_query}} + ) == { + "query": { + "function_score": { + "query": {"term": {"tags": "a"}}, + "functions": [{"script_score": {"script": "_score * 1"}}], + } + }, + "rescore": { + "window_size": 10, + "query": { + "rescore_query": { + "function_score": { + "query": {"term": {"tags": "a"}}, + "functions": [{"script_score": {"script": "_score * 1"}}], + } + } + }, + }, + } diff --git a/tests/_async/test_update_by_query.py b/tests/_async/test_update_by_query.py new file mode 100644 index 000000000..5da7fde97 --- /dev/null +++ b/tests/_async/test_update_by_query.py @@ -0,0 +1,171 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from copy import deepcopy + +from elasticsearch_dsl import AsyncUpdateByQuery, Q +from elasticsearch_dsl.response import UpdateByQueryResponse + + +def test_ubq_starts_with_no_query(): + ubq = AsyncUpdateByQuery() + + assert ubq.query._proxied is None + + +def test_ubq_to_dict(): + ubq = AsyncUpdateByQuery() + assert {} == ubq.to_dict() + + ubq = ubq.query("match", f=42) + assert {"query": {"match": {"f": 42}}} == ubq.to_dict() + + assert {"query": {"match": {"f": 42}}, "size": 10} == ubq.to_dict(size=10) + + ubq = AsyncUpdateByQuery(extra={"size": 5}) + assert {"size": 5} == ubq.to_dict() + + ubq = AsyncUpdateByQuery(extra={"extra_q": Q("term", category="conference")}) + assert {"extra_q": {"term": {"category": "conference"}}} == ubq.to_dict() + + +def test_complex_example(): + ubq = AsyncUpdateByQuery() + ubq = ( + ubq.query("match", title="python") + .query(~Q("match", title="ruby")) + .filter(Q("term", category="meetup") | Q("term", category="conference")) + .script( + source="ctx._source.likes += params.f", lang="painless", params={"f": 3} + ) + ) + + ubq.query.minimum_should_match = 2 + assert { + "query": { + "bool": { + "filter": [ + { + "bool": { + "should": [ + {"term": {"category": "meetup"}}, + {"term": {"category": "conference"}}, + ] + } + } + ], + "must": [{"match": {"title": "python"}}], + "must_not": [{"match": {"title": "ruby"}}], + "minimum_should_match": 2, + } + }, + "script": { + "source": "ctx._source.likes += params.f", + "lang": "painless", + "params": {"f": 3}, + }, + } == ubq.to_dict() + + +def test_exclude(): + ubq = AsyncUpdateByQuery() + ubq = ubq.exclude("match", title="python") + + assert { + "query": { + "bool": { + "filter": [{"bool": {"must_not": [{"match": {"title": "python"}}]}}] + } + } + } == ubq.to_dict() + + +def test_reverse(): + d = { + "query": { + "filtered": { + "filter": { + "bool": { + "should": [ + {"term": {"category": "meetup"}}, + {"term": {"category": "conference"}}, + ] + } + }, + "query": { + "bool": { + "must": [{"match": {"title": "python"}}], + "must_not": [{"match": {"title": "ruby"}}], + "minimum_should_match": 2, + } + }, + } + }, + "script": { + "source": "ctx._source.likes += params.f", + "lang": "painless", + "params": {"f": 3}, + }, + } + + d2 = deepcopy(d) + + ubq = AsyncUpdateByQuery.from_dict(d) + + assert d == d2 + assert d == ubq.to_dict() + + +def test_from_dict_doesnt_need_query(): + ubq = AsyncUpdateByQuery.from_dict({"script": {"source": "test"}}) + + assert {"script": {"source": "test"}} == ubq.to_dict() + + +async def test_params_being_passed_to_search(async_mock_client): + ubq = AsyncUpdateByQuery(using="mock") + ubq = ubq.params(routing="42") + await ubq.execute() + + async_mock_client.update_by_query.assert_called_once_with(index=None, routing="42") + + +def test_overwrite_script(): + ubq = AsyncUpdateByQuery() + ubq = ubq.script( + source="ctx._source.likes += params.f", lang="painless", params={"f": 3} + ) + assert { + "script": { + "source": "ctx._source.likes += params.f", + "lang": "painless", + "params": {"f": 3}, + } + } == ubq.to_dict() + ubq = ubq.script(source="ctx._source.likes++") + assert {"script": {"source": "ctx._source.likes++"}} == ubq.to_dict() + + +def test_update_by_query_response_success(): + ubqr = UpdateByQueryResponse({}, {"timed_out": False, "failures": []}) + assert ubqr.success() + + ubqr = UpdateByQueryResponse({}, {"timed_out": True, "failures": []}) + assert not ubqr.success() + + ubqr = UpdateByQueryResponse({}, {"timed_out": False, "failures": [{}]}) + assert not ubqr.success() diff --git a/tests/_sync/__init__.py b/tests/_sync/__init__.py new file mode 100644 index 000000000..2a87d183f --- /dev/null +++ b/tests/_sync/__init__.py @@ -0,0 +1,16 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/test_document.py b/tests/_sync/test_document.py similarity index 92% rename from tests/test_document.py rename to tests/_sync/test_document.py index 3660d9bec..79c266ad0 100644 --- a/tests/test_document.py +++ b/tests/_sync/test_document.py @@ -24,12 +24,13 @@ from pytest import raises from elasticsearch_dsl import ( + Document, Index, InnerDoc, Mapping, + MetaField, Range, analyzer, - document, field, utils, ) @@ -40,7 +41,7 @@ class MyInner(InnerDoc): old_field = field.Text() -class MyDoc(document.Document): +class MyDoc(Document): title = field.Keyword() name = field.Text() created_at = field.Date() @@ -54,7 +55,7 @@ class Index: name = "default-index" -class MyDoc2(document.Document): +class MyDoc2(Document): extra = field.Long() @@ -62,19 +63,19 @@ class MyMultiSubDoc(MyDoc2, MySubDoc): pass -class Comment(document.InnerDoc): +class Comment(InnerDoc): title = field.Text() tags = field.Keyword(multi=True) -class DocWithNested(document.Document): +class DocWithNested(Document): comments = field.Nested(Comment) class Index: name = "test-doc-with-nested" -class SimpleCommit(document.Document): +class SimpleCommit(Document): files = field.Text(multi=True) class Index: @@ -97,28 +98,28 @@ def _deserialize(self, data): return Secret(codecs.decode(data, "rot_13")) -class SecretDoc(document.Document): +class SecretDoc(Document): title = SecretField(index="no") class Index: name = "test-secret-doc" -class NestedSecret(document.Document): +class NestedSecret(Document): secrets = field.Nested(SecretDoc) class Index: name = "test-nested-secret" -class OptionalObjectWithRequiredField(document.Document): +class OptionalObjectWithRequiredField(Document): comments = field.Nested(properties={"title": field.Keyword(required=True)}) class Index: name = "test-required" -class Host(document.Document): +class Host(Document): ip = field.Ip() class Index: @@ -126,7 +127,7 @@ class Index: def test_range_serializes_properly(): - class D(document.Document): + class D(Document): lr = field.LongRange() d = D(lr=Range(lt=42)) @@ -139,7 +140,7 @@ class D(document.Document): def test_range_deserializes_properly(): - class D(document.InnerDoc): + class D(InnerDoc): lr = field.LongRange() d = D.from_es({"lr": {"lt": 42}}, True) @@ -155,10 +156,10 @@ def test_resolve_nested(): def test_conflicting_mapping_raises_error_in_index_to_dict(): - class A(document.Document): + class A(Document): name = field.Text() - class B(document.Document): + class B(Document): name = field.Keyword() i = Index("i") @@ -181,7 +182,7 @@ def test_matches_uses_index(): def test_matches_with_no_name_always_matches(): - class D(document.Document): + class D(Document): pass assert D._matches({}) @@ -189,7 +190,7 @@ class D(document.Document): def test_matches_accepts_wildcards(): - class MyDoc(document.Document): + class MyDoc(Document): class Index: name = "my-*" @@ -347,14 +348,14 @@ def test_meta_is_accessible_even_on_empty_doc(): def test_meta_field_mapping(): - class User(document.Document): + class User(Document): username = field.Text() class Meta: - all = document.MetaField(enabled=False) - _index = document.MetaField(enabled=True) - dynamic = document.MetaField("strict") - dynamic_templates = document.MetaField([42]) + all = MetaField(enabled=False) + _index = MetaField(enabled=True) + dynamic = MetaField("strict") + dynamic_templates = MetaField([42]) assert { "properties": {"username": {"type": "text"}}, @@ -366,7 +367,7 @@ class Meta: def test_multi_value_fields(): - class Blog(document.Document): + class Blog(Document): tags = field.Keyword(multi=True) b = Blog() @@ -377,7 +378,7 @@ class Blog(document.Document): def test_docs_with_properties(): - class User(document.Document): + class User(Document): pwd_hash = field.Text() def check_password(self, pwd): @@ -449,7 +450,7 @@ def test_to_dict_ignores_empty_collections(): def test_declarative_mapping_definition(): - assert issubclass(MyDoc, document.Document) + assert issubclass(MyDoc, Document) assert hasattr(MyDoc, "_doc_type") assert { "properties": { @@ -462,7 +463,7 @@ def test_declarative_mapping_definition(): def test_you_can_supply_own_mapping_instance(): - class MyD(document.Document): + class MyD(Document): title = field.Text() class Meta: @@ -505,7 +506,7 @@ def test_invalid_date_will_raise_exception(): def test_document_inheritance(): assert issubclass(MySubDoc, MyDoc) - assert issubclass(MySubDoc, document.Document) + assert issubclass(MySubDoc, Document) assert hasattr(MySubDoc, "_doc_type") assert { "properties": { @@ -518,7 +519,7 @@ def test_document_inheritance(): def test_child_class_can_override_parent(): - class A(document.Document): + class A(Document): o = field.Object(dynamic=False, properties={"a": field.Text()}) class B(A): @@ -548,7 +549,7 @@ def test_meta_fields_are_stored_in_meta_and_ignored_by_to_dict(): def test_index_inheritance(): assert issubclass(MyMultiSubDoc, MySubDoc) assert issubclass(MyMultiSubDoc, MyDoc2) - assert issubclass(MyMultiSubDoc, document.Document) + assert issubclass(MyMultiSubDoc, Document) assert hasattr(MyMultiSubDoc, "_doc_type") assert hasattr(MyMultiSubDoc, "_index") assert { @@ -587,7 +588,7 @@ def test_update_no_fields(): md.update() -def test_search_with_custom_alias_and_index(mock_client): +def test_search_with_custom_alias_and_index(): search_object = MyDoc.search( using="staging", index=["custom_index1", "custom_index2"] ) @@ -609,7 +610,7 @@ def test_from_es_respects_underscored_non_meta_fields(): }, } - class Company(document.Document): + class Company(Document): class Index: name = "test-company" diff --git a/tests/test_faceted_search.py b/tests/_sync/test_faceted_search.py similarity index 100% rename from tests/test_faceted_search.py rename to tests/_sync/test_faceted_search.py diff --git a/tests/test_index.py b/tests/_sync/test_index.py similarity index 98% rename from tests/test_index.py rename to tests/_sync/test_index.py index 420160e50..8bef36fa6 100644 --- a/tests/test_index.py +++ b/tests/_sync/test_index.py @@ -68,7 +68,7 @@ def test_cloned_index_has_copied_settings_and_using(): def test_cloned_index_has_analysis_attribute(): """ - Regression test for Issue #582 in which `Index.clone()` was not copying + Regression test for Issue #582 in which `AsyncIndex.clone()` was not copying over the `_analysis` attribute. """ client = object() diff --git a/tests/_sync/test_mapping.py b/tests/_sync/test_mapping.py new file mode 100644 index 000000000..500b5dde7 --- /dev/null +++ b/tests/_sync/test_mapping.py @@ -0,0 +1,222 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json + +from elasticsearch_dsl import Keyword, Mapping, Nested, Text, analysis + + +def test_mapping_can_has_fields(): + m = Mapping() + m.field("name", "text").field("tags", "keyword") + + assert { + "properties": {"name": {"type": "text"}, "tags": {"type": "keyword"}} + } == m.to_dict() + + +def test_mapping_update_is_recursive(): + m1 = Mapping() + m1.field("title", "text") + m1.field("author", "object") + m1.field("author", "object", properties={"name": {"type": "text"}}) + m1.meta("_all", enabled=False) + m1.meta("dynamic", False) + + m2 = Mapping() + m2.field("published_from", "date") + m2.field("author", "object", properties={"email": {"type": "text"}}) + m2.field("title", "text") + m2.field("lang", "keyword") + m2.meta("_analyzer", path="lang") + + m1.update(m2, update_only=True) + + assert { + "_all": {"enabled": False}, + "_analyzer": {"path": "lang"}, + "dynamic": False, + "properties": { + "published_from": {"type": "date"}, + "title": {"type": "text"}, + "lang": {"type": "keyword"}, + "author": { + "type": "object", + "properties": {"name": {"type": "text"}, "email": {"type": "text"}}, + }, + }, + } == m1.to_dict() + + +def test_properties_can_iterate_over_all_the_fields(): + m = Mapping() + m.field("f1", "text", test_attr="f1", fields={"f2": Keyword(test_attr="f2")}) + m.field("f3", Nested(test_attr="f3", properties={"f4": Text(test_attr="f4")})) + + assert {"f1", "f2", "f3", "f4"} == { + f.test_attr for f in m.properties._collect_fields() + } + + +def test_mapping_can_collect_all_analyzers_and_normalizers(): + a1 = analysis.analyzer( + "my_analyzer1", + tokenizer="keyword", + filter=[ + "lowercase", + analysis.token_filter("my_filter1", "stop", stopwords=["a", "b"]), + ], + ) + a2 = analysis.analyzer("english") + a3 = analysis.analyzer("unknown_custom") + a4 = analysis.analyzer( + "my_analyzer2", + tokenizer=analysis.tokenizer("trigram", "nGram", min_gram=3, max_gram=3), + filter=[analysis.token_filter("my_filter2", "stop", stopwords=["c", "d"])], + ) + a5 = analysis.analyzer("my_analyzer3", tokenizer="keyword") + n1 = analysis.normalizer("my_normalizer1", filter=["lowercase"]) + n2 = analysis.normalizer( + "my_normalizer2", + filter=[ + "my_filter1", + "my_filter2", + analysis.token_filter("my_filter3", "stop", stopwords=["e", "f"]), + ], + ) + n3 = analysis.normalizer("unknown_custom") + + m = Mapping() + m.field( + "title", + "text", + analyzer=a1, + fields={"english": Text(analyzer=a2), "unknown": Keyword(search_analyzer=a3)}, + ) + m.field("comments", Nested(properties={"author": Text(analyzer=a4)})) + m.field("normalized_title", "keyword", normalizer=n1) + m.field("normalized_comment", "keyword", normalizer=n2) + m.field("unknown", "keyword", normalizer=n3) + m.meta("_all", analyzer=a5) + + assert { + "analyzer": { + "my_analyzer1": { + "filter": ["lowercase", "my_filter1"], + "tokenizer": "keyword", + "type": "custom", + }, + "my_analyzer2": { + "filter": ["my_filter2"], + "tokenizer": "trigram", + "type": "custom", + }, + "my_analyzer3": {"tokenizer": "keyword", "type": "custom"}, + }, + "normalizer": { + "my_normalizer1": {"filter": ["lowercase"], "type": "custom"}, + "my_normalizer2": { + "filter": ["my_filter1", "my_filter2", "my_filter3"], + "type": "custom", + }, + }, + "filter": { + "my_filter1": {"stopwords": ["a", "b"], "type": "stop"}, + "my_filter2": {"stopwords": ["c", "d"], "type": "stop"}, + "my_filter3": {"stopwords": ["e", "f"], "type": "stop"}, + }, + "tokenizer": {"trigram": {"max_gram": 3, "min_gram": 3, "type": "nGram"}}, + } == m._collect_analysis() + + assert json.loads(json.dumps(m.to_dict())) == m.to_dict() + + +def test_mapping_can_collect_multiple_analyzers(): + a1 = analysis.analyzer( + "my_analyzer1", + tokenizer="keyword", + filter=[ + "lowercase", + analysis.token_filter("my_filter1", "stop", stopwords=["a", "b"]), + ], + ) + a2 = analysis.analyzer( + "my_analyzer2", + tokenizer=analysis.tokenizer("trigram", "nGram", min_gram=3, max_gram=3), + filter=[analysis.token_filter("my_filter2", "stop", stopwords=["c", "d"])], + ) + m = Mapping() + m.field("title", "text", analyzer=a1, search_analyzer=a2) + m.field( + "text", + "text", + analyzer=a1, + fields={ + "english": Text(analyzer=a1), + "unknown": Keyword(analyzer=a1, search_analyzer=a2), + }, + ) + assert { + "analyzer": { + "my_analyzer1": { + "filter": ["lowercase", "my_filter1"], + "tokenizer": "keyword", + "type": "custom", + }, + "my_analyzer2": { + "filter": ["my_filter2"], + "tokenizer": "trigram", + "type": "custom", + }, + }, + "filter": { + "my_filter1": {"stopwords": ["a", "b"], "type": "stop"}, + "my_filter2": {"stopwords": ["c", "d"], "type": "stop"}, + }, + "tokenizer": {"trigram": {"max_gram": 3, "min_gram": 3, "type": "nGram"}}, + } == m._collect_analysis() + + +def test_even_non_custom_analyzers_can_have_params(): + a1 = analysis.analyzer("whitespace", type="pattern", pattern=r"\\s+") + m = Mapping() + m.field("title", "text", analyzer=a1) + + assert { + "analyzer": {"whitespace": {"type": "pattern", "pattern": r"\\s+"}} + } == m._collect_analysis() + + +def test_resolve_field_can_resolve_multifields(): + m = Mapping() + m.field("title", "text", fields={"keyword": Keyword()}) + + assert isinstance(m.resolve_field("title.keyword"), Keyword) + + +def test_resolve_nested(): + m = Mapping() + m.field("n1", "nested", properties={"n2": Nested(properties={"k1": Keyword()})}) + m.field("k2", "keyword") + + nested, field = m.resolve_nested("n1.n2.k1") + assert nested == ["n1", "n1.n2"] + assert isinstance(field, Keyword) + + nested, field = m.resolve_nested("k2") + assert nested == [] + assert isinstance(field, Keyword) diff --git a/tests/test_search.py b/tests/_sync/test_search.py similarity index 88% rename from tests/test_search.py rename to tests/_sync/test_search.py index 841caa7cc..306b09f53 100644 --- a/tests/test_search.py +++ b/tests/_sync/test_search.py @@ -19,18 +19,18 @@ from pytest import raises -from elasticsearch_dsl import Document, Q, query, search +from elasticsearch_dsl import Document, Q, Search, query from elasticsearch_dsl.exceptions import IllegalOperation def test_expand__to_dot_is_respected(): - s = search.Search().query("match", a__b=42, _expand__to_dot=False) + s = Search().query("match", a__b=42, _expand__to_dot=False) assert {"query": {"match": {"a__b": 42}}} == s.to_dict() def test_execute_uses_cache(): - s = search.Search() + s = Search() r = object() s._response = r @@ -38,7 +38,7 @@ def test_execute_uses_cache(): def test_cache_can_be_ignored(mock_client): - s = search.Search(using="mock") + s = Search(using="mock") r = object() s._response = r s.execute(ignore_cache=True) @@ -47,27 +47,27 @@ def test_cache_can_be_ignored(mock_client): def test_iter_iterates_over_hits(): - s = search.Search() + s = Search() s._response = [1, 2, 3] - assert [1, 2, 3] == list(s) + assert [1, 2, 3] == [hit for hit in s] def test_cache_isnt_cloned(): - s = search.Search() + s = Search() s._response = object() assert not hasattr(s._clone(), "_response") def test_search_starts_with_no_query(): - s = search.Search() + s = Search() assert s.query._proxied is None def test_search_query_combines_query(): - s = search.Search() + s = Search() s2 = s.query("match", f=42) assert s2.query._proxied == query.Match(f=42) @@ -79,7 +79,7 @@ def test_search_query_combines_query(): def test_query_can_be_assigned_to(): - s = search.Search() + s = Search() q = Q("match", title="python") s.query = q @@ -88,7 +88,7 @@ def test_query_can_be_assigned_to(): def test_query_can_be_wrapped(): - s = search.Search().query("match", title="python") + s = Search().query("match", title="python") s.query = Q("function_score", query=s.query, field_value_factor={"field": "rating"}) @@ -105,7 +105,7 @@ def test_query_can_be_wrapped(): def test_using(): o = object() o2 = object() - s = search.Search(using=o) + s = Search(using=o) assert s._using is o s2 = s.using(o2) assert s._using is o @@ -113,27 +113,27 @@ def test_using(): def test_methods_are_proxied_to_the_query(): - s = search.Search().query("match_all") + s = Search().query("match_all") assert s.query.to_dict() == {"match_all": {}} def test_query_always_returns_search(): - s = search.Search() + s = Search() - assert isinstance(s.query("match", f=42), search.Search) + assert isinstance(s.query("match", f=42), Search) def test_source_copied_on_clone(): - s = search.Search().source(False) + s = Search().source(False) assert s._clone()._source == s._source assert s._clone()._source is False - s2 = search.Search().source([]) + s2 = Search().source([]) assert s2._clone()._source == s2._source assert s2._source == [] - s3 = search.Search().source(["some", "fields"]) + s3 = Search().source(["some", "fields"]) assert s3._clone()._source == s3._source assert s3._clone()._source == ["some", "fields"] @@ -141,7 +141,7 @@ def test_source_copied_on_clone(): def test_copy_clones(): from copy import copy - s1 = search.Search().source(["some", "fields"]) + s1 = Search().source(["some", "fields"]) s2 = copy(s1) assert s1 == s2 @@ -149,7 +149,7 @@ def test_copy_clones(): def test_aggs_allow_two_metric(): - s = search.Search() + s = Search() s.aggs.metric("a", "max", field="a").metric("b", "max", field="b") @@ -159,7 +159,7 @@ def test_aggs_allow_two_metric(): def test_aggs_get_copied_on_change(): - s = search.Search().query("match_all") + s = Search().query("match_all") s.aggs.bucket("per_tag", "terms", field="f").metric( "max_score", "max", field="score" ) @@ -191,7 +191,7 @@ def test_aggs_get_copied_on_change(): def test_search_index(): - s = search.Search(index="i") + s = Search(index="i") assert s._index == ["i"] s = s.index("i2") assert s._index == ["i", "i2"] @@ -199,17 +199,17 @@ def test_search_index(): assert s._index == ["i", "i2", "i3"] s = s.index() assert s._index is None - s = search.Search(index=("i", "i2")) + s = Search(index=("i", "i2")) assert s._index == ["i", "i2"] - s = search.Search(index=["i", "i2"]) + s = Search(index=["i", "i2"]) assert s._index == ["i", "i2"] - s = search.Search() + s = Search() s = s.index("i", "i2") assert s._index == ["i", "i2"] s2 = s.index("i3") assert s._index == ["i", "i2"] assert s2._index == ["i", "i2", "i3"] - s = search.Search() + s = Search() s = s.index(["i", "i2"], "i3") assert s._index == ["i", "i2", "i3"] s2 = s.index("i4") @@ -225,17 +225,17 @@ def test_doc_type_document_class(): class MyDocument(Document): pass - s = search.Search(doc_type=MyDocument) + s = Search(doc_type=MyDocument) assert s._doc_type == [MyDocument] assert s._doc_type_map == {} - s = search.Search().doc_type(MyDocument) + s = Search().doc_type(MyDocument) assert s._doc_type == [MyDocument] assert s._doc_type_map == {} def test_knn(): - s = search.Search() + s = Search() with raises(TypeError): s.knn() @@ -289,7 +289,7 @@ def test_knn(): def test_rank(): - s = search.Search() + s = Search() s.rank(rrf=False) assert {} == s.to_dict() @@ -301,7 +301,7 @@ def test_rank(): def test_sort(): - s = search.Search() + s = Search() s = s.sort("fielda", "-fieldb") assert ["fielda", {"fieldb": {"order": "desc"}}] == s._sort @@ -309,21 +309,21 @@ def test_sort(): s = s.sort() assert [] == s._sort - assert search.Search().to_dict() == s.to_dict() + assert Search().to_dict() == s.to_dict() def test_sort_by_score(): - s = search.Search() + s = Search() s = s.sort("_score") assert {"sort": ["_score"]} == s.to_dict() - s = search.Search() + s = Search() with raises(IllegalOperation): s.sort("-_score") def test_collapse(): - s = search.Search() + s = Search() inner_hits = {"name": "most_recent", "size": 5, "sort": [{"@timestamp": "desc"}]} s = s.collapse("user.id", inner_hits=inner_hits, max_concurrent_group_searches=4) @@ -351,11 +351,11 @@ def test_collapse(): s = s.collapse() assert {} == s._collapse - assert search.Search().to_dict() == s.to_dict() + assert Search().to_dict() == s.to_dict() def test_slice(): - s = search.Search() + s = Search() assert {"from": 3, "size": 7} == s[3:10].to_dict() assert {"from": 0, "size": 5} == s[:5].to_dict() assert {"from": 3, "size": 10} == s[3:].to_dict() @@ -364,12 +364,12 @@ def test_slice(): def test_index(): - s = search.Search() + s = Search() assert {"from": 3, "size": 1} == s[3].to_dict() def test_search_to_dict(): - s = search.Search() + s = Search() assert {} == s.to_dict() s = s.query("match", f=42) @@ -391,14 +391,14 @@ def test_search_to_dict(): } assert d == s.to_dict() - s = search.Search(extra={"size": 5}) + s = Search(extra={"size": 5}) assert {"size": 5} == s.to_dict() s = s.extra(from_=42) assert {"size": 5, "from": 42} == s.to_dict() def test_complex_example(): - s = search.Search() + s = Search() s = ( s.query("match", title="python") .query(~Q("match", title="ruby")) @@ -492,7 +492,7 @@ def test_reverse(): d2 = deepcopy(d) - s = search.Search.from_dict(d) + s = Search.from_dict(d) # make sure we haven't modified anything in place assert d == d2 @@ -501,13 +501,13 @@ def test_reverse(): def test_from_dict_doesnt_need_query(): - s = search.Search.from_dict({"size": 5}) + s = Search.from_dict({"size": 5}) assert {"size": 5} == s.to_dict() def test_params_being_passed_to_search(mock_client): - s = search.Search(using="mock") + s = Search(using="mock") s = s.params(routing="42") s.execute() @@ -515,15 +515,15 @@ def test_params_being_passed_to_search(mock_client): def test_source(): - assert {} == search.Search().source().to_dict() + assert {} == Search().source().to_dict() assert { "_source": {"includes": ["foo.bar.*"], "excludes": ["foo.one"]} - } == search.Search().source(includes=["foo.bar.*"], excludes=["foo.one"]).to_dict() + } == Search().source(includes=["foo.bar.*"], excludes=["foo.one"]).to_dict() - assert {"_source": False} == search.Search().source(False).to_dict() + assert {"_source": False} == Search().source(False).to_dict() - assert {"_source": ["f1", "f2"]} == search.Search().source( + assert {"_source": ["f1", "f2"]} == Search().source( includes=["foo.bar.*"], excludes=["foo.one"] ).source(["f1", "f2"]).to_dict() @@ -532,21 +532,19 @@ def test_source_on_clone(): assert { "_source": {"includes": ["foo.bar.*"], "excludes": ["foo.one"]}, "query": {"bool": {"filter": [{"term": {"title": "python"}}]}}, - } == search.Search().source(includes=["foo.bar.*"]).source( - excludes=["foo.one"] - ).filter( + } == Search().source(includes=["foo.bar.*"]).source(excludes=["foo.one"]).filter( "term", title="python" ).to_dict() assert { "_source": False, "query": {"bool": {"filter": [{"term": {"title": "python"}}]}}, - } == search.Search().source(False).filter("term", title="python").to_dict() + } == Search().source(False).filter("term", title="python").to_dict() def test_source_on_clear(): assert ( {} - == search.Search() + == Search() .source(includes=["foo.bar.*"]) .source(includes=None, excludes=None) .to_dict() @@ -554,7 +552,7 @@ def test_source_on_clear(): def test_suggest_accepts_global_text(): - s = search.Search.from_dict( + s = Search.from_dict( { "suggest": { "text": "the amsterdma meetpu", @@ -576,7 +574,7 @@ def test_suggest_accepts_global_text(): def test_suggest(): - s = search.Search() + s = Search() s = s.suggest("my_suggestion", "pyhton", term={"field": "title"}) assert { @@ -585,7 +583,7 @@ def test_suggest(): def test_exclude(): - s = search.Search() + s = Search() s = s.exclude("match", title="python") assert { @@ -598,7 +596,7 @@ def test_exclude(): def test_delete_by_query(mock_client): - s = search.Search(using="mock").query("match", lang="java") + s = Search(using="mock").query("match", lang="java") s.delete() mock_client.delete_by_query.assert_called_once_with( @@ -607,7 +605,7 @@ def test_delete_by_query(mock_client): def test_update_from_dict(): - s = search.Search() + s = Search() s.update_from_dict({"indices_boost": [{"important-documents": 2}]}) s.update_from_dict({"_source": ["id", "name"]}) s.update_from_dict({"collapse": {"field": "user_id"}}) @@ -620,7 +618,7 @@ def test_update_from_dict(): def test_rescore_query_to_dict(): - s = search.Search(index="index-name") + s = Search(index="index-name") positive_query = Q( "function_score", diff --git a/tests/test_update_by_query.py b/tests/_sync/test_update_by_query.py similarity index 100% rename from tests/test_update_by_query.py rename to tests/_sync/test_update_by_query.py diff --git a/tests/conftest.py b/tests/conftest.py index 0e5e082de..ea7ff6cbd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,21 +16,27 @@ # under the License. +import asyncio import os import re import time from datetime import datetime from unittest import SkipTest, TestCase -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock +import pytest_asyncio from elastic_transport import ObjectApiResponse -from elasticsearch import Elasticsearch +from elasticsearch import AsyncElasticsearch, Elasticsearch from elasticsearch.exceptions import ConnectionError from elasticsearch.helpers import bulk from pytest import fixture, skip +from elasticsearch_dsl.async_connections import add_connection as add_async_connection +from elasticsearch_dsl.async_connections import connections as async_connections from elasticsearch_dsl.connections import add_connection, connections +from .test_integration._async import test_document as async_document +from .test_integration._sync import test_document as sync_document from .test_integration.test_data import ( DATA, FLAT_DATA, @@ -38,7 +44,6 @@ create_flat_git_index, create_git_index, ) -from .test_integration.test_document import Comment, History, PullRequest, User if "ELASTICSEARCH_URL" in os.environ: ELASTICSEARCH_URL = os.environ["ELASTICSEARCH_URL"] @@ -73,6 +78,34 @@ def get_test_client(wait=True, **kwargs): raise SkipTest("Elasticsearch failed to start.") +async def get_async_test_client(wait=True, **kwargs): + # construct kwargs from the environment + kw = {"request_timeout": 30} + + if "PYTHON_CONNECTION_CLASS" in os.environ: + from elasticsearch import connection + + kw["connection_class"] = getattr( + connection, os.environ["PYTHON_CONNECTION_CLASS"] + ) + + kw.update(kwargs) + client = AsyncElasticsearch(ELASTICSEARCH_URL, **kw) + + # wait for yellow status + for tries_left in range(100 if wait else 1, 0, -1): + try: + await client.cluster.health(wait_for_status="yellow") + return client + except ConnectionError: + if wait and tries_left == 1: + raise + await asyncio.sleep(0.1) + + await client.close() + raise SkipTest("Elasticsearch failed to start.") + + class ElasticsearchTestCase(TestCase): @staticmethod def _get_client(): @@ -119,6 +152,17 @@ def client(): skip() +@pytest_asyncio.fixture +async def async_client(): + try: + connection = await get_async_test_client(wait="WAIT_FOR_ES" in os.environ) + add_async_connection("default", connection) + yield connection + await connection.close() + except SkipTest: + skip() + + @fixture(scope="session") def es_version(client): info = client.info() @@ -137,16 +181,36 @@ def write_client(client): client.options(ignore_status=404).indices.delete_template(name="test-template") +@pytest_asyncio.fixture +async def async_write_client(write_client, async_client): + yield async_client + + @fixture def mock_client(dummy_response): client = Mock() client.search.return_value = dummy_response add_connection("mock", client) + yield client connections._conn = {} connections._kwargs = {} +@fixture +def async_mock_client(dummy_response): + client = Mock() + client.search = AsyncMock(return_value=dummy_response) + client.indices = AsyncMock() + client.update_by_query = AsyncMock() + client.delete_by_query = AsyncMock() + add_async_connection("mock", client) + + yield client + async_connections._conn = {} + async_connections._kwargs = {} + + @fixture(scope="session") def data_client(client): # create mappings @@ -160,6 +224,11 @@ def data_client(client): client.indices.delete(index="flat-git") +@pytest_asyncio.fixture +async def async_data_client(data_client, async_client): + yield async_client + + @fixture def dummy_response(): return ObjectApiResponse( @@ -367,18 +436,16 @@ def aggs_data(): } -@fixture -def pull_request(write_client): - PullRequest.init() - pr = PullRequest( +def make_pr(pr_module): + return pr_module.PullRequest( _id=42, comments=[ - Comment( + pr_module.Comment( content="Hello World!", - author=User(name="honzakral"), + author=pr_module.User(name="honzakral"), created_at=datetime(2018, 1, 9, 10, 17, 3, 21184), history=[ - History( + pr_module.History( timestamp=datetime(2012, 1, 1), diff="-Ahoj Svete!\n+Hello World!", ) @@ -387,10 +454,24 @@ def pull_request(write_client): ], created_at=datetime(2018, 1, 9, 9, 17, 3, 21184), ) + + +@fixture +def pull_request(write_client): + sync_document.PullRequest.init() + pr = make_pr(sync_document) pr.save(refresh=True) return pr +@pytest_asyncio.fixture +async def async_pull_request(async_write_client): + await async_document.PullRequest.init() + pr = make_pr(async_document) + await pr.save(refresh=True) + return pr + + @fixture def setup_ubq_tests(client): index = "test-git" diff --git a/tests/test_integration/_async/__init__.py b/tests/test_integration/_async/__init__.py new file mode 100644 index 000000000..2a87d183f --- /dev/null +++ b/tests/test_integration/_async/__init__.py @@ -0,0 +1,16 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/test_integration/_async/test_analysis.py b/tests/test_integration/_async/test_analysis.py new file mode 100644 index 000000000..159b68d3e --- /dev/null +++ b/tests/test_integration/_async/test_analysis.py @@ -0,0 +1,46 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from elasticsearch_dsl import analyzer, token_filter, tokenizer + + +async def test_simulate_with_just__builtin_tokenizer(async_client): + a = analyzer("my-analyzer", tokenizer="keyword") + tokens = (await a.async_simulate("Hello World!", using=async_client)).tokens + + assert len(tokens) == 1 + assert tokens[0].token == "Hello World!" + + +async def test_simulate_complex(async_client): + a = analyzer( + "my-analyzer", + tokenizer=tokenizer("split_words", "simple_pattern_split", pattern=":"), + filter=["lowercase", token_filter("no-ifs", "stop", stopwords=["if"])], + ) + + tokens = (await a.async_simulate("if:this:works", using=async_client)).tokens + + assert len(tokens) == 2 + assert ["this", "works"] == [t.token for t in tokens] + + +async def test_simulate_builtin(async_client): + a = analyzer("my-analyzer", "english") + tokens = (await a.async_simulate("fixes running")).tokens + + assert ["fix", "run"] == [t.token for t in tokens] diff --git a/tests/test_integration/_async/test_document.py b/tests/test_integration/_async/test_document.py new file mode 100644 index 000000000..6bb94aeaa --- /dev/null +++ b/tests/test_integration/_async/test_document.py @@ -0,0 +1,574 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from datetime import datetime +from ipaddress import ip_address + +import pytest +from elasticsearch import ConflictError, NotFoundError +from pytest import raises +from pytz import timezone + +from elasticsearch_dsl import ( + AsyncDocument, + Binary, + Boolean, + Date, + Double, + InnerDoc, + Ip, + Keyword, + Long, + Mapping, + MetaField, + Nested, + Object, + Q, + RankFeatures, + Text, + analyzer, +) +from elasticsearch_dsl.utils import AttrList + +snowball = analyzer("my_snow", tokenizer="standard", filter=["lowercase", "snowball"]) + + +class User(InnerDoc): + name = Text(fields={"raw": Keyword()}) + + +class Wiki(AsyncDocument): + owner = Object(User) + views = Long() + ranked = RankFeatures() + + class Index: + name = "test-wiki" + + +class Repository(AsyncDocument): + owner = Object(User) + created_at = Date() + description = Text(analyzer=snowball) + tags = Keyword() + + @classmethod + def search(cls): + return super().search().filter("term", commit_repo="repo") + + class Index: + name = "git" + + +class Commit(AsyncDocument): + committed_date = Date() + authored_date = Date() + description = Text(analyzer=snowball) + + class Index: + name = "flat-git" + + class Meta: + mapping = Mapping() + + +class History(InnerDoc): + timestamp = Date() + diff = Text() + + +class Comment(InnerDoc): + content = Text() + created_at = Date() + author = Object(User) + history = Nested(History) + + class Meta: + dynamic = MetaField(False) + + +class PullRequest(AsyncDocument): + comments = Nested(Comment) + created_at = Date() + + class Index: + name = "test-prs" + + +class SerializationDoc(AsyncDocument): + i = Long() + b = Boolean() + d = Double() + bin = Binary() + ip = Ip() + + class Index: + name = "test-serialization" + + +async def test_serialization(async_write_client): + await SerializationDoc.init() + await async_write_client.index( + index="test-serialization", + id=42, + body={ + "i": [1, 2, "3", None], + "b": [True, False, "true", "false", None], + "d": [0.1, "-0.1", None], + "bin": ["SGVsbG8gV29ybGQ=", None], + "ip": ["::1", "127.0.0.1", None], + }, + ) + sd = await SerializationDoc.get(id=42) + + assert sd.i == [1, 2, 3, None] + assert sd.b == [True, False, True, False, None] + assert sd.d == [0.1, -0.1, None] + assert sd.bin == [b"Hello World", None] + assert sd.ip == [ip_address("::1"), ip_address("127.0.0.1"), None] + + assert sd.to_dict() == { + "b": [True, False, True, False, None], + "bin": ["SGVsbG8gV29ybGQ=", None], + "d": [0.1, -0.1, None], + "i": [1, 2, 3, None], + "ip": ["::1", "127.0.0.1", None], + } + + +async def test_nested_inner_hits_are_wrapped_properly(async_pull_request): + history_query = Q( + "nested", + path="comments.history", + inner_hits={}, + query=Q("match", comments__history__diff="ahoj"), + ) + s = PullRequest.search().query( + "nested", inner_hits={}, path="comments", query=history_query + ) + + response = await s.execute() + pr = response.hits[0] + assert isinstance(pr, PullRequest) + assert isinstance(pr.comments[0], Comment) + assert isinstance(pr.comments[0].history[0], History) + + comment = pr.meta.inner_hits.comments.hits[0] + assert isinstance(comment, Comment) + assert comment.author.name == "honzakral" + assert isinstance(comment.history[0], History) + + history = comment.meta.inner_hits["comments.history"].hits[0] + assert isinstance(history, History) + assert history.timestamp == datetime(2012, 1, 1) + assert "score" in history.meta + + +async def test_nested_inner_hits_are_deserialized_properly(async_pull_request): + s = PullRequest.search().query( + "nested", + inner_hits={}, + path="comments", + query=Q("match", comments__content="hello"), + ) + + response = await s.execute() + pr = response.hits[0] + assert isinstance(pr.created_at, datetime) + assert isinstance(pr.comments[0], Comment) + assert isinstance(pr.comments[0].created_at, datetime) + + +async def test_nested_top_hits_are_wrapped_properly(async_pull_request): + s = PullRequest.search() + s.aggs.bucket("comments", "nested", path="comments").metric( + "hits", "top_hits", size=1 + ) + + r = await s.execute() + + print(r._d_) + assert isinstance(r.aggregations.comments.hits.hits[0], Comment) + + +async def test_update_object_field(async_write_client): + await Wiki.init() + w = Wiki( + owner=User(name="Honza Kral"), + _id="elasticsearch-py", + ranked={"test1": 0.1, "topic2": 0.2}, + ) + await w.save() + + assert "updated" == await w.update(owner=[{"name": "Honza"}, {"name": "Nick"}]) + assert w.owner[0].name == "Honza" + assert w.owner[1].name == "Nick" + + w = await Wiki.get(id="elasticsearch-py") + assert w.owner[0].name == "Honza" + assert w.owner[1].name == "Nick" + + assert w.ranked == {"test1": 0.1, "topic2": 0.2} + + +async def test_update_script(async_write_client): + await Wiki.init() + w = Wiki(owner=User(name="Honza Kral"), _id="elasticsearch-py", views=42) + await w.save() + + await w.update(script="ctx._source.views += params.inc", inc=5) + w = await Wiki.get(id="elasticsearch-py") + assert w.views == 47 + + +async def test_update_retry_on_conflict(async_write_client): + await Wiki.init() + w = Wiki(owner=User(name="Honza Kral"), _id="elasticsearch-py", views=42) + await w.save() + + w1 = await Wiki.get(id="elasticsearch-py") + w2 = await Wiki.get(id="elasticsearch-py") + await w1.update( + script="ctx._source.views += params.inc", inc=5, retry_on_conflict=1 + ) + await w2.update( + script="ctx._source.views += params.inc", inc=5, retry_on_conflict=1 + ) + + w = await Wiki.get(id="elasticsearch-py") + assert w.views == 52 + + +@pytest.mark.parametrize("retry_on_conflict", [None, 0]) +async def test_update_conflicting_version(async_write_client, retry_on_conflict): + await Wiki.init() + w = Wiki(owner=User(name="Honza Kral"), _id="elasticsearch-py", views=42) + await w.save() + + w1 = await Wiki.get(id="elasticsearch-py") + w2 = await Wiki.get(id="elasticsearch-py") + await w1.update(script="ctx._source.views += params.inc", inc=5) + + with raises(ConflictError): + await w2.update( + script="ctx._source.views += params.inc", + inc=5, + retry_on_conflict=retry_on_conflict, + ) + + +async def test_save_and_update_return_doc_meta(async_write_client): + await Wiki.init() + w = Wiki(owner=User(name="Honza Kral"), _id="elasticsearch-py", views=42) + resp = await w.save(return_doc_meta=True) + assert resp["_index"] == "test-wiki" + assert resp["result"] == "created" + assert set(resp.keys()) == { + "_id", + "_index", + "_primary_term", + "_seq_no", + "_shards", + "_version", + "result", + } + + resp = await w.update( + script="ctx._source.views += params.inc", inc=5, return_doc_meta=True + ) + assert resp["_index"] == "test-wiki" + assert resp["result"] == "updated" + assert set(resp.keys()) == { + "_id", + "_index", + "_primary_term", + "_seq_no", + "_shards", + "_version", + "result", + } + + +async def test_init(async_write_client): + await Repository.init(index="test-git") + + assert await async_write_client.indices.exists(index="test-git") + + +async def test_get_raises_404_on_index_missing(async_data_client): + with raises(NotFoundError): + await Repository.get("elasticsearch-dsl-php", index="not-there") + + +async def test_get_raises_404_on_non_existent_id(async_data_client): + with raises(NotFoundError): + await Repository.get("elasticsearch-dsl-php") + + +async def test_get_returns_none_if_404_ignored(async_data_client): + assert None is await Repository.get( + "elasticsearch-dsl-php", using=async_data_client.options(ignore_status=404) + ) + + +async def test_get_returns_none_if_404_ignored_and_index_doesnt_exist( + async_data_client, +): + assert None is await Repository.get( + "42", index="not-there", using=async_data_client.options(ignore_status=404) + ) + + +async def test_get(async_data_client): + elasticsearch_repo = await Repository.get("elasticsearch-dsl-py") + + assert isinstance(elasticsearch_repo, Repository) + assert elasticsearch_repo.owner.name == "elasticsearch" + assert datetime(2014, 3, 3) == elasticsearch_repo.created_at + + +async def test_exists_return_true(async_data_client): + assert await Repository.exists("elasticsearch-dsl-py") + + +async def test_exists_false(async_data_client): + assert not await Repository.exists("elasticsearch-dsl-php") + + +async def test_get_with_tz_date(async_data_client): + first_commit = await Commit.get( + id="3ca6e1e73a071a705b4babd2f581c91a2a3e5037", routing="elasticsearch-dsl-py" + ) + + tzinfo = timezone("Europe/Prague") + assert ( + tzinfo.localize(datetime(2014, 5, 2, 13, 47, 19, 123000)) + == first_commit.authored_date + ) + + +async def test_save_with_tz_date(async_data_client): + tzinfo = timezone("Europe/Prague") + first_commit = await Commit.get( + id="3ca6e1e73a071a705b4babd2f581c91a2a3e5037", routing="elasticsearch-dsl-py" + ) + first_commit.committed_date = tzinfo.localize( + datetime(2014, 5, 2, 13, 47, 19, 123456) + ) + await first_commit.save() + + first_commit = await Commit.get( + id="3ca6e1e73a071a705b4babd2f581c91a2a3e5037", routing="elasticsearch-dsl-py" + ) + assert ( + tzinfo.localize(datetime(2014, 5, 2, 13, 47, 19, 123456)) + == first_commit.committed_date + ) + + +COMMIT_DOCS_WITH_MISSING = [ + {"_id": "0"}, # Missing + {"_id": "3ca6e1e73a071a705b4babd2f581c91a2a3e5037"}, # Existing + {"_id": "f"}, # Missing + {"_id": "eb3e543323f189fd7b698e66295427204fff5755"}, # Existing +] + + +async def test_mget(async_data_client): + commits = await Commit.mget(COMMIT_DOCS_WITH_MISSING) + assert commits[0] is None + assert commits[1].meta.id == "3ca6e1e73a071a705b4babd2f581c91a2a3e5037" + assert commits[2] is None + assert commits[3].meta.id == "eb3e543323f189fd7b698e66295427204fff5755" + + +async def test_mget_raises_exception_when_missing_param_is_invalid(async_data_client): + with raises(ValueError): + await Commit.mget(COMMIT_DOCS_WITH_MISSING, missing="raj") + + +async def test_mget_raises_404_when_missing_param_is_raise(async_data_client): + with raises(NotFoundError): + await Commit.mget(COMMIT_DOCS_WITH_MISSING, missing="raise") + + +async def test_mget_ignores_missing_docs_when_missing_param_is_skip(async_data_client): + commits = await Commit.mget(COMMIT_DOCS_WITH_MISSING, missing="skip") + assert commits[0].meta.id == "3ca6e1e73a071a705b4babd2f581c91a2a3e5037" + assert commits[1].meta.id == "eb3e543323f189fd7b698e66295427204fff5755" + + +async def test_update_works_from_search_response(async_data_client): + elasticsearch_repo = (await Repository.search().execute())[0] + + await elasticsearch_repo.update(owner={"other_name": "elastic"}) + assert "elastic" == elasticsearch_repo.owner.other_name + + new_version = await Repository.get("elasticsearch-dsl-py") + assert "elastic" == new_version.owner.other_name + assert "elasticsearch" == new_version.owner.name + + +async def test_update(async_data_client): + elasticsearch_repo = await Repository.get("elasticsearch-dsl-py") + v = elasticsearch_repo.meta.version + + old_seq_no = elasticsearch_repo.meta.seq_no + await elasticsearch_repo.update( + owner={"new_name": "elastic"}, new_field="testing-update" + ) + + assert "elastic" == elasticsearch_repo.owner.new_name + assert "testing-update" == elasticsearch_repo.new_field + + # assert version has been updated + assert elasticsearch_repo.meta.version == v + 1 + + new_version = await Repository.get("elasticsearch-dsl-py") + assert "testing-update" == new_version.new_field + assert "elastic" == new_version.owner.new_name + assert "elasticsearch" == new_version.owner.name + assert "seq_no" in new_version.meta + assert new_version.meta.seq_no != old_seq_no + assert "primary_term" in new_version.meta + + +async def test_save_updates_existing_doc(async_data_client): + elasticsearch_repo = await Repository.get("elasticsearch-dsl-py") + + elasticsearch_repo.new_field = "testing-save" + old_seq_no = elasticsearch_repo.meta.seq_no + assert "updated" == await elasticsearch_repo.save() + + new_repo = await async_data_client.get(index="git", id="elasticsearch-dsl-py") + assert "testing-save" == new_repo["_source"]["new_field"] + assert new_repo["_seq_no"] != old_seq_no + assert new_repo["_seq_no"] == elasticsearch_repo.meta.seq_no + + +async def test_save_automatically_uses_seq_no_and_primary_term(async_data_client): + elasticsearch_repo = await Repository.get("elasticsearch-dsl-py") + elasticsearch_repo.meta.seq_no += 1 + + with raises(ConflictError): + await elasticsearch_repo.save() + + +async def test_delete_automatically_uses_seq_no_and_primary_term(async_data_client): + elasticsearch_repo = await Repository.get("elasticsearch-dsl-py") + elasticsearch_repo.meta.seq_no += 1 + + with raises(ConflictError): + await elasticsearch_repo.delete() + + +def assert_doc_equals(expected, actual): + for f in expected: + assert f in actual + assert actual[f] == expected[f] + + +async def test_can_save_to_different_index(async_write_client): + test_repo = Repository(description="testing", meta={"id": 42}) + assert await test_repo.save(index="test-document") + + assert_doc_equals( + { + "found": True, + "_index": "test-document", + "_id": "42", + "_source": {"description": "testing"}, + }, + await async_write_client.get(index="test-document", id=42), + ) + + +async def test_save_without_skip_empty_will_include_empty_fields(async_write_client): + test_repo = Repository(field_1=[], field_2=None, field_3={}, meta={"id": 42}) + assert await test_repo.save(index="test-document", skip_empty=False) + + assert_doc_equals( + { + "found": True, + "_index": "test-document", + "_id": "42", + "_source": {"field_1": [], "field_2": None, "field_3": {}}, + }, + await async_write_client.get(index="test-document", id=42), + ) + + +async def test_delete(async_write_client): + await async_write_client.create( + index="test-document", + id="elasticsearch-dsl-py", + body={ + "organization": "elasticsearch", + "created_at": "2014-03-03", + "owner": {"name": "elasticsearch"}, + }, + ) + + test_repo = Repository(meta={"id": "elasticsearch-dsl-py"}) + test_repo.meta.index = "test-document" + await test_repo.delete() + + assert not await async_write_client.exists( + index="test-document", + id="elasticsearch-dsl-py", + ) + + +async def test_search(async_data_client): + assert await Repository.search().count() == 1 + + +async def test_search_returns_proper_doc_classes(async_data_client): + result = await Repository.search().execute() + + elasticsearch_repo = result.hits[0] + + assert isinstance(elasticsearch_repo, Repository) + assert elasticsearch_repo.owner.name == "elasticsearch" + + +async def test_refresh_mapping(async_data_client): + class Commit(AsyncDocument): + class Index: + name = "git" + + await Commit._index.load_mappings() + + assert "stats" in Commit._index._mapping + assert "committer" in Commit._index._mapping + assert "description" in Commit._index._mapping + assert "committed_date" in Commit._index._mapping + assert isinstance(Commit._index._mapping["committed_date"], Date) + + +async def test_highlight_in_meta(async_data_client): + commit = ( + await Commit.search() + .query("match", description="inverting") + .highlight("description") + .execute() + )[0] + + assert isinstance(commit, Commit) + assert "description" in commit.meta.highlight + assert isinstance(commit.meta.highlight["description"], AttrList) + assert len(commit.meta.highlight["description"]) > 0 diff --git a/tests/test_integration/_async/test_faceted_search.py b/tests/test_integration/_async/test_faceted_search.py new file mode 100644 index 000000000..fa63559e0 --- /dev/null +++ b/tests/test_integration/_async/test_faceted_search.py @@ -0,0 +1,282 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from datetime import datetime + +import pytest + +from elasticsearch_dsl import A, AsyncDocument, Boolean, Date, Keyword +from elasticsearch_dsl.faceted_search import ( + AsyncFacetedSearch, + DateHistogramFacet, + NestedFacet, + RangeFacet, + TermsFacet, +) + +from .test_document import PullRequest + + +class Repos(AsyncDocument): + is_public = Boolean() + created_at = Date() + + class Index: + name = "git" + + +class Commit(AsyncDocument): + files = Keyword() + committed_date = Date() + + class Index: + name = "git" + + +class MetricSearch(AsyncFacetedSearch): + index = "git" + doc_types = [Commit] + + facets = { + "files": TermsFacet(field="files", metric=A("max", field="committed_date")), + } + + +@pytest.fixture(scope="session") +def commit_search_cls(es_version): + if es_version >= (7, 2): + interval_kwargs = {"fixed_interval": "1d"} + else: + interval_kwargs = {"interval": "day"} + + class CommitSearch(AsyncFacetedSearch): + index = "flat-git" + fields = ( + "description", + "files", + ) + + facets = { + "files": TermsFacet(field="files"), + "frequency": DateHistogramFacet( + field="authored_date", min_doc_count=1, **interval_kwargs + ), + "deletions": RangeFacet( + field="stats.deletions", + ranges=[("ok", (None, 1)), ("good", (1, 5)), ("better", (5, None))], + ), + } + + return CommitSearch + + +@pytest.fixture(scope="session") +def repo_search_cls(es_version): + interval_type = "calendar_interval" if es_version >= (7, 2) else "interval" + + class RepoSearch(AsyncFacetedSearch): + index = "git" + doc_types = [Repos] + facets = { + "public": TermsFacet(field="is_public"), + "created": DateHistogramFacet( + field="created_at", **{interval_type: "month"} + ), + } + + def search(self): + s = super().search() + return s.filter("term", commit_repo="repo") + + return RepoSearch + + +@pytest.fixture(scope="session") +def pr_search_cls(es_version): + interval_type = "calendar_interval" if es_version >= (7, 2) else "interval" + + class PRSearch(AsyncFacetedSearch): + index = "test-prs" + doc_types = [PullRequest] + facets = { + "comments": NestedFacet( + "comments", + DateHistogramFacet( + field="comments.created_at", **{interval_type: "month"} + ), + ) + } + + return PRSearch + + +async def test_facet_with_custom_metric(async_data_client): + ms = MetricSearch() + r = await ms.execute() + + dates = [f[1] for f in r.facets.files] + assert dates == list(sorted(dates, reverse=True)) + assert dates[0] == 1399038439000 + + +async def test_nested_facet(async_pull_request, pr_search_cls): + prs = pr_search_cls() + r = await prs.execute() + + assert r.hits.total.value == 1 + assert [(datetime(2018, 1, 1, 0, 0), 1, False)] == r.facets.comments + + +async def test_nested_facet_with_filter(async_pull_request, pr_search_cls): + prs = pr_search_cls(filters={"comments": datetime(2018, 1, 1, 0, 0)}) + r = await prs.execute() + + assert r.hits.total.value == 1 + assert [(datetime(2018, 1, 1, 0, 0), 1, True)] == r.facets.comments + + prs = pr_search_cls(filters={"comments": datetime(2018, 2, 1, 0, 0)}) + r = await prs.execute() + assert not r.hits + + +async def test_datehistogram_facet(async_data_client, repo_search_cls): + rs = repo_search_cls() + r = await rs.execute() + + assert r.hits.total.value == 1 + assert [(datetime(2014, 3, 1, 0, 0), 1, False)] == r.facets.created + + +async def test_boolean_facet(async_data_client, repo_search_cls): + rs = repo_search_cls() + r = await rs.execute() + + assert r.hits.total.value == 1 + assert [(True, 1, False)] == r.facets.public + value, count, selected = r.facets.public[0] + assert value is True + + +async def test_empty_search_finds_everything( + async_data_client, es_version, commit_search_cls +): + cs = commit_search_cls() + r = await cs.execute() + + assert r.hits.total.value == 52 + assert [ + ("elasticsearch_dsl", 40, False), + ("test_elasticsearch_dsl", 35, False), + ("elasticsearch_dsl/query.py", 19, False), + ("test_elasticsearch_dsl/test_search.py", 15, False), + ("elasticsearch_dsl/utils.py", 14, False), + ("test_elasticsearch_dsl/test_query.py", 13, False), + ("elasticsearch_dsl/search.py", 12, False), + ("elasticsearch_dsl/aggs.py", 11, False), + ("test_elasticsearch_dsl/test_result.py", 5, False), + ("elasticsearch_dsl/result.py", 3, False), + ] == r.facets.files + + assert [ + (datetime(2014, 3, 3, 0, 0), 2, False), + (datetime(2014, 3, 4, 0, 0), 1, False), + (datetime(2014, 3, 5, 0, 0), 3, False), + (datetime(2014, 3, 6, 0, 0), 3, False), + (datetime(2014, 3, 7, 0, 0), 9, False), + (datetime(2014, 3, 10, 0, 0), 2, False), + (datetime(2014, 3, 15, 0, 0), 4, False), + (datetime(2014, 3, 21, 0, 0), 2, False), + (datetime(2014, 3, 23, 0, 0), 2, False), + (datetime(2014, 3, 24, 0, 0), 10, False), + (datetime(2014, 4, 20, 0, 0), 2, False), + (datetime(2014, 4, 22, 0, 0), 2, False), + (datetime(2014, 4, 25, 0, 0), 3, False), + (datetime(2014, 4, 26, 0, 0), 2, False), + (datetime(2014, 4, 27, 0, 0), 2, False), + (datetime(2014, 5, 1, 0, 0), 2, False), + (datetime(2014, 5, 2, 0, 0), 1, False), + ] == r.facets.frequency + + assert [ + ("ok", 19, False), + ("good", 14, False), + ("better", 19, False), + ] == r.facets.deletions + + +async def test_term_filters_are_shown_as_selected_and_data_is_filtered( + async_data_client, commit_search_cls +): + cs = commit_search_cls(filters={"files": "test_elasticsearch_dsl"}) + + r = await cs.execute() + + assert 35 == r.hits.total.value + assert [ + ("elasticsearch_dsl", 40, False), + ("test_elasticsearch_dsl", 35, True), # selected + ("elasticsearch_dsl/query.py", 19, False), + ("test_elasticsearch_dsl/test_search.py", 15, False), + ("elasticsearch_dsl/utils.py", 14, False), + ("test_elasticsearch_dsl/test_query.py", 13, False), + ("elasticsearch_dsl/search.py", 12, False), + ("elasticsearch_dsl/aggs.py", 11, False), + ("test_elasticsearch_dsl/test_result.py", 5, False), + ("elasticsearch_dsl/result.py", 3, False), + ] == r.facets.files + + assert [ + (datetime(2014, 3, 3, 0, 0), 1, False), + (datetime(2014, 3, 5, 0, 0), 2, False), + (datetime(2014, 3, 6, 0, 0), 3, False), + (datetime(2014, 3, 7, 0, 0), 6, False), + (datetime(2014, 3, 10, 0, 0), 1, False), + (datetime(2014, 3, 15, 0, 0), 3, False), + (datetime(2014, 3, 21, 0, 0), 2, False), + (datetime(2014, 3, 23, 0, 0), 1, False), + (datetime(2014, 3, 24, 0, 0), 7, False), + (datetime(2014, 4, 20, 0, 0), 1, False), + (datetime(2014, 4, 25, 0, 0), 3, False), + (datetime(2014, 4, 26, 0, 0), 2, False), + (datetime(2014, 4, 27, 0, 0), 1, False), + (datetime(2014, 5, 1, 0, 0), 1, False), + (datetime(2014, 5, 2, 0, 0), 1, False), + ] == r.facets.frequency + + assert [ + ("ok", 12, False), + ("good", 10, False), + ("better", 13, False), + ] == r.facets.deletions + + +async def test_range_filters_are_shown_as_selected_and_data_is_filtered( + async_data_client, commit_search_cls +): + cs = commit_search_cls(filters={"deletions": "better"}) + + r = await cs.execute() + + assert 19 == r.hits.total.value + + +async def test_pagination(async_data_client, commit_search_cls): + cs = commit_search_cls() + cs = cs[0:20] + + assert 52 == await cs.count() + assert 20 == len(await cs.execute()) diff --git a/tests/test_integration/_async/test_index.py b/tests/test_integration/_async/test_index.py new file mode 100644 index 000000000..7a7082a1b --- /dev/null +++ b/tests/test_integration/_async/test_index.py @@ -0,0 +1,122 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from elasticsearch_dsl import ( + AsyncDocument, + AsyncIndex, + AsyncIndexTemplate, + Date, + Text, + analysis, +) + + +class Post(AsyncDocument): + title = Text(analyzer=analysis.analyzer("my_analyzer", tokenizer="keyword")) + published_from = Date() + + +async def test_index_template_works(async_write_client): + it = AsyncIndexTemplate("test-template", "test-*") + it.document(Post) + it.settings(number_of_replicas=0, number_of_shards=1) + await it.save() + + i = AsyncIndex("test-blog") + await i.create() + + assert { + "test-blog": { + "mappings": { + "properties": { + "title": {"type": "text", "analyzer": "my_analyzer"}, + "published_from": {"type": "date"}, + } + } + } + } == await async_write_client.indices.get_mapping(index="test-blog") + + +async def test_index_can_be_saved_even_with_settings(async_write_client): + i = AsyncIndex("test-blog", using=async_write_client) + i.settings(number_of_shards=3, number_of_replicas=0) + await i.save() + i.settings(number_of_replicas=1) + await i.save() + + assert ( + "1" + == (await i.get_settings())["test-blog"]["settings"]["index"][ + "number_of_replicas" + ] + ) + + +async def test_index_exists(async_data_client): + assert await AsyncIndex("git").exists() + assert not await AsyncIndex("not-there").exists() + + +async def test_index_can_be_created_with_settings_and_mappings(async_write_client): + i = AsyncIndex("test-blog", using=async_write_client) + i.document(Post) + i.settings(number_of_replicas=0, number_of_shards=1) + await i.create() + + assert { + "test-blog": { + "mappings": { + "properties": { + "title": {"type": "text", "analyzer": "my_analyzer"}, + "published_from": {"type": "date"}, + } + } + } + } == await async_write_client.indices.get_mapping(index="test-blog") + + settings = await async_write_client.indices.get_settings(index="test-blog") + assert settings["test-blog"]["settings"]["index"]["number_of_replicas"] == "0" + assert settings["test-blog"]["settings"]["index"]["number_of_shards"] == "1" + assert settings["test-blog"]["settings"]["index"]["analysis"] == { + "analyzer": {"my_analyzer": {"type": "custom", "tokenizer": "keyword"}} + } + + +async def test_delete(async_write_client): + await async_write_client.indices.create( + index="test-index", + body={"settings": {"number_of_replicas": 0, "number_of_shards": 1}}, + ) + + i = AsyncIndex("test-index", using=async_write_client) + await i.delete() + assert not await async_write_client.indices.exists(index="test-index") + + +async def test_multiple_indices_with_same_doc_type_work(async_write_client): + i1 = AsyncIndex("test-index-1", using=async_write_client) + i2 = AsyncIndex("test-index-2", using=async_write_client) + + for i in (i1, i2): + i.document(Post) + await i.create() + + for i in ("test-index-1", "test-index-2"): + settings = await async_write_client.indices.get_settings(index=i) + assert settings[i]["settings"]["index"]["analysis"] == { + "analyzer": {"my_analyzer": {"type": "custom", "tokenizer": "keyword"}} + } diff --git a/tests/test_integration/_async/test_mapping.py b/tests/test_integration/_async/test_mapping.py new file mode 100644 index 000000000..54762c005 --- /dev/null +++ b/tests/test_integration/_async/test_mapping.py @@ -0,0 +1,163 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from pytest import raises + +from elasticsearch_dsl import AsyncMapping, analysis, exceptions + + +async def test_mapping_saved_into_es(async_write_client): + m = AsyncMapping() + m.field( + "name", "text", analyzer=analysis.analyzer("my_analyzer", tokenizer="keyword") + ) + m.field("tags", "keyword") + await m.save("test-mapping", using=async_write_client) + + assert { + "test-mapping": { + "mappings": { + "properties": { + "name": {"type": "text", "analyzer": "my_analyzer"}, + "tags": {"type": "keyword"}, + } + } + } + } == await async_write_client.indices.get_mapping(index="test-mapping") + + +async def test_mapping_saved_into_es_when_index_already_exists_closed( + async_write_client, +): + m = AsyncMapping() + m.field( + "name", "text", analyzer=analysis.analyzer("my_analyzer", tokenizer="keyword") + ) + await async_write_client.indices.create(index="test-mapping") + + with raises(exceptions.IllegalOperation): + await m.save("test-mapping", using=async_write_client) + + await async_write_client.cluster.health( + index="test-mapping", wait_for_status="yellow" + ) + await async_write_client.indices.close(index="test-mapping") + await m.save("test-mapping", using=async_write_client) + + assert { + "test-mapping": { + "mappings": { + "properties": {"name": {"type": "text", "analyzer": "my_analyzer"}} + } + } + } == await async_write_client.indices.get_mapping(index="test-mapping") + + +async def test_mapping_saved_into_es_when_index_already_exists_with_analysis( + async_write_client, +): + m = AsyncMapping() + analyzer = analysis.analyzer("my_analyzer", tokenizer="keyword") + m.field("name", "text", analyzer=analyzer) + + new_analysis = analyzer.get_analysis_definition() + new_analysis["analyzer"]["other_analyzer"] = { + "type": "custom", + "tokenizer": "whitespace", + } + await async_write_client.indices.create( + index="test-mapping", body={"settings": {"analysis": new_analysis}} + ) + + m.field("title", "text", analyzer=analyzer) + await m.save("test-mapping", using=async_write_client) + + assert { + "test-mapping": { + "mappings": { + "properties": { + "name": {"type": "text", "analyzer": "my_analyzer"}, + "title": {"type": "text", "analyzer": "my_analyzer"}, + } + } + } + } == await async_write_client.indices.get_mapping(index="test-mapping") + + +async def test_mapping_gets_updated_from_es(async_write_client): + await async_write_client.indices.create( + index="test-mapping", + body={ + "settings": {"number_of_shards": 1, "number_of_replicas": 0}, + "mappings": { + "date_detection": False, + "properties": { + "title": { + "type": "text", + "analyzer": "snowball", + "fields": {"raw": {"type": "keyword"}}, + }, + "created_at": {"type": "date"}, + "comments": { + "type": "nested", + "properties": { + "created": {"type": "date"}, + "author": { + "type": "text", + "analyzer": "snowball", + "fields": {"raw": {"type": "keyword"}}, + }, + }, + }, + }, + }, + }, + ) + + m = await AsyncMapping.from_es("test-mapping", using=async_write_client) + + assert ["comments", "created_at", "title"] == list( + sorted(m.properties.properties._d_.keys()) + ) + assert { + "date_detection": False, + "properties": { + "comments": { + "type": "nested", + "properties": { + "created": {"type": "date"}, + "author": { + "analyzer": "snowball", + "fields": {"raw": {"type": "keyword"}}, + "type": "text", + }, + }, + }, + "created_at": {"type": "date"}, + "title": { + "analyzer": "snowball", + "fields": {"raw": {"type": "keyword"}}, + "type": "text", + }, + }, + } == m.to_dict() + + # test same with alias + await async_write_client.indices.put_alias(index="test-mapping", name="test-alias") + + m2 = await AsyncMapping.from_es("test-alias", using=async_write_client) + assert m2.to_dict() == m.to_dict() diff --git a/tests/test_integration/_async/test_search.py b/tests/test_integration/_async/test_search.py new file mode 100644 index 000000000..cd93ed562 --- /dev/null +++ b/tests/test_integration/_async/test_search.py @@ -0,0 +1,178 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +from elasticsearch import ApiError +from pytest import raises + +from elasticsearch_dsl import ( + AsyncDocument, + AsyncMultiSearch, + AsyncSearch, + Date, + Keyword, + Q, + Text, +) +from elasticsearch_dsl.response import aggs + +from ..test_data import FLAT_DATA + + +class Repository(AsyncDocument): + created_at = Date() + description = Text(analyzer="snowball") + tags = Keyword() + + @classmethod + def search(cls): + return super().search().filter("term", commit_repo="repo") + + class Index: + name = "git" + + +class Commit(AsyncDocument): + class Index: + name = "flat-git" + + +async def test_filters_aggregation_buckets_are_accessible(async_data_client): + has_tests_query = Q("term", files="test_elasticsearch_dsl") + s = Commit.search()[0:0] + s.aggs.bucket("top_authors", "terms", field="author.name.raw").bucket( + "has_tests", "filters", filters={"yes": has_tests_query, "no": ~has_tests_query} + ).metric("lines", "stats", field="stats.lines") + + response = await s.execute() + + assert isinstance( + response.aggregations.top_authors.buckets[0].has_tests.buckets.yes, aggs.Bucket + ) + assert ( + 35 + == response.aggregations.top_authors.buckets[0].has_tests.buckets.yes.doc_count + ) + assert ( + 228 + == response.aggregations.top_authors.buckets[0].has_tests.buckets.yes.lines.max + ) + + +async def test_top_hits_are_wrapped_in_response(async_data_client): + s = Commit.search()[0:0] + s.aggs.bucket("top_authors", "terms", field="author.name.raw").metric( + "top_commits", "top_hits", size=5 + ) + response = await s.execute() + + top_commits = response.aggregations.top_authors.buckets[0].top_commits + assert isinstance(top_commits, aggs.TopHitsData) + assert 5 == len(top_commits) + + hits = [h for h in top_commits] + assert 5 == len(hits) + assert isinstance(hits[0], Commit) + + +async def test_inner_hits_are_wrapped_in_response(async_data_client): + s = AsyncSearch(index="git")[0:1].query( + "has_parent", parent_type="repo", inner_hits={}, query=Q("match_all") + ) + response = await s.execute() + + commit = response.hits[0] + assert isinstance(commit.meta.inner_hits.repo, response.__class__) + assert repr(commit.meta.inner_hits.repo[0]).startswith( + " 0 + assert not response.timed_out + assert response.updated == 52 + assert response.deleted == 0 + assert response.took > 0 + assert response.success() + + +async def test_update_by_query_with_script(async_write_client, setup_ubq_tests): + index = setup_ubq_tests + + ubq = ( + AsyncUpdateByQuery(using=async_write_client) + .index(index) + .filter(~Q("exists", field="parent_shas")) + .script(source="ctx._source.is_public = false") + ) + ubq = ubq.params(conflicts="proceed") + + response = await ubq.execute() + assert response.total == 2 + assert response.updated == 2 + assert response.version_conflicts == 0 + + +async def test_delete_by_query_with_script(async_write_client, setup_ubq_tests): + index = setup_ubq_tests + + ubq = ( + AsyncUpdateByQuery(using=async_write_client) + .index(index) + .filter(Q("match", parent_shas="1dd19210b5be92b960f7db6f66ae526288edccc3")) + .script(source='ctx.op = "delete"') + ) + ubq = ubq.params(conflicts="proceed") + + response = await ubq.execute() + + assert response.total == 1 + assert response.deleted == 1 + assert response.success() diff --git a/tests/test_integration/_sync/__init__.py b/tests/test_integration/_sync/__init__.py new file mode 100644 index 000000000..2a87d183f --- /dev/null +++ b/tests/test_integration/_sync/__init__.py @@ -0,0 +1,16 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/test_integration/test_analysis.py b/tests/test_integration/_sync/test_analysis.py similarity index 89% rename from tests/test_integration/test_analysis.py rename to tests/test_integration/_sync/test_analysis.py index 140099d4a..0356cf1a6 100644 --- a/tests/test_integration/test_analysis.py +++ b/tests/test_integration/_sync/test_analysis.py @@ -20,7 +20,7 @@ def test_simulate_with_just__builtin_tokenizer(client): a = analyzer("my-analyzer", tokenizer="keyword") - tokens = a.simulate("Hello World!", using=client).tokens + tokens = (a.simulate("Hello World!", using=client)).tokens assert len(tokens) == 1 assert tokens[0].token == "Hello World!" @@ -33,7 +33,7 @@ def test_simulate_complex(client): filter=["lowercase", token_filter("no-ifs", "stop", stopwords=["if"])], ) - tokens = a.simulate("if:this:works", using=client).tokens + tokens = (a.simulate("if:this:works", using=client)).tokens assert len(tokens) == 2 assert ["this", "works"] == [t.token for t in tokens] @@ -41,6 +41,6 @@ def test_simulate_complex(client): def test_simulate_builtin(client): a = analyzer("my-analyzer", "english") - tokens = a.simulate("fixes running").tokens + tokens = (a.simulate("fixes running")).tokens assert ["fix", "run"] == [t.token for t in tokens] diff --git a/tests/test_integration/test_document.py b/tests/test_integration/_sync/test_document.py similarity index 99% rename from tests/test_integration/test_document.py rename to tests/test_integration/_sync/test_document.py index 6fb2f4eaa..153a98632 100644 --- a/tests/test_integration/test_document.py +++ b/tests/test_integration/_sync/test_document.py @@ -321,7 +321,9 @@ def test_get_returns_none_if_404_ignored(data_client): ) -def test_get_returns_none_if_404_ignored_and_index_doesnt_exist(data_client): +def test_get_returns_none_if_404_ignored_and_index_doesnt_exist( + data_client, +): assert None is Repository.get( "42", index="not-there", using=data_client.options(ignore_status=404) ) @@ -407,7 +409,7 @@ def test_mget_ignores_missing_docs_when_missing_param_is_skip(data_client): def test_update_works_from_search_response(data_client): - elasticsearch_repo = Repository.search().execute()[0] + elasticsearch_repo = (Repository.search().execute())[0] elasticsearch_repo.update(owner={"other_name": "elastic"}) assert "elastic" == elasticsearch_repo.owner.other_name @@ -557,8 +559,8 @@ def test_highlight_in_meta(data_client): Commit.search() .query("match", description="inverting") .highlight("description") - .execute()[0] - ) + .execute() + )[0] assert isinstance(commit, Commit) assert "description" in commit.meta.highlight diff --git a/tests/test_integration/test_faceted_search.py b/tests/test_integration/_sync/test_faceted_search.py similarity index 100% rename from tests/test_integration/test_faceted_search.py rename to tests/test_integration/_sync/test_faceted_search.py diff --git a/tests/test_integration/test_index.py b/tests/test_integration/_sync/test_index.py similarity index 97% rename from tests/test_integration/test_index.py rename to tests/test_integration/_sync/test_index.py index c03223000..35e91141a 100644 --- a/tests/test_integration/test_index.py +++ b/tests/test_integration/_sync/test_index.py @@ -52,7 +52,8 @@ def test_index_can_be_saved_even_with_settings(write_client): i.save() assert ( - "1" == i.get_settings()["test-blog"]["settings"]["index"]["number_of_replicas"] + "1" + == (i.get_settings())["test-blog"]["settings"]["index"]["number_of_replicas"] ) diff --git a/tests/test_integration/test_mapping.py b/tests/test_integration/_sync/test_mapping.py similarity index 94% rename from tests/test_integration/test_mapping.py rename to tests/test_integration/_sync/test_mapping.py index ff266a777..7d55ca6c2 100644 --- a/tests/test_integration/test_mapping.py +++ b/tests/test_integration/_sync/test_mapping.py @@ -17,11 +17,11 @@ from pytest import raises -from elasticsearch_dsl import analysis, exceptions, mapping +from elasticsearch_dsl import Mapping, analysis, exceptions def test_mapping_saved_into_es(write_client): - m = mapping.Mapping() + m = Mapping() m.field( "name", "text", analyzer=analysis.analyzer("my_analyzer", tokenizer="keyword") ) @@ -40,8 +40,10 @@ def test_mapping_saved_into_es(write_client): } == write_client.indices.get_mapping(index="test-mapping") -def test_mapping_saved_into_es_when_index_already_exists_closed(write_client): - m = mapping.Mapping() +def test_mapping_saved_into_es_when_index_already_exists_closed( + write_client, +): + m = Mapping() m.field( "name", "text", analyzer=analysis.analyzer("my_analyzer", tokenizer="keyword") ) @@ -63,8 +65,10 @@ def test_mapping_saved_into_es_when_index_already_exists_closed(write_client): } == write_client.indices.get_mapping(index="test-mapping") -def test_mapping_saved_into_es_when_index_already_exists_with_analysis(write_client): - m = mapping.Mapping() +def test_mapping_saved_into_es_when_index_already_exists_with_analysis( + write_client, +): + m = Mapping() analyzer = analysis.analyzer("my_analyzer", tokenizer="keyword") m.field("name", "text", analyzer=analyzer) @@ -122,7 +126,7 @@ def test_mapping_gets_updated_from_es(write_client): }, ) - m = mapping.Mapping.from_es("test-mapping", using=write_client) + m = Mapping.from_es("test-mapping", using=write_client) assert ["comments", "created_at", "title"] == list( sorted(m.properties.properties._d_.keys()) @@ -153,5 +157,5 @@ def test_mapping_gets_updated_from_es(write_client): # test same with alias write_client.indices.put_alias(index="test-mapping", name="test-alias") - m2 = mapping.Mapping.from_es("test-alias", using=write_client) + m2 = Mapping.from_es("test-alias", using=write_client) assert m2.to_dict() == m.to_dict() diff --git a/tests/test_integration/test_search.py b/tests/test_integration/_sync/test_search.py similarity index 96% rename from tests/test_integration/test_search.py rename to tests/test_integration/_sync/test_search.py index 99fb51847..a72c0fe1c 100644 --- a/tests/test_integration/test_search.py +++ b/tests/test_integration/_sync/test_search.py @@ -22,7 +22,7 @@ from elasticsearch_dsl import Date, Document, Keyword, MultiSearch, Q, Search, Text from elasticsearch_dsl.response import aggs -from .test_data import FLAT_DATA +from ..test_data import FLAT_DATA class Repository(Document): @@ -49,6 +49,7 @@ def test_filters_aggregation_buckets_are_accessible(data_client): s.aggs.bucket("top_authors", "terms", field="author.name.raw").bucket( "has_tests", "filters", filters={"yes": has_tests_query, "no": ~has_tests_query} ).metric("lines", "stats", field="stats.lines") + response = s.execute() assert isinstance( @@ -94,7 +95,7 @@ def test_inner_hits_are_wrapped_in_response(data_client): def test_scan_respects_doc_types(data_client): - repos = list(Repository.search().scan()) + repos = [repo for repo in Repository.search().scan()] assert 1 == len(repos) assert isinstance(repos[0], Repository) @@ -104,7 +105,7 @@ def test_scan_respects_doc_types(data_client): def test_scan_iterates_through_all_docs(data_client): s = Search(index="flat-git") - commits = list(s.scan()) + commits = [commit for commit in s.scan()] assert 52 == len(commits) assert {d["_id"] for d in FLAT_DATA} == {c.meta.id for c in commits} @@ -112,7 +113,7 @@ def test_scan_iterates_through_all_docs(data_client): def test_response_is_cached(data_client): s = Repository.search() - repos = list(s) + repos = [repo for repo in s] assert hasattr(s, "_response") assert s._response.hits == repos diff --git a/tests/test_integration/test_update_by_query.py b/tests/test_integration/_sync/test_update_by_query.py similarity index 97% rename from tests/test_integration/test_update_by_query.py rename to tests/test_integration/_sync/test_update_by_query.py index 64485391a..ed9fd36c2 100644 --- a/tests/test_integration/test_update_by_query.py +++ b/tests/test_integration/_sync/test_update_by_query.py @@ -15,8 +15,8 @@ # specific language governing permissions and limitations # under the License. +from elasticsearch_dsl import UpdateByQuery from elasticsearch_dsl.search import Q -from elasticsearch_dsl.update_by_query import UpdateByQuery def test_update_by_query_no_script(write_client, setup_ubq_tests): diff --git a/tests/test_integration/test_examples/_async/__init__.py b/tests/test_integration/test_examples/_async/__init__.py new file mode 100644 index 000000000..2a87d183f --- /dev/null +++ b/tests/test_integration/test_examples/_async/__init__.py @@ -0,0 +1,16 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/test_integration/test_examples/_async/test_alias_migration.py b/tests/test_integration/test_examples/_async/test_alias_migration.py new file mode 100644 index 000000000..9bedd7b28 --- /dev/null +++ b/tests/test_integration/test_examples/_async/test_alias_migration.py @@ -0,0 +1,68 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from ..async_examples import alias_migration +from ..async_examples.alias_migration import ALIAS, PATTERN, BlogPost, migrate + + +async def test_alias_migration(async_write_client): + # create the index + await alias_migration.setup() + + # verify that template, index, and alias has been set up + assert await async_write_client.indices.exists_template(name=ALIAS) + assert await async_write_client.indices.exists(index=PATTERN) + assert await async_write_client.indices.exists_alias(name=ALIAS) + + indices = await async_write_client.indices.get(index=PATTERN) + assert len(indices) == 1 + index_name, _ = indices.popitem() + + # which means we can now save a document + with open(__file__) as f: + bp = BlogPost( + _id=0, + title="Hello World!", + tags=["testing", "dummy"], + content=f.read(), + ) + await bp.save(refresh=True) + + assert await BlogPost.search().count() == 1 + + # _matches work which means we get BlogPost instance + bp = (await BlogPost.search().execute())[0] + assert isinstance(bp, BlogPost) + assert not bp.is_published() + assert "0" == bp.meta.id + + # create new index + await migrate() + + indices = await async_write_client.indices.get(index=PATTERN) + assert 2 == len(indices) + alias = await async_write_client.indices.get(index=ALIAS) + assert 1 == len(alias) + assert index_name not in alias + + # data has been moved properly + assert await BlogPost.search().count() == 1 + + # _matches work which means we get BlogPost instance + bp = (await BlogPost.search().execute())[0] + assert isinstance(bp, BlogPost) + assert "0" == bp.meta.id diff --git a/tests/test_integration/test_examples/_async/test_completion.py b/tests/test_integration/test_examples/_async/test_completion.py new file mode 100644 index 000000000..f71fcb1e6 --- /dev/null +++ b/tests/test_integration/test_examples/_async/test_completion.py @@ -0,0 +1,34 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +from ..async_examples.completion import Person + + +async def test_person_suggests_on_all_variants_of_name(async_write_client): + await Person.init(using=async_write_client) + + await Person(name="Honza Král", popularity=42).save(refresh=True) + + s = Person.search().suggest("t", "kra", completion={"field": "suggest"}) + response = await s.execute() + + opts = response.suggest.t[0].options + + assert 1 == len(opts) + assert opts[0]._score == 42 + assert opts[0]._source.name == "Honza Král" diff --git a/tests/test_integration/test_examples/_async/test_composite_aggs.py b/tests/test_integration/test_examples/_async/test_composite_aggs.py new file mode 100644 index 000000000..ba07b4533 --- /dev/null +++ b/tests/test_integration/test_examples/_async/test_composite_aggs.py @@ -0,0 +1,46 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from elasticsearch_dsl import A, AsyncSearch + +from ..async_examples.composite_agg import scan_aggs + + +async def test_scan_aggs_exhausts_all_files(async_data_client): + s = AsyncSearch(index="flat-git") + key_aggs = {"files": A("terms", field="files")} + file_list = [f async for f in scan_aggs(s, key_aggs)] + + assert len(file_list) == 26 + + +async def test_scan_aggs_with_multiple_aggs(async_data_client): + s = AsyncSearch(index="flat-git") + key_aggs = [ + {"files": A("terms", field="files")}, + { + "months": { + "date_histogram": { + "field": "committed_date", + "calendar_interval": "month", + } + } + }, + ] + file_list = [f async for f in scan_aggs(s, key_aggs)] + + assert len(file_list) == 47 diff --git a/tests/test_integration/test_examples/_async/test_parent_child.py b/tests/test_integration/test_examples/_async/test_parent_child.py new file mode 100644 index 000000000..66fdac3c0 --- /dev/null +++ b/tests/test_integration/test_examples/_async/test_parent_child.py @@ -0,0 +1,105 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from datetime import datetime + +from pytest import fixture + +from elasticsearch_dsl import Q + +from ..async_examples.parent_child import Answer, Comment, Question, User, setup + +honza = User( + id=42, + signed_up=datetime(2013, 4, 3), + username="honzakral", + email="honza@elastic.co", + location="Prague", +) + +nick = User( + id=47, + signed_up=datetime(2017, 4, 3), + username="fxdgear", + email="nick.lang@elastic.co", + location="Colorado", +) + + +@fixture +async def question(async_write_client): + await setup() + assert await async_write_client.indices.exists_template(name="base") + + # create a question object + q = Question( + _id=1, + author=nick, + tags=["elasticsearch", "python"], + title="How do I use elasticsearch from Python?", + body=""" + I want to use elasticsearch, how do I do it from Python? + """, + ) + await q.save() + return q + + +async def test_comment(async_write_client, question): + await question.add_comment(nick, "Just use elasticsearch-py") + + q = await Question.get(1) + assert isinstance(q, Question) + assert 1 == len(q.comments) + + c = q.comments[0] + assert isinstance(c, Comment) + assert c.author.username == "fxdgear" + + +async def test_question_answer(async_write_client, question): + a = await question.add_answer(honza, "Just use `elasticsearch-py`!") + + assert isinstance(a, Answer) + + # refresh the index so we can search right away + await Question._index.refresh() + + # we can now fetch answers from elasticsearch + answers = await question.get_answers() + assert 1 == len(answers) + assert isinstance(answers[0], Answer) + + search = Question.search().query( + "has_child", + type="answer", + inner_hits={}, + query=Q("term", author__username__keyword="honzakral"), + ) + response = await search.execute() + + assert 1 == len(response.hits) + + q = response.hits[0] + assert isinstance(q, Question) + assert 1 == len(q.meta.inner_hits.answer.hits) + assert q.meta.inner_hits.answer.hits is await q.get_answers() + + a = q.meta.inner_hits.answer.hits[0] + assert isinstance(a, Answer) + assert isinstance(await a.get_question(), Question) + assert (await a.get_question()).meta.id == "1" diff --git a/tests/test_integration/test_examples/_async/test_percolate.py b/tests/test_integration/test_examples/_async/test_percolate.py new file mode 100644 index 000000000..a538b9e47 --- /dev/null +++ b/tests/test_integration/test_examples/_async/test_percolate.py @@ -0,0 +1,31 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from ..async_examples.percolate import BlogPost, setup + + +async def test_post_gets_tagged_automatically(async_write_client): + await setup() + + bp = BlogPost(_id=47, content="nothing about snakes here!") + bp_py = BlogPost(_id=42, content="something about Python here!") + + await bp.save() + await bp_py.save() + + assert [] == bp.tags + assert {"programming", "development", "python"} == set(bp_py.tags) diff --git a/tests/test_integration/test_examples/_sync/__init__.py b/tests/test_integration/test_examples/_sync/__init__.py new file mode 100644 index 000000000..2a87d183f --- /dev/null +++ b/tests/test_integration/test_examples/_sync/__init__.py @@ -0,0 +1,16 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/test_integration/test_examples/test_alias_migration.py b/tests/test_integration/test_examples/_sync/test_alias_migration.py similarity index 91% rename from tests/test_integration/test_examples/test_alias_migration.py rename to tests/test_integration/test_examples/_sync/test_alias_migration.py index c76e65e32..ee8b70cc5 100644 --- a/tests/test_integration/test_examples/test_alias_migration.py +++ b/tests/test_integration/test_examples/_sync/test_alias_migration.py @@ -15,8 +15,8 @@ # specific language governing permissions and limitations # under the License. -from . import alias_migration -from .alias_migration import ALIAS, PATTERN, BlogPost, migrate +from ..examples import alias_migration +from ..examples.alias_migration import ALIAS, PATTERN, BlogPost, migrate def test_alias_migration(write_client): @@ -45,7 +45,7 @@ def test_alias_migration(write_client): assert BlogPost.search().count() == 1 # _matches work which means we get BlogPost instance - bp = BlogPost.search().execute()[0] + bp = (BlogPost.search().execute())[0] assert isinstance(bp, BlogPost) assert not bp.is_published() assert "0" == bp.meta.id @@ -63,6 +63,6 @@ def test_alias_migration(write_client): assert BlogPost.search().count() == 1 # _matches work which means we get BlogPost instance - bp = BlogPost.search().execute()[0] + bp = (BlogPost.search().execute())[0] assert isinstance(bp, BlogPost) assert "0" == bp.meta.id diff --git a/tests/test_integration/test_examples/test_completion.py b/tests/test_integration/test_examples/_sync/test_completion.py similarity index 96% rename from tests/test_integration/test_examples/test_completion.py rename to tests/test_integration/test_examples/_sync/test_completion.py index 92f6d80a0..0b49400fc 100644 --- a/tests/test_integration/test_examples/test_completion.py +++ b/tests/test_integration/test_examples/_sync/test_completion.py @@ -16,7 +16,7 @@ # under the License. -from .completion import Person +from ..examples.completion import Person def test_person_suggests_on_all_variants_of_name(write_client): diff --git a/tests/test_integration/test_examples/test_composite_aggs.py b/tests/test_integration/test_examples/_sync/test_composite_aggs.py similarity index 90% rename from tests/test_integration/test_examples/test_composite_aggs.py rename to tests/test_integration/test_examples/_sync/test_composite_aggs.py index 373696927..0c59b3f70 100644 --- a/tests/test_integration/test_examples/test_composite_aggs.py +++ b/tests/test_integration/test_examples/_sync/test_composite_aggs.py @@ -17,13 +17,13 @@ from elasticsearch_dsl import A, Search -from .composite_agg import scan_aggs +from ..examples.composite_agg import scan_aggs def test_scan_aggs_exhausts_all_files(data_client): s = Search(index="flat-git") key_aggs = {"files": A("terms", field="files")} - file_list = list(scan_aggs(s, key_aggs)) + file_list = [f for f in scan_aggs(s, key_aggs)] assert len(file_list) == 26 @@ -41,6 +41,6 @@ def test_scan_aggs_with_multiple_aggs(data_client): } }, ] - file_list = list(scan_aggs(s, key_aggs)) + file_list = [f for f in scan_aggs(s, key_aggs)] assert len(file_list) == 47 diff --git a/tests/test_integration/test_examples/test_parent_child.py b/tests/test_integration/test_examples/_sync/test_parent_child.py similarity index 94% rename from tests/test_integration/test_examples/test_parent_child.py rename to tests/test_integration/test_examples/_sync/test_parent_child.py index fa2a3ab5e..2aad5f64b 100644 --- a/tests/test_integration/test_examples/test_parent_child.py +++ b/tests/test_integration/test_examples/_sync/test_parent_child.py @@ -21,7 +21,7 @@ from elasticsearch_dsl import Q -from .parent_child import Answer, Comment, Question, User, setup +from ..examples.parent_child import Answer, Comment, Question, User, setup honza = User( id=42, @@ -101,5 +101,5 @@ def test_question_answer(write_client, question): a = q.meta.inner_hits.answer.hits[0] assert isinstance(a, Answer) - assert isinstance(a.question, Question) - assert a.question.meta.id == "1" + assert isinstance(a.get_question(), Question) + assert (a.get_question()).meta.id == "1" diff --git a/tests/test_integration/test_examples/test_percolate.py b/tests/test_integration/test_examples/_sync/test_percolate.py similarity index 95% rename from tests/test_integration/test_examples/test_percolate.py rename to tests/test_integration/test_examples/_sync/test_percolate.py index 30fcf972b..ef81a1099 100644 --- a/tests/test_integration/test_examples/test_percolate.py +++ b/tests/test_integration/test_examples/_sync/test_percolate.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from .percolate import BlogPost, setup +from ..examples.percolate import BlogPost, setup def test_post_gets_tagged_automatically(write_client): diff --git a/tests/test_integration/test_examples/alias_migration.py b/tests/test_integration/test_examples/alias_migration.py deleted file mode 120000 index 6aba6fefb..000000000 --- a/tests/test_integration/test_examples/alias_migration.py +++ /dev/null @@ -1 +0,0 @@ -../../../examples/alias_migration.py \ No newline at end of file diff --git a/tests/test_integration/test_examples/async_examples b/tests/test_integration/test_examples/async_examples new file mode 120000 index 000000000..1cebed2ce --- /dev/null +++ b/tests/test_integration/test_examples/async_examples @@ -0,0 +1 @@ +../../../examples/async \ No newline at end of file diff --git a/tests/test_integration/test_examples/completion.py b/tests/test_integration/test_examples/completion.py deleted file mode 120000 index 8efd561b2..000000000 --- a/tests/test_integration/test_examples/completion.py +++ /dev/null @@ -1 +0,0 @@ -../../../examples/completion.py \ No newline at end of file diff --git a/tests/test_integration/test_examples/composite_agg.py b/tests/test_integration/test_examples/composite_agg.py deleted file mode 120000 index 1b2d9f10d..000000000 --- a/tests/test_integration/test_examples/composite_agg.py +++ /dev/null @@ -1 +0,0 @@ -../../../examples/composite_agg.py \ No newline at end of file diff --git a/tests/test_integration/test_examples/examples b/tests/test_integration/test_examples/examples new file mode 120000 index 000000000..9f9d1de88 --- /dev/null +++ b/tests/test_integration/test_examples/examples @@ -0,0 +1 @@ +../../../examples \ No newline at end of file diff --git a/tests/test_integration/test_examples/parent_child.py b/tests/test_integration/test_examples/parent_child.py deleted file mode 120000 index dc7c7c120..000000000 --- a/tests/test_integration/test_examples/parent_child.py +++ /dev/null @@ -1 +0,0 @@ -../../../examples/parent_child.py \ No newline at end of file diff --git a/tests/test_integration/test_examples/percolate.py b/tests/test_integration/test_examples/percolate.py deleted file mode 120000 index 9c578bbda..000000000 --- a/tests/test_integration/test_examples/percolate.py +++ /dev/null @@ -1 +0,0 @@ -../../../examples/percolate.py \ No newline at end of file diff --git a/utils/run-unasync.py b/utils/run-unasync.py new file mode 100644 index 000000000..a0a908f5c --- /dev/null +++ b/utils/run-unasync.py @@ -0,0 +1,126 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import subprocess +import sys +from glob import glob +from pathlib import Path + +import unasync + + +def main(check=False): + # the list of directories that need to be processed with unasync + # each entry has two paths: + # - the source path with the async sources + # - the destination path where the sync sources should be written + source_dirs = [ + ( + "elasticsearch_dsl/_async/", + "elasticsearch_dsl/_sync/", + ), + ("tests/_async/", "tests/_sync/"), + ( + "tests/test_integration/_async/", + "tests/test_integration/_sync/", + ), + ( + "tests/test_integration/test_examples/_async", + "tests/test_integration/test_examples/_sync/", + ), + ("examples/async/", "examples/"), + ] + + # Unasync all the generated async code + additional_replacements = { + "_async": "_sync", + "AsyncElasticsearch": "Elasticsearch", + "AsyncSearch": "Search", + "AsyncMultiSearch": "MultiSearch", + "AsyncDocument": "Document", + "AsyncIndexMeta": "IndexMeta", + "AsyncIndexTemplate": "IndexTemplate", + "AsyncIndex": "Index", + "AsyncUpdateByQuery": "UpdateByQuery", + "AsyncMapping": "Mapping", + "AsyncFacetedSearch": "FacetedSearch", + "async_connections": "connections", + "async_scan": "scan", + "async_simulate": "simulate", + "async_mock_client": "mock_client", + "async_client": "client", + "async_data_client": "data_client", + "async_write_client": "write_client", + "async_pull_request": "pull_request", + "async_examples": "examples", + "assert_awaited_once_with": "assert_called_once_with", + } + rules = [ + unasync.Rule( + fromdir=dir[0], + todir=f"{dir[0]}_sync_check/" if check else dir[1], + additional_replacements=additional_replacements, + ) + for dir in source_dirs + ] + + filepaths = [] + for root, _, filenames in os.walk(Path(__file__).absolute().parent.parent): + for filename in filenames: + if filename.rpartition(".")[-1] in ( + "py", + "pyi", + ) and not filename.startswith("utils.py"): + filepaths.append(os.path.join(root, filename)) + + unasync.unasync_files(filepaths, rules) + for dir in source_dirs: + output_dir = f"{dir[0]}_sync_check/" if check else dir[1] + subprocess.check_call(["black", "--target-version=py38", output_dir]) + subprocess.check_call(["isort", output_dir]) + for file in glob("*.py", root_dir=dir[0]): + # remove asyncio from sync files + subprocess.check_call( + ["sed", "-i.bak", "/^import asyncio$/d", f"{output_dir}{file}"] + ) + subprocess.check_call( + [ + "sed", + "-i.bak", + "s/asyncio\\.run(main())/main()/", + f"{output_dir}{file}", + ] + ) + subprocess.check_call(["rm", f"{output_dir}{file}.bak"]) + + if check: + # make sure there are no differences between _sync and _sync_check + subprocess.check_call( + [ + "diff", + f"{dir[1]}{file}", + f"{output_dir}{file}", + ] + ) + + if check: + subprocess.check_call(["rm", "-rf", output_dir]) + + +if __name__ == "__main__": + main(check="--check" in sys.argv)