diff --git a/libs/aws/langchain_aws/retrievers/bedrock.py b/libs/aws/langchain_aws/retrievers/bedrock.py index 55d32837..44fd6851 100644 --- a/libs/aws/langchain_aws/retrievers/bedrock.py +++ b/libs/aws/langchain_aws/retrievers/bedrock.py @@ -1,3 +1,4 @@ +import json from typing import Any, Dict, List, Literal, Optional, Union import boto3 @@ -65,11 +66,10 @@ class AmazonKnowledgeBasesRetriever(BaseRetriever): specified. If not specified, the default credential profile or, if on an EC2 instance, credentials from IMDS will be used. client: boto3 client for bedrock agent runtime. - retrieval_config: Configuration for retrieval. - + retrieval_config: Optional configuration for retrieval specified as a + Python object (RetrievalConfig) or as a dictionary Example: .. code-block:: python - from langchain_community.retrievers import AmazonKnowledgeBasesRetriever retriever = AmazonKnowledgeBasesRetriever( @@ -87,7 +87,7 @@ class AmazonKnowledgeBasesRetriever(BaseRetriever): credentials_profile_name: Optional[str] = None endpoint_url: Optional[str] = None client: Any - retrieval_config: RetrievalConfig + retrieval_config: Optional[RetrievalConfig] | Optional[Dict[str, Any]] = None min_score_confidence: Annotated[ Optional[float], Field(ge=0.0, le=1.0, default=None) ] @@ -136,7 +136,7 @@ def create_client(cls, values: Dict[str, Any]) -> Any: "profile name are valid." ) from e - def _filter_by_score_confidence(self, docs: List[Document]) -> List[Document]: + def __filter_by_score_confidence(self, docs: List[Document]) -> List[Document]: """ Filter out the records that have a score confidence less than the required threshold. @@ -159,17 +159,53 @@ def _get_relevant_documents( *, run_manager: CallbackManagerForRetrieverRun, ) -> List[Document]: - response = self.client.retrieve( - retrievalQuery={"text": query.strip()}, - knowledgeBaseId=self.knowledge_base_id, - retrievalConfiguration=self.retrieval_config.model_dump( - exclude_none=True, by_alias=True - ), - ) + """ + Get relevant document from a KnowledgeBase + + :param query: the user's query + :param run_manager: The callback handler to use + :return: List of relevant documents + """ + retrieve_request: Dict[str, Any] = self.__get_retrieve_request(query) + response = self.client.retrieve(**retrieve_request) results = response["retrievalResults"] + documents: List[ + Document + ] = AmazonKnowledgeBasesRetriever.__retrieval_results_to_documents(results) + + return self.__filter_by_score_confidence(docs=documents) + + def __get_retrieve_request(self, query: str) -> Dict[str, Any]: + """ + Build a Retrieve request + + :param query: + :return: + """ + request: Dict[str, Any] = { + "retrievalQuery": {"text": query.strip()}, + "knowledgeBaseId": self.knowledge_base_id, + } + if self.retrieval_config: + request["retrievalConfiguration"] = self.retrieval_config.model_dump( + exclude_none=True, by_alias=True + ) + return request + + @staticmethod + def __retrieval_results_to_documents( + results: List[Dict[str, Any]], + ) -> List[Document]: + """ + Convert the Retrieve API results to LangChain Documents + + :param results: Retrieve API results list + :return: List of LangChain Documents + """ documents = [] for result in results: - content = result["content"]["text"] + content = AmazonKnowledgeBasesRetriever.__get_content_from_result(result) + result["type"] = result.get("content", {}).get("type", "TEXT") result.pop("content") if "score" not in result: result["score"] = 0 @@ -181,5 +217,33 @@ def _get_relevant_documents( metadata=result, ) ) + return documents - return self._filter_by_score_confidence(docs=documents) + @staticmethod + def __get_content_from_result(result: Dict[str, Any]) -> Optional[str]: + """ + Convert the content from one Retrieve API result to string + + :param result: Retrieve API search result + :return: string representation of the content attribute + """ + if not result: + raise ValueError("Invalid search result") + content: dict = result.get("content") + if not content: + raise ValueError( + "Invalid search result, content is missing from the result" + ) + if not content.get("type"): + return content.get("text") + if content["type"] == "TEXT": + return content.get("text") + elif content["type"] == "IMAGE": + return content.get("byteContent") + elif content["type"] == "ROW": + row: Optional[List[dict]] = content.get("row", []) + return json.dumps(row if row else []) + else: + # future proofing this class to prevent code breaks if new types + # are introduced + return None diff --git a/libs/aws/tests/unit_tests/retrievers/test_bedrock.py b/libs/aws/tests/unit_tests/retrievers/test_bedrock.py index 98357428..cc448fca 100644 --- a/libs/aws/tests/unit_tests/retrievers/test_bedrock.py +++ b/libs/aws/tests/unit_tests/retrievers/test_bedrock.py @@ -1,5 +1,5 @@ # type: ignore - +from typing import Any, List from unittest.mock import MagicMock import pytest @@ -28,6 +28,16 @@ def mock_retriever_config(): ) +@pytest.fixture +def mock_retriever_config_dict(): + return { + "vectorSearchConfiguration": { + "numberOfResults": 5, + "filter": {"in": {"key": "key", "value": ["value1", "value2"]}}, + } + } + + @pytest.fixture def amazon_retriever(mock_client, mock_retriever_config): return AmazonKnowledgeBasesRetriever( @@ -37,6 +47,23 @@ def amazon_retriever(mock_client, mock_retriever_config): ) +@pytest.fixture +def amazon_retriever_no_retrieval_config(mock_client, mock_retriever_config): + return AmazonKnowledgeBasesRetriever( + knowledge_base_id="test_kb_id", + client=mock_client, + ) + + +@pytest.fixture +def amazon_retriever_retrieval_config_dict(mock_client, mock_retriever_config_dict): + return AmazonKnowledgeBasesRetriever( + knowledge_base_id="test_kb_id", + retrieval_config=mock_retriever_config_dict, + client=mock_client, + ) + + def test_retriever_invoke(amazon_retriever, mock_client): query = "test query" mock_client.retrieve.return_value = { @@ -67,15 +94,20 @@ def test_retriever_invoke(amazon_retriever, mock_client): assert len(documents) == 3 assert isinstance(documents[0], Document) assert documents[0].page_content == "result1" - assert documents[0].metadata == {"score": 0, "source_metadata": {"key": "value1"}} + assert documents[0].metadata == { + "score": 0, + "source_metadata": {"key": "value1"}, + "type": "TEXT", + } assert documents[1].page_content == "result2" assert documents[1].metadata == { "score": 1, "source_metadata": {"key": "value2"}, "location": "testLocation", + "type": "TEXT", } assert documents[2].page_content == "result3" - assert documents[2].metadata == {"score": 0} + assert documents[2].metadata == {"score": 0, "type": "TEXT"} def test_retriever_invoke_with_score(amazon_retriever, mock_client): @@ -88,6 +120,7 @@ def test_retriever_invoke_with_score(amazon_retriever, mock_client): "metadata": {"key": "value2"}, "score": 1, "location": "testLocation", + "type": "TEXT", }, {"content": {"text": "result3"}}, ] @@ -103,4 +136,439 @@ def test_retriever_invoke_with_score(amazon_retriever, mock_client): "score": 1, "source_metadata": {"key": "value2"}, "location": "testLocation", + "type": "TEXT", + } + + +def test_retriever_retrieval_config_dict_invoke( + amazon_retriever_retrieval_config_dict, mock_client +): + documents = set_return_value_and_query( + mock_client, amazon_retriever_retrieval_config_dict + ) + validate_query_response_no_cutoff(documents) + mock_client.retrieve.assert_called_once_with( + retrievalQuery={"text": "test query"}, + knowledgeBaseId="test_kb_id", + retrievalConfiguration={ + "vectorSearchConfiguration": { + "numberOfResults": 5, + # Expecting to be called with correct "in" operator instead of "in_" + "filter": {"in": {"key": "key", "value": ["value1", "value2"]}}, + } + }, + ) + + +def test_retriever_retrieval_config_dict_invoke_with_score( + amazon_retriever_retrieval_config_dict, mock_client +): + amazon_retriever_retrieval_config_dict.min_score_confidence = 0.6 + documents = set_return_value_and_query( + mock_client, amazon_retriever_retrieval_config_dict + ) + validate_query_response_with_cutoff(documents) + + +def test_retriever_no_retrieval_config_invoke( + amazon_retriever_no_retrieval_config, mock_client +): + documents = set_return_value_and_query( + mock_client, amazon_retriever_no_retrieval_config + ) + validate_query_response_no_cutoff(documents) + mock_client.retrieve.assert_called_once_with( + retrievalQuery={"text": "test query"}, knowledgeBaseId="test_kb_id" + ) + + +def test_retriever_no_retrieval_config_invoke_with_score( + amazon_retriever_no_retrieval_config, mock_client +): + amazon_retriever_no_retrieval_config.min_score_confidence = 0.6 + documents = set_return_value_and_query( + mock_client, amazon_retriever_no_retrieval_config + ) + validate_query_response_with_cutoff(documents) + + +@pytest.mark.parametrize( + "search_results,expected_documents", + [ + ( + [ + { + "content": {"text": "result"}, + "metadata": {"key": "value1"}, + "score": 1, + "location": "testLocation", + }, + { + "content": {"text": "result"}, + "metadata": {"key": "value1"}, + "score": 0.5, + "location": "testLocation", + }, + ], + [ + Document( + page_content="result", + metadata={ + "score": 1, + "location": "testLocation", + "source_metadata": {"key": "value1"}, + "type": "TEXT", + }, + ), + Document( + page_content="result", + metadata={ + "score": 0.5, + "location": "testLocation", + "source_metadata": {"key": "value1"}, + "type": "TEXT", + }, + ), + ], + ), + # text type + ( + [ + { + "content": {"text": "result", "type": "TEXT"}, + "metadata": {"key": "value1"}, + "score": 1, + "location": "testLocation", + }, + { + "content": {"text": "result", "type": "TEXT"}, + "metadata": {"key": "value1"}, + "score": 0.5, + "location": "testLocation", + }, + ], + [ + Document( + page_content="result", + metadata={ + "score": 1, + "location": "testLocation", + "source_metadata": {"key": "value1"}, + "type": "TEXT", + }, + ), + Document( + page_content="result", + metadata={ + "score": 0.5, + "location": "testLocation", + "source_metadata": {"key": "value1"}, + "type": "TEXT", + }, + ), + ], + ), + # image type + ( + [ + { + "content": {"byteContent": "bytecontent", "type": "IMAGE"}, + "metadata": {"key": "value1"}, + "score": 1, + "location": "testLocation", + }, + { + "content": {"byteContent": "bytecontent", "type": "IMAGE"}, + "metadata": {"key": "value1"}, + "score": 0.5, + "location": "testLocation", + }, + ], + [ + Document( + page_content="bytecontent", + metadata={ + "score": 1, + "location": "testLocation", + "source_metadata": {"key": "value1"}, + "type": "IMAGE", + }, + ), + Document( + page_content="bytecontent", + metadata={ + "score": 0.5, + "location": "testLocation", + "source_metadata": {"key": "value1"}, + "type": "IMAGE", + }, + ), + ], + ), + # row type + ( + [ + { + "content": { + "row": [ + {"columnName": "someName1", "columnValue": "someValue1"}, + {"columnName": "someName2", "columnValue": "someValue2"}, + ], + "type": "ROW", + }, + "score": 1, + "metadata": {"key": "value1"}, + "location": "testLocation", + }, + { + "content": { + "row": [ + {"columnName": "someName1", "columnValue": "someValue1"}, + {"columnName": "someName2", "columnValue": "someValue2"}, + ], + "type": "ROW", + }, + "score": 0.5, + "metadata": {"key": "value1"}, + "location": "testLocation", + }, + ], + [ + Document( + page_content='[{"columnName": "someName1", "columnValue": "someValue1"}, ' + '{"columnName": "someName2", "columnValue": "someValue2"}]', + metadata={ + "score": 1, + "location": "testLocation", + "source_metadata": {"key": "value1"}, + "type": "ROW", + }, + ), + Document( + page_content='[{"columnName": "someName1", "columnValue": "someValue1"}, ' + '{"columnName": "someName2", "columnValue": "someValue2"}]', + metadata={ + "score": 0.5, + "location": "testLocation", + "source_metadata": {"key": "value1"}, + "type": "ROW", + }, + ), + ], + ), + ], +) +def test_retriever_with_multi_modal_types_then_get_valid_documents( + mock_client, amazon_retriever, search_results, expected_documents +): + query = "test query" + mock_client.retrieve.return_value = {"retrievalResults": search_results} + documents = amazon_retriever.invoke(query, run_manager=None) + assert documents == expected_documents + + +@pytest.mark.parametrize( + "search_result_input,expected_output", + [ + # VALID INPUTS + # no type + ({"content": {"text": "result"}}, "result"), + # text type + ({"content": {"text": "result", "type": "TEXT"}}, "result"), + # image type + ({"content": {"byteContent": "bytecontent", "type": "IMAGE"}}, "bytecontent"), + # row type + ( + { + "content": { + "row": [ + {"columnName": "someName1", "columnValue": "someValue1"}, + {"columnName": "someName2", "columnValue": "someValue2"}, + ], + "type": "ROW", + } + }, + '[{"columnName": "someName1", "columnValue": "someValue1"}, ' + '{"columnName": "someName2", "columnValue": "someValue2"}]', + ), + # VALID INPUTS w/ metadata + # no type + ({"content": {"text": "result"}, "metadata": {"key": "value1"}}, "result"), + # text type + ( + { + "content": {"text": "result", "type": "TEXT"}, + "metadata": {"key": "value1"}, + }, + "result", + ), + # image type + ( + { + "content": {"byteContent": "bytecontent", "type": "IMAGE"}, + "metadata": {"key": "value1"}, + }, + "bytecontent", + ), + # row type + ( + { + "content": { + "row": [ + {"columnName": "someName1", "columnValue": "someValue1"}, + {"columnName": "someName2", "columnValue": "someValue2"}, + ], + "metadata": {"key": "value1"}, + "type": "ROW", + } + }, + '[{"columnName": "someName1", "columnValue": "someValue1"}, ' + '{"columnName": "someName2", "columnValue": "someValue2"}]', + ), + # invalid type + ({"content": {"invalid": "invalid", "type": "INVALID"}}, None), + # EMPTY VALUES + # no type + ({"content": {"text": ""}}, ""), + # text type + ({"content": {"text": "", "type": "TEXT"}}, ""), + # image type + ({"content": {"byteContent": "", "type": "IMAGE"}}, ""), + # row type + ({"content": {"row": [], "type": "ROW"}}, "[]"), + # NONE VALUES + ({"content": {"text": None}}, None), + # text type + ({"content": {"text": None, "type": "TEXT"}}, None), + # image type + ({"content": {"byteContent": None, "type": "IMAGE"}}, None), + # row type + ({"content": {"row": None, "type": "ROW"}}, "[]"), + # WRONG CONTENT + # text + ({"content": {"text": "result", "type": "IMAGE"}}, None), + ({"content": {"text": "result", "type": "ROW"}}, "[]"), + # byteContent + ({"content": {"byteContent": "result"}}, None), + ({"content": {"byteContent": "result", "type": "TEXT"}}, None), + ({"content": {"byteContent": "result", "type": "ROW"}}, "[]"), + # row + ( + { + "content": { + "row": [ + {"columnName": "someName1", "columnValue": "someValue1"}, + {"columnName": "someName2", "columnValue": "someValue2"}, + ] + } + }, + None, + ), + ( + { + "content": { + "row": [ + {"columnName": "someName1", "columnValue": "someValue1"}, + {"columnName": "someName2", "columnValue": "someValue2"}, + ], + "type": "TEXT", + } + }, + None, + ), + ( + { + "content": { + "row": [ + {"columnName": "someName1", "columnValue": "someValue1"}, + {"columnName": "someName2", "columnValue": "someValue2"}, + ], + "type": "IMAGE", + } + }, + None, + ), + ], +) +def test_when_get_content_from_result_then_get_expected_content( + search_result_input, expected_output +): + assert ( + AmazonKnowledgeBasesRetriever._AmazonKnowledgeBasesRetriever__get_content_from_result( + search_result_input + ) + == expected_output + ) + + +@pytest.mark.parametrize( + "search_result_input", + [ + # empty content + ({"content": {}}), + # None content + ({"content": None}), + # empty dict + ({}), + # None search result + None, + ], +) +def test_when_get_content_from_result_with_invalid_content_then_raise_error( + search_result_input, +): + with pytest.raises(ValueError): + AmazonKnowledgeBasesRetriever._AmazonKnowledgeBasesRetriever__get_content_from_result( + search_result_input + ) + + +def set_return_value_and_query( + client: Any, retriever: AmazonKnowledgeBasesRetriever +) -> List[Document]: + query = "test query" + client.retrieve.return_value = { + "retrievalResults": [ + {"content": {"text": "result1"}, "metadata": {"key": "value1"}}, + { + "content": {"text": "result2"}, + "metadata": {"key": "value2"}, + "score": 1, + "location": "testLocation", + }, + {"content": {"text": "result3"}}, + ] + } + return retriever.invoke(query, run_manager=None) + + +def validate_query_response_no_cutoff(documents: List[Document]): + assert len(documents) == 3 + assert isinstance(documents[0], Document) + assert documents[0].page_content == "result1" + assert documents[0].metadata == { + "score": 0, + "source_metadata": {"key": "value1"}, + "type": "TEXT", + } + assert documents[1].page_content == "result2" + assert documents[1].metadata == { + "score": 1, + "source_metadata": {"key": "value2"}, + "location": "testLocation", + "type": "TEXT", + } + assert documents[2].page_content == "result3" + assert documents[2].metadata == {"score": 0, "type": "TEXT"} + + +def validate_query_response_with_cutoff(documents: List[Document]): + assert len(documents) == 1 + assert isinstance(documents[0], Document) + assert documents[0].page_content == "result2" + assert documents[0].metadata == { + "score": 1, + "source_metadata": {"key": "value2"}, + "location": "testLocation", + "type": "TEXT", }