Skip to content

Commit

Permalink
Fix linting issues
Browse files Browse the repository at this point in the history
  • Loading branch information
silvanocerza committed Oct 2, 2023
1 parent a4bfef8 commit 1b6f9cc
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 25 deletions.
23 changes: 12 additions & 11 deletions src/elasticsearch_haystack/document_store.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
# SPDX-FileCopyrightText: 2023-present Silvano Cerza <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
import logging
from typing import Any, Dict, List, Optional, Union, Mapping
import json
import logging
from typing import Any, Dict, List, Mapping, Optional, Union

from elasticsearch import Elasticsearch, helpers
from elastic_transport import NodeConfig
import numpy as np
from pandas import DataFrame

from elastic_transport import NodeConfig
from elasticsearch import Elasticsearch, helpers
from haystack.preview.dataclasses import Document
from haystack.preview.document_stores.decorator import document_store
from haystack.preview.document_stores.errors import DuplicateDocumentError
from haystack.preview.document_stores.protocols import DuplicatePolicy
from pandas import DataFrame

from elasticsearch_haystack.filters import _normalize_filters

Expand All @@ -36,7 +35,7 @@ def __init__(self, *, hosts: Optional[Hosts] = None, index: str = "default", **k
:param hosts: List of hosts running the Elasticsearch client. Defaults to None
:param index: Name of index in Elasticsearch, if it doesn't exist it will be created. Defaults to "default"
:param \*\*kwargs: Optional arguments that ``Elasticsearch`` takes.
:param **kwargs: Optional arguments that ``Elasticsearch`` takes.
"""
self._client = Elasticsearch(hosts, **kwargs)
self._index = index
Expand Down Expand Up @@ -149,7 +148,8 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D
"""
if len(documents) > 0:
if not isinstance(documents[0], Document):
raise ValueError("param 'documents' must contain a list of objects of type Document")
msg = "param 'documents' must contain a list of objects of type Document"
raise ValueError(msg)

action = "index" if policy == DuplicatePolicy.OVERWRITE else "create"
_, errors = helpers.bulk(
Expand All @@ -165,8 +165,9 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D
if errors and policy == DuplicatePolicy.FAIL:
# TODO: Handle errors in a better way, we're assuming that all errors
# are related to duplicate documents but that could be very well be wrong.
ids = ', '.join((e["create"]["_id"] for e in errors))
raise DuplicateDocumentError(f"IDs '{ids}' already exist in the document store.")
ids = ", ".join(e["create"]["_id"] for e in errors)
msg = f"IDs '{ids}' already exist in the document store."
raise DuplicateDocumentError(msg)

def _deserialize_document(self, hit: Dict[str, Any]) -> Document:
"""
Expand Down Expand Up @@ -231,7 +232,7 @@ def delete_documents(self, document_ids: List[str]) -> None:
#
helpers.bulk(
client=self._client,
actions=({"_op_type": "delete", "_id": id} for id in document_ids),
actions=({"_op_type": "delete", "_id": id_} for id_ in document_ids),
refresh="wait_for",
index=self._index,
raise_on_error=False,
Expand Down
26 changes: 16 additions & 10 deletions src/elasticsearch_haystack/filters.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
from typing import Any, Dict, List, Union

import numpy as np
from pandas import DataFrame
from haystack.preview.errors import FilterError
from pandas import DataFrame


def _normalize_filters(filters: Union[List[Dict], Dict], logical_condition="") -> Dict[str, Any]:
"""
Converts Haystack filters in ElasticSearch compatible filters.
"""
if not isinstance(filters, dict) and not isinstance(filters, list):
raise FilterError("Filters must be either a dictionary or a list")
msg = "Filters must be either a dictionary or a list"
raise FilterError(msg)
conditions = []
if isinstance(filters, dict):
filters = [filters]
for filter in filters:
for operator, value in filter.items():
for filter_ in filters:
for operator, value in filter_.items():
if operator in ["$not", "$and", "$or"]:
# Logical operators
conditions.append(_normalize_filters(value, operator))
Expand Down Expand Up @@ -58,19 +59,23 @@ def _parse_comparison(field: str, comparison: Union[Dict, List, str, float]) ->
result.append({"term": {field: val}})
elif comparator == "$ne":
if isinstance(val, list):
raise FilterError(f"{field}'s value can't be a list when using '{comparator}' comparator")
msg = f"{field}'s value can't be a list when using '{comparator}' comparator"
raise FilterError(msg)
result.append({"bool": {"must_not": {"term": {field: val}}}})
elif comparator == "$in":
if not isinstance(val, list):
raise FilterError(f"{field}'s value must be a list when using '{comparator}' comparator")
msg = f"{field}'s value must be a list when using '{comparator}' comparator"
raise FilterError(msg)
result.append({"terms": {field: val}})
elif comparator == "$nin":
if not isinstance(val, list):
raise FilterError(f"{field}'s value must be a list when using '{comparator}' comparator")
msg = f"{field}'s value must be a list when using '{comparator}' comparator"
raise FilterError(msg)
result.append({"bool": {"must_not": {"terms": {field: val}}}})
elif comparator in ["$gt", "$gte", "$lt", "$lte"]:
if isinstance(val, list):
raise FilterError(f"{field}'s value can't be a list when using '{comparator}' comparator")
msg = f"{field}'s value can't be a list when using '{comparator}' comparator"
raise FilterError(msg)
result.append({"range": {field: {comparator[1:]: val}}})
elif comparator in ["$not", "$or"]:
result.append(_normalize_filters(val, comparator))
Expand All @@ -81,7 +86,8 @@ def _parse_comparison(field: str, comparison: Union[Dict, List, str, float]) ->
elif comparator == "$and":
result.append(_normalize_filters({field: val}, comparator))
else:
raise FilterError(f"Unknown comparator '{comparator}'")
msg = f"Unknown comparator '{comparator}'"
raise FilterError(msg)
elif isinstance(comparison, list):
result.append({"terms": {field: comparison}})
elif isinstance(comparison, np.ndarray):
Expand Down Expand Up @@ -116,7 +122,7 @@ def _normalize_ranges(conditions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
]
```
"""
range_conditions = [list(c["range"].items())[0] for c in conditions if "range" in c]
range_conditions = [next(iter(c["range"].items())) for c in conditions if "range" in c]
if range_conditions:
conditions = [c for c in conditions if "range" not in c]
range_conditions_dict = {}
Expand Down
6 changes: 2 additions & 4 deletions tests/test_document_store.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
# SPDX-FileCopyrightText: 2023-present Silvano Cerza <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from typing import List

import pytest
from haystack.preview.testing.document_store import DocumentStoreBaseTests
from haystack.preview.dataclasses.document import Document
from haystack.preview.document_stores.protocols import DuplicatePolicy
from haystack.preview.document_stores.errors import DuplicateDocumentError
from haystack.preview.document_stores.protocols import DuplicatePolicy
from haystack.preview.testing.document_store import DocumentStoreBaseTests

from elasticsearch_haystack.document_store import ElasticsearchDocumentStore

Expand Down

0 comments on commit 1b6f9cc

Please sign in to comment.