Skip to content

Commit

Permalink
CrateDB vector: Test non-deterministic values by using pytest.approx
Browse files Browse the repository at this point in the history
The test cases can be written substantially more elegant.
  • Loading branch information
amotl committed Nov 27, 2023
1 parent ce34aa6 commit ee0d1f5
Showing 1 changed file with 14 additions and 57 deletions.
71 changes: 14 additions & 57 deletions libs/langchain/tests/integration_tests/vectorstores/test_cratedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"""
import os
import re
from typing import Dict, Generator, List, Tuple
from typing import Dict, Generator, List

import pytest
import sqlalchemy as sa
Expand Down Expand Up @@ -85,20 +85,6 @@ def prune_tables(engine: sa.Engine) -> None:
pass


def decode_output(
output: List[Tuple[Document, float]]
) -> Tuple[List[Document], List[float]]:
"""
Decode a typical API result into separate `documents` and `scores`.
It is needed as utility function in some test cases to compensate
for different and/or flaky score values, when compared to the
original implementation.
"""
documents = [item[0] for item in output]
scores = [round(item[1], 1) for item in output]
return documents, scores


def ensure_collection(session: sa.orm.Session, name: str) -> None:
"""
Create a (fake) collection item.
Expand Down Expand Up @@ -241,12 +227,11 @@ def test_cratedb_with_filter_match() -> None:
connection_string=CONNECTION_STRING,
pre_delete_collection=True,
)
output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "0"})
# TODO: Original:
# assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] # noqa: E501
assert output in [
[(Document(page_content="foo", metadata={"page": "0"}), 2.1307645)],
[(Document(page_content="foo", metadata={"page": "0"}), 2.3150668)],
output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "0"})
assert output == [
(Document(page_content="foo", metadata={"page": "0"}), pytest.approx(2.2, 0.1))
]


Expand All @@ -263,20 +248,9 @@ def test_cratedb_with_filter_distant_match() -> None:
pre_delete_collection=True,
)
output = docsearch.similarity_search_with_score("foo", k=2, filter={"page": "2"})
# TODO: Original:
# output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "2"}) # noqa: E501
# assert output == [
# (Document(page_content="baz", metadata={"page": "2"}), 0.0013003906671379406) # noqa: E501
# ]
documents, scores = decode_output(output)
assert documents == [
Document(page_content="baz", metadata={"page": "2"}),
]
assert scores in [
[1.3],
[1.5],
[1.6],
[1.7],
# Original score value: 0.0013003906671379406
assert output == [
(Document(page_content="baz", metadata={"page": "2"}), pytest.approx(1.5, 0.2))
]


Expand Down Expand Up @@ -429,19 +403,11 @@ def test_cratedb_with_filter_in_set() -> None:
output = docsearch.similarity_search_with_score(
"foo", k=2, filter={"page": {"IN": ["0", "2"]}}
)
# TODO: Original:
"""
# Original score values: 0.0, 0.0013003906671379406
assert output == [
(Document(page_content="foo", metadata={"page": "0"}), 0.0),
(Document(page_content="baz", metadata={"page": "2"}), 0.0013003906671379406),
]
"""
documents, scores = decode_output(output)
assert documents == [
Document(page_content="foo", metadata={"page": "0"}),
Document(page_content="baz", metadata={"page": "2"}),
(Document(page_content="foo", metadata={"page": "0"}), pytest.approx(3.0, 0.1)),
(Document(page_content="baz", metadata={"page": "2"}), pytest.approx(2.2, 0.1)),
]
assert scores == [3.0, 2.2]


def test_cratedb_delete_docs() -> None:
Expand Down Expand Up @@ -486,21 +452,12 @@ def test_cratedb_relevance_score() -> None:
)

output = docsearch.similarity_search_with_relevance_scores("foo", k=3)
"""
# TODO: Original code, where the `distance` is stable.
# Original score values: 1.0, 0.9996744261675065, 0.9986996093328621
assert output == [
(Document(page_content="foo", metadata={"page": "0"}), 1.0),
(Document(page_content="bar", metadata={"page": "1"}), 0.9996744261675065),
(Document(page_content="baz", metadata={"page": "2"}), 0.9986996093328621),
]
"""
documents, scores = decode_output(output)
assert documents == [
Document(page_content="foo", metadata={"page": "0"}),
Document(page_content="bar", metadata={"page": "1"}),
Document(page_content="baz", metadata={"page": "2"}),
(Document(page_content="foo", metadata={"page": "0"}), pytest.approx(1.4, 0.1)),
(Document(page_content="bar", metadata={"page": "1"}), pytest.approx(1.1, 0.1)),
(Document(page_content="baz", metadata={"page": "2"}), pytest.approx(0.8, 0.1)),
]
assert scores == [1.4, 1.1, 0.8]


def test_cratedb_retriever_search_threshold() -> None:
Expand Down

0 comments on commit ee0d1f5

Please sign in to comment.