From 1b6f9cc8b2684acbd73134ea3a31642f53d037bf Mon Sep 17 00:00:00 2001 From: Silvano Cerza Date: Sun, 1 Oct 2023 17:02:48 -0700 Subject: [PATCH] Fix linting issues --- src/elasticsearch_haystack/document_store.py | 23 ++++++++--------- src/elasticsearch_haystack/filters.py | 26 ++++++++++++-------- tests/test_document_store.py | 6 ++--- 3 files changed, 30 insertions(+), 25 deletions(-) diff --git a/src/elasticsearch_haystack/document_store.py b/src/elasticsearch_haystack/document_store.py index 4ffc88c..9b08fcf 100644 --- a/src/elasticsearch_haystack/document_store.py +++ b/src/elasticsearch_haystack/document_store.py @@ -1,19 +1,18 @@ # SPDX-FileCopyrightText: 2023-present Silvano Cerza # # SPDX-License-Identifier: Apache-2.0 -import logging -from typing import Any, Dict, List, Optional, Union, Mapping import json +import logging +from typing import Any, Dict, List, Mapping, Optional, Union -from elasticsearch import Elasticsearch, helpers -from elastic_transport import NodeConfig import numpy as np -from pandas import DataFrame - +from elastic_transport import NodeConfig +from elasticsearch import Elasticsearch, helpers from haystack.preview.dataclasses import Document from haystack.preview.document_stores.decorator import document_store from haystack.preview.document_stores.errors import DuplicateDocumentError from haystack.preview.document_stores.protocols import DuplicatePolicy +from pandas import DataFrame from elasticsearch_haystack.filters import _normalize_filters @@ -36,7 +35,7 @@ def __init__(self, *, hosts: Optional[Hosts] = None, index: str = "default", **k :param hosts: List of hosts running the Elasticsearch client. Defaults to None :param index: Name of index in Elasticsearch, if it doesn't exist it will be created. Defaults to "default" - :param \*\*kwargs: Optional arguments that ``Elasticsearch`` takes. + :param **kwargs: Optional arguments that ``Elasticsearch`` takes. """ self._client = Elasticsearch(hosts, **kwargs) self._index = index @@ -149,7 +148,8 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D """ if len(documents) > 0: if not isinstance(documents[0], Document): - raise ValueError("param 'documents' must contain a list of objects of type Document") + msg = "param 'documents' must contain a list of objects of type Document" + raise ValueError(msg) action = "index" if policy == DuplicatePolicy.OVERWRITE else "create" _, errors = helpers.bulk( @@ -165,8 +165,9 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D if errors and policy == DuplicatePolicy.FAIL: # TODO: Handle errors in a better way, we're assuming that all errors # are related to duplicate documents but that could be very well be wrong. - ids = ', '.join((e["create"]["_id"] for e in errors)) - raise DuplicateDocumentError(f"IDs '{ids}' already exist in the document store.") + ids = ", ".join(e["create"]["_id"] for e in errors) + msg = f"IDs '{ids}' already exist in the document store." + raise DuplicateDocumentError(msg) def _deserialize_document(self, hit: Dict[str, Any]) -> Document: """ @@ -231,7 +232,7 @@ def delete_documents(self, document_ids: List[str]) -> None: # helpers.bulk( client=self._client, - actions=({"_op_type": "delete", "_id": id} for id in document_ids), + actions=({"_op_type": "delete", "_id": id_} for id_ in document_ids), refresh="wait_for", index=self._index, raise_on_error=False, diff --git a/src/elasticsearch_haystack/filters.py b/src/elasticsearch_haystack/filters.py index f79e2a7..66041fa 100644 --- a/src/elasticsearch_haystack/filters.py +++ b/src/elasticsearch_haystack/filters.py @@ -1,8 +1,8 @@ from typing import Any, Dict, List, Union import numpy as np -from pandas import DataFrame from haystack.preview.errors import FilterError +from pandas import DataFrame def _normalize_filters(filters: Union[List[Dict], Dict], logical_condition="") -> Dict[str, Any]: @@ -10,12 +10,13 @@ def _normalize_filters(filters: Union[List[Dict], Dict], logical_condition="") - Converts Haystack filters in ElasticSearch compatible filters. """ if not isinstance(filters, dict) and not isinstance(filters, list): - raise FilterError("Filters must be either a dictionary or a list") + msg = "Filters must be either a dictionary or a list" + raise FilterError(msg) conditions = [] if isinstance(filters, dict): filters = [filters] - for filter in filters: - for operator, value in filter.items(): + for filter_ in filters: + for operator, value in filter_.items(): if operator in ["$not", "$and", "$or"]: # Logical operators conditions.append(_normalize_filters(value, operator)) @@ -58,19 +59,23 @@ def _parse_comparison(field: str, comparison: Union[Dict, List, str, float]) -> result.append({"term": {field: val}}) elif comparator == "$ne": if isinstance(val, list): - raise FilterError(f"{field}'s value can't be a list when using '{comparator}' comparator") + msg = f"{field}'s value can't be a list when using '{comparator}' comparator" + raise FilterError(msg) result.append({"bool": {"must_not": {"term": {field: val}}}}) elif comparator == "$in": if not isinstance(val, list): - raise FilterError(f"{field}'s value must be a list when using '{comparator}' comparator") + msg = f"{field}'s value must be a list when using '{comparator}' comparator" + raise FilterError(msg) result.append({"terms": {field: val}}) elif comparator == "$nin": if not isinstance(val, list): - raise FilterError(f"{field}'s value must be a list when using '{comparator}' comparator") + msg = f"{field}'s value must be a list when using '{comparator}' comparator" + raise FilterError(msg) result.append({"bool": {"must_not": {"terms": {field: val}}}}) elif comparator in ["$gt", "$gte", "$lt", "$lte"]: if isinstance(val, list): - raise FilterError(f"{field}'s value can't be a list when using '{comparator}' comparator") + msg = f"{field}'s value can't be a list when using '{comparator}' comparator" + raise FilterError(msg) result.append({"range": {field: {comparator[1:]: val}}}) elif comparator in ["$not", "$or"]: result.append(_normalize_filters(val, comparator)) @@ -81,7 +86,8 @@ def _parse_comparison(field: str, comparison: Union[Dict, List, str, float]) -> elif comparator == "$and": result.append(_normalize_filters({field: val}, comparator)) else: - raise FilterError(f"Unknown comparator '{comparator}'") + msg = f"Unknown comparator '{comparator}'" + raise FilterError(msg) elif isinstance(comparison, list): result.append({"terms": {field: comparison}}) elif isinstance(comparison, np.ndarray): @@ -116,7 +122,7 @@ def _normalize_ranges(conditions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: ] ``` """ - range_conditions = [list(c["range"].items())[0] for c in conditions if "range" in c] + range_conditions = [next(iter(c["range"].items())) for c in conditions if "range" in c] if range_conditions: conditions = [c for c in conditions if "range" not in c] range_conditions_dict = {} diff --git a/tests/test_document_store.py b/tests/test_document_store.py index 076001b..8dd7fcd 100644 --- a/tests/test_document_store.py +++ b/tests/test_document_store.py @@ -1,13 +1,11 @@ # SPDX-FileCopyrightText: 2023-present Silvano Cerza # # SPDX-License-Identifier: Apache-2.0 -from typing import List - import pytest -from haystack.preview.testing.document_store import DocumentStoreBaseTests from haystack.preview.dataclasses.document import Document -from haystack.preview.document_stores.protocols import DuplicatePolicy from haystack.preview.document_stores.errors import DuplicateDocumentError +from haystack.preview.document_stores.protocols import DuplicatePolicy +from haystack.preview.testing.document_store import DocumentStoreBaseTests from elasticsearch_haystack.document_store import ElasticsearchDocumentStore