Skip to content

Commit

Permalink
Added tests for knowlegebases retriever
Browse files Browse the repository at this point in the history
  • Loading branch information
3coins committed Apr 3, 2024
1 parent 23a712a commit 9e1d84c
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 2 deletions.
11 changes: 9 additions & 2 deletions libs/aws/langchain_aws/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
from langchain_aws.llms import SagemakerEndpoint
from langchain_aws.retrievers import AmazonKendraRetriever
from langchain_aws.retrievers import (
AmazonKendraRetriever,
AmazonKnowledgeBasesRetriever,
)

__all__ = ["SagemakerEndpoint", "AmazonKendraRetriever"]
__all__ = [
"SagemakerEndpoint",
"AmazonKendraRetriever",
"AmazonKnowledgeBasesRetriever",
]
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from unittest.mock import Mock

import pytest
from langchain_core.documents import Document

from langchain_aws import AmazonKnowledgeBasesRetriever


@pytest.fixture
def mock_client() -> Mock:
return Mock()


@pytest.fixture
def retriever(mock_client: Mock) -> AmazonKnowledgeBasesRetriever:
return AmazonKnowledgeBasesRetriever(
knowledge_base_id="test-knowledge-base",
client=mock_client,
retrieval_config={"vectorSearchConfiguration": {"numberOfResults": 4}}, # type: ignore[arg-type]
)


def test_get_relevant_documents(retriever, mock_client) -> None: # type: ignore[no-untyped-def]
response = {
"retrievalResults": [
{
"content": {"text": "This is the first result."},
"location": "location1",
"score": 0.9,
},
{
"content": {"text": "This is the second result."},
"location": "location2",
"score": 0.8,
},
{"content": {"text": "This is the third result."}, "location": "location3"},
]
}
mock_client.retrieve.return_value = response

query = "test query"

expected_documents = [
Document(
page_content="This is the first result.",
metadata={"location": "location1", "score": 0.9},
),
Document(
page_content="This is the second result.",
metadata={"location": "location2", "score": 0.8},
),
Document(
page_content="This is the third result.",
metadata={"location": "location3", "score": 0.0},
),
]

documents = retriever.get_relevant_documents(query)

assert documents == expected_documents

mock_client.retrieve.assert_called_once_with(
retrievalQuery={"text": "test query"},
knowledgeBaseId="test-knowledge-base",
retrievalConfiguration={"vectorSearchConfiguration": {"numberOfResults": 4}},
)

0 comments on commit 9e1d84c

Please sign in to comment.