forked from langchain-ai/langchain
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
bedrock: add unit test for retriever (langchain-ai#21485)
This was implemented in langchain-ai#21349 but dropped before merge.
- Loading branch information
1 parent
e5232ae
commit 6ae784b
Showing
1 changed file
with
68 additions
and
0 deletions.
There are no files selected for viewing
68 changes: 68 additions & 0 deletions
68
libs/community/tests/unit_tests/retrievers/test_bedrock.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from typing import List | ||
from unittest.mock import MagicMock | ||
|
||
import pytest | ||
from langchain_core.documents import Document | ||
|
||
from langchain_community.retrievers import AmazonKnowledgeBasesRetriever | ||
|
||
|
||
@pytest.fixture | ||
def mock_client() -> MagicMock: | ||
return MagicMock() | ||
|
||
|
||
@pytest.fixture | ||
def mock_retriever_config() -> dict: | ||
return {"vectorSearchConfiguration": {"numberOfResults": 4}} | ||
|
||
|
||
@pytest.fixture | ||
def amazon_retriever( | ||
mock_client: MagicMock, mock_retriever_config: dict | ||
) -> AmazonKnowledgeBasesRetriever: | ||
return AmazonKnowledgeBasesRetriever( | ||
knowledge_base_id="test_kb_id", | ||
retrieval_config=mock_retriever_config, | ||
client=mock_client, | ||
) | ||
|
||
|
||
def test_create_client(amazon_retriever: AmazonKnowledgeBasesRetriever) -> None: | ||
with pytest.raises(ImportError): | ||
amazon_retriever.create_client({}) | ||
|
||
|
||
def test_get_relevant_documents( | ||
amazon_retriever: AmazonKnowledgeBasesRetriever, mock_client: MagicMock | ||
) -> None: | ||
query: str = "test query" | ||
mock_client.retrieve.return_value = { | ||
"retrievalResults": [ | ||
{"content": {"text": "result1"}, "metadata": {"key": "value1"}}, | ||
{ | ||
"content": {"text": "result2"}, | ||
"metadata": {"key": "value2"}, | ||
"score": 1, | ||
"location": "testLocation", | ||
}, | ||
{"content": {"text": "result3"}}, | ||
] | ||
} | ||
documents: List[Document] = amazon_retriever._get_relevant_documents( | ||
query, | ||
run_manager=None, # type: ignore | ||
) | ||
|
||
assert len(documents) == 3 | ||
assert isinstance(documents[0], Document) | ||
assert documents[0].page_content == "result1" | ||
assert documents[0].metadata == {"score": 0, "source_metadata": {"key": "value1"}} | ||
assert documents[1].page_content == "result2" | ||
assert documents[1].metadata == { | ||
"score": 1, | ||
"source_metadata": {"key": "value2"}, | ||
"location": "testLocation", | ||
} | ||
assert documents[2].page_content == "result3" | ||
assert documents[2].metadata == {"score": 0} |